Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/unit/auth/test_auth_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.auth.auth_config import AZURE_AI_SERVICES_DEFAULT_SCOPE, REFRESH_TOKEN_BEFORE_MSEC


def test_refresh_token_before_msec_is_int():
assert isinstance(REFRESH_TOKEN_BEFORE_MSEC, int)


def test_refresh_token_before_msec_value():
assert REFRESH_TOKEN_BEFORE_MSEC == 300


def test_azure_ai_services_default_scope_is_list():
assert isinstance(AZURE_AI_SERVICES_DEFAULT_SCOPE, list)


def test_azure_ai_services_default_scope_contains_expected_entries():
assert "https://cognitiveservices.azure.com/.default" in AZURE_AI_SERVICES_DEFAULT_SCOPE
assert "https://ml.azure.com/.default" in AZURE_AI_SERVICES_DEFAULT_SCOPE


def test_azure_ai_services_default_scope_length():
assert len(AZURE_AI_SERVICES_DEFAULT_SCOPE) == 2
51 changes: 51 additions & 0 deletions tests/unit/auth/test_authenticator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytest

from pyrit.auth.authenticator import Authenticator


class ConcreteAuthenticator(Authenticator):
"""Minimal concrete subclass for testing the ABC."""

def __init__(self) -> None:
self.token = "test-token"


@pytest.fixture
def authenticator():
return ConcreteAuthenticator()


def test_authenticator_is_abstract():
assert hasattr(Authenticator, "__abstractmethods__") is False or len(Authenticator.__abstractmethods__) == 0
# Authenticator has no abstract methods (uses NotImplementedError pattern instead)


def test_refresh_token_raises_not_implemented(authenticator):
with pytest.raises(NotImplementedError, match="refresh_token"):
authenticator.refresh_token()


@pytest.mark.asyncio
async def test_refresh_token_async_raises_not_implemented(authenticator):
with pytest.raises(NotImplementedError, match="refresh_token"):
await authenticator.refresh_token_async()


def test_get_token_raises_not_implemented(authenticator):
with pytest.raises(NotImplementedError, match="get_token"):
authenticator.get_token()


@pytest.mark.asyncio
async def test_get_token_async_raises_not_implemented(authenticator):
with pytest.raises(NotImplementedError, match="get_token"):
await authenticator.get_token_async()


def test_token_attribute_can_be_set(authenticator):
assert authenticator.token == "test-token"
authenticator.token = "new-token"
assert authenticator.token == "new-token"
99 changes: 99 additions & 0 deletions tests/unit/auth/test_manual_copilot_authenticator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
from unittest.mock import patch

import jwt as pyjwt
import pytest

from pyrit.auth.manual_copilot_authenticator import ManualCopilotAuthenticator


def _make_jwt(claims: dict) -> str:
"""Create an unsigned JWT with the given claims for testing."""
return pyjwt.encode(claims, key="secret", algorithm="HS256")


VALID_CLAIMS = {"tid": "tenant-id-123", "oid": "object-id-456", "sub": "user"}
VALID_TOKEN = _make_jwt(VALID_CLAIMS)


def test_init_with_valid_token():
auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN)
assert auth.get_token() == VALID_TOKEN


def test_init_reads_from_env_var_when_no_token_provided():
with patch.dict(os.environ, {ManualCopilotAuthenticator.ACCESS_TOKEN_ENV_VAR: VALID_TOKEN}):
auth = ManualCopilotAuthenticator()
assert auth.get_token() == VALID_TOKEN


def test_init_raises_when_no_token_and_no_env_var():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="access_token must be provided"):
ManualCopilotAuthenticator()


def test_init_raises_for_invalid_jwt():
with pytest.raises(ValueError, match="Failed to decode access_token as JWT"):
ManualCopilotAuthenticator(access_token="not-a-valid-jwt")


def test_init_raises_when_missing_tid_claim():
token = _make_jwt({"oid": "object-id-456"})
with pytest.raises(ValueError, match="missing required claims"):
ManualCopilotAuthenticator(access_token=token)


def test_init_raises_when_missing_oid_claim():
token = _make_jwt({"tid": "tenant-id-123"})
with pytest.raises(ValueError, match="missing required claims"):
ManualCopilotAuthenticator(access_token=token)


def test_init_raises_when_missing_both_required_claims():
token = _make_jwt({"sub": "user"})
with pytest.raises(ValueError, match="missing required claims"):
ManualCopilotAuthenticator(access_token=token)


def test_get_token_returns_access_token():
auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN)
assert auth.get_token() == VALID_TOKEN


@pytest.mark.asyncio
async def test_get_token_async_returns_access_token():
auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN)
result = await auth.get_token_async()
assert result == VALID_TOKEN


@pytest.mark.asyncio
async def test_get_claims_returns_decoded_claims():
auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN)
claims = await auth.get_claims()
assert claims["tid"] == "tenant-id-123"
assert claims["oid"] == "object-id-456"


def test_refresh_token_raises_runtime_error():
auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN)
with pytest.raises(RuntimeError, match="Manual token cannot be refreshed"):
auth.refresh_token()


@pytest.mark.asyncio
async def test_refresh_token_async_raises_runtime_error():
auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN)
with pytest.raises(RuntimeError, match="Manual token cannot be refreshed"):
await auth.refresh_token_async()


def test_direct_token_takes_precedence_over_env_var():
other_token = _make_jwt({"tid": "other-tenant", "oid": "other-oid"})
with patch.dict(os.environ, {ManualCopilotAuthenticator.ACCESS_TOKEN_ENV_VAR: other_token}):
auth = ManualCopilotAuthenticator(access_token=VALID_TOKEN)
assert auth.get_token() == VALID_TOKEN
Loading