diff --git a/tests/unit/auth/test_auth_config.py b/tests/unit/auth/test_auth_config.py new file mode 100644 index 0000000000..531b3c7fc1 --- /dev/null +++ b/tests/unit/auth/test_auth_config.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pyrit.auth.auth_config import AZURE_AI_SERVICES_DEFAULT_SCOPE, REFRESH_TOKEN_BEFORE_MSEC + + +def test_refresh_token_before_msec_is_int(): + assert isinstance(REFRESH_TOKEN_BEFORE_MSEC, int) + + +def test_refresh_token_before_msec_value(): + assert REFRESH_TOKEN_BEFORE_MSEC == 300 + + +def test_azure_ai_services_default_scope_is_list(): + assert isinstance(AZURE_AI_SERVICES_DEFAULT_SCOPE, list) + + +def test_azure_ai_services_default_scope_contains_expected_entries(): + assert "https://cognitiveservices.azure.com/.default" in AZURE_AI_SERVICES_DEFAULT_SCOPE + assert "https://ml.azure.com/.default" in AZURE_AI_SERVICES_DEFAULT_SCOPE + + +def test_azure_ai_services_default_scope_length(): + assert len(AZURE_AI_SERVICES_DEFAULT_SCOPE) == 2 diff --git a/tests/unit/auth/test_authenticator.py b/tests/unit/auth/test_authenticator.py new file mode 100644 index 0000000000..87956b03a4 --- /dev/null +++ b/tests/unit/auth/test_authenticator.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.auth.authenticator import Authenticator + + +class ConcreteAuthenticator(Authenticator): + """Minimal concrete subclass for testing the ABC.""" + + def __init__(self) -> None: + self.token = "test-token" + + +@pytest.fixture +def authenticator(): + return ConcreteAuthenticator() + + +def test_authenticator_is_abstract(): + assert hasattr(Authenticator, "__abstractmethods__") is False or len(Authenticator.__abstractmethods__) == 0 + # Authenticator has no abstract methods (uses NotImplementedError pattern instead) + + +def test_refresh_token_raises_not_implemented(authenticator): + with pytest.raises(NotImplementedError, match="refresh_token"): + authenticator.refresh_token() + + +@pytest.mark.asyncio +async def test_refresh_token_async_raises_not_implemented(authenticator): + with pytest.raises(NotImplementedError, match="refresh_token"): + await authenticator.refresh_token_async() + + +def test_get_token_raises_not_implemented(authenticator): + with pytest.raises(NotImplementedError, match="get_token"): + authenticator.get_token() + + +@pytest.mark.asyncio +async def test_get_token_async_raises_not_implemented(authenticator): + with pytest.raises(NotImplementedError, match="get_token"): + await authenticator.get_token_async() + + +def test_token_attribute_can_be_set(authenticator): + assert authenticator.token == "test-token" + authenticator.token = "new-token" + assert authenticator.token == "new-token" diff --git a/tests/unit/auth/test_manual_copilot_authenticator.py b/tests/unit/auth/test_manual_copilot_authenticator.py new file mode 100644 index 0000000000..dda8eab232 --- /dev/null +++ b/tests/unit/auth/test_manual_copilot_authenticator.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from unittest.mock import patch + +import jwt as pyjwt +import pytest + +from pyrit.auth.manual_copilot_authenticator import ManualCopilotAuthenticator + + +def _make_jwt(claims: dict) -> str: + """Create an unsigned JWT with the given claims for testing.""" + return pyjwt.encode(claims, key="secret", algorithm="HS256") + + +VALID_CLAIMS = {"tid": "tenant-id-123", "oid": "object-id-456", "sub": "user"} +VALID_TOKEN = _make_jwt(VALID_CLAIMS) + + +def test_init_with_valid_token(): + auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN) + assert auth.get_token() == VALID_TOKEN + + +def test_init_reads_from_env_var_when_no_token_provided(): + with patch.dict(os.environ, {ManualCopilotAuthenticator.ACCESS_TOKEN_ENV_VAR: VALID_TOKEN}): + auth = ManualCopilotAuthenticator() + assert auth.get_token() == VALID_TOKEN + + +def test_init_raises_when_no_token_and_no_env_var(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="access_token must be provided"): + ManualCopilotAuthenticator() + + +def test_init_raises_for_invalid_jwt(): + with pytest.raises(ValueError, match="Failed to decode access_token as JWT"): + ManualCopilotAuthenticator(access_token="not-a-valid-jwt") + + +def test_init_raises_when_missing_tid_claim(): + token = _make_jwt({"oid": "object-id-456"}) + with pytest.raises(ValueError, match="missing required claims"): + ManualCopilotAuthenticator(access_token=token) + + +def test_init_raises_when_missing_oid_claim(): + token = _make_jwt({"tid": "tenant-id-123"}) + with pytest.raises(ValueError, match="missing required claims"): + ManualCopilotAuthenticator(access_token=token) + + +def test_init_raises_when_missing_both_required_claims(): + token = _make_jwt({"sub": "user"}) + with pytest.raises(ValueError, match="missing required claims"): + ManualCopilotAuthenticator(access_token=token) + + +def test_get_token_returns_access_token(): + auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN) + assert auth.get_token() == VALID_TOKEN + + +@pytest.mark.asyncio +async def test_get_token_async_returns_access_token(): + auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN) + result = await auth.get_token_async() + assert result == VALID_TOKEN + + +@pytest.mark.asyncio +async def test_get_claims_returns_decoded_claims(): + auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN) + claims = await auth.get_claims() + assert claims["tid"] == "tenant-id-123" + assert claims["oid"] == "object-id-456" + + +def test_refresh_token_raises_runtime_error(): + auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN) + with pytest.raises(RuntimeError, match="Manual token cannot be refreshed"): + auth.refresh_token() + + +@pytest.mark.asyncio +async def test_refresh_token_async_raises_runtime_error(): + auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN) + with pytest.raises(RuntimeError, match="Manual token cannot be refreshed"): + await auth.refresh_token_async() + + +def test_direct_token_takes_precedence_over_env_var(): + other_token = _make_jwt({"tid": "other-tenant", "oid": "other-oid"}) + with patch.dict(os.environ, {ManualCopilotAuthenticator.ACCESS_TOKEN_ENV_VAR: other_token}): + auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN) + assert auth.get_token() == VALID_TOKEN