From 77ba52e246d6bd690b064feb89441dee65ac0391 Mon Sep 17 00:00:00 2001 From: aviau Date: Sun, 1 Mar 2026 11:16:38 -0500 Subject: [PATCH] add customizable auth --- src/flareio/api_client.py | 34 +++++++++++++------- src/flareio/auth.py | 26 ++++++++++++++++ tests/test_api_client_auth.py | 50 ++++++++++++++++++++++++++++++ tests/test_api_client_endpoints.py | 19 ------------ tests/utils.py | 6 ++-- 5 files changed, 102 insertions(+), 33 deletions(-) create mode 100644 src/flareio/auth.py create mode 100644 tests/test_api_client_auth.py diff --git a/src/flareio/api_client.py b/src/flareio/api_client.py index 20858eb..161c049 100644 --- a/src/flareio/api_client.py +++ b/src/flareio/api_client.py @@ -9,6 +9,7 @@ import requests from requests.adapters import HTTPAdapter +from requests.auth import AuthBase from urllib3.util import Retry import typing as t @@ -34,7 +35,7 @@ def __init__( tenant_id: t.Optional[int] = None, session: t.Optional[requests.Session] = None, api_domain: t.Optional[str] = None, - _disable_auth: bool = False, + _auth: AuthBase | None = None, _enable_beta_features: bool = False, ) -> None: if not api_key: @@ -52,9 +53,9 @@ def __init__( self._api_key: str = api_key self._tenant_id: t.Optional[int] = tenant_id + self._auth: t.Optional[AuthBase] = _auth self._api_token: t.Optional[str] = None self._api_token_exp: t.Optional[datetime] = None - self._disable_auth: bool = _disable_auth self._session = session or self._create_session() @classmethod @@ -135,16 +136,24 @@ def generate_token(self) -> str: return token - def _auth_headers(self) -> dict: - if self._disable_auth: - return dict() + def _apply_auth( + self, + *, + request: requests.PreparedRequest, + ) -> requests.PreparedRequest: + if self._auth: + self._auth(request) + return request + api_token: t.Optional[str] = self._api_token if not api_token or ( self._api_token_exp and self._api_token_exp < datetime.now() ): api_token = self.generate_token() - return {"Authorization": f"Bearer {api_token}"} + request.headers["Authorization"] = f"Bearer {api_token}" + + return request def _request( self, @@ -163,12 +172,7 @@ def _request( f"Client was used to access {netloc=} at {url=}. Only the domain {self._api_domain} is supported." ) - headers = { - **(headers or {}), - **self._auth_headers(), - } - - return self._session.request( + request = requests.Request( method=method, url=url, params=params, @@ -176,6 +180,12 @@ def _request( headers=headers, ) + prepared = self._session.prepare_request(request) + prepared = self._apply_auth(request=prepared) + resp = self._session.send(prepared) + + return resp + def post( self, url: str, diff --git a/src/flareio/auth.py b/src/flareio/auth.py new file mode 100644 index 0000000..35d5e1a --- /dev/null +++ b/src/flareio/auth.py @@ -0,0 +1,26 @@ +from requests import PreparedRequest +from requests.auth import AuthBase + + +class _StaticHeadersAuth(AuthBase): + def __init__( + self, + *, + headers: dict[str, str], + ) -> None: + self._headers: dict[str, str] = headers + + def __call__( + self, + r: PreparedRequest, + ) -> PreparedRequest: + r.headers.update(self._headers) + return r + + +class _EmptyAuth(AuthBase): + def __call__( + self, + r: PreparedRequest, + ) -> PreparedRequest: + return r diff --git a/tests/test_api_client_auth.py b/tests/test_api_client_auth.py new file mode 100644 index 0000000..3443c15 --- /dev/null +++ b/tests/test_api_client_auth.py @@ -0,0 +1,50 @@ +import requests_mock + +from .utils import get_test_client + +from flareio.auth import _EmptyAuth +from flareio.auth import _StaticHeadersAuth + + +def test_custom_auth_empty() -> None: + client = get_test_client( + authenticated=False, + _auth=_EmptyAuth(), + ) + with requests_mock.Mocker() as mocker: + mocker.register_uri( + "POST", + "https://api.flare.io/hello-post", + status_code=200, + ) + client.post( + "https://api.flare.io/hello-post", + json={"foo": "bar"}, + ) + assert not mocker.last_request.headers.get("Authorization") + + +def test_custom_auth_static() -> None: + client = get_test_client( + authenticated=False, + _auth=_StaticHeadersAuth( + headers={ + "first-header": "first-value", + "Authorization": "auth-value", + } + ), + ) + with requests_mock.Mocker() as mocker: + mocker.register_uri( + "POST", + "https://api.flare.io/hello-post", + status_code=200, + ) + client.post( + "https://api.flare.io/hello-post", + json={"foo": "bar"}, + headers={"second-header": "second-value"}, + ) + assert mocker.last_request.headers["Authorization"] == "auth-value" + assert mocker.last_request.headers["first-header"] == "first-value" + assert mocker.last_request.headers["second-header"] == "second-value" diff --git a/tests/test_api_client_endpoints.py b/tests/test_api_client_endpoints.py index 45866f2..e23c616 100644 --- a/tests/test_api_client_endpoints.py +++ b/tests/test_api_client_endpoints.py @@ -121,22 +121,3 @@ def test_bad_domain() -> None: match="Client was used to access netloc='bad.com' at url='https://bad.com/hello-post'. Only the domain api.flare.io is supported.", ): client.post("https://bad.com/hello-post") - - -def test_disable_auth_does_not_call_generate() -> None: - client = get_test_client( - authenticated=False, - _disable_auth=True, - ) - with requests_mock.Mocker() as mocker: - mocker.register_uri( - "POST", - "https://api.flare.io/hello-post", - status_code=200, - ) - client.post("https://api.flare.io/hello-post", json={"foo": "bar"}) - assert mocker.last_request.url == "https://api.flare.io/hello-post" - assert mocker.last_request.json() == {"foo": "bar"} - - # Authorization header should not be present when auth is disabled - assert not mocker.last_request.headers.get("Authorization") diff --git a/tests/utils.py b/tests/utils.py index 17237bb..3d90520 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,7 @@ import requests_mock +from requests.auth import AuthBase + import typing as t from flareio import FlareApiClient @@ -11,14 +13,14 @@ def get_test_client( authenticated: bool = True, api_domain: t.Optional[str] = None, _enable_beta_features: bool = False, - _disable_auth: bool = False, + _auth: t.Optional[AuthBase] = None, ) -> FlareApiClient: client = FlareApiClient( api_key="test-api-key", tenant_id=tenant_id, api_domain=api_domain, _enable_beta_features=_enable_beta_features, - _disable_auth=_disable_auth, + _auth=_auth, ) if authenticated: