Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions tests/unit/common/test_apply_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytest

from pyrit.common.apply_defaults import (
REQUIRED_VALUE,
DefaultValueScope,
GlobalDefaultValues,
apply_defaults,
get_global_default_values,
reset_default_values,
set_default_value,
set_global_variable,
)


@pytest.fixture(autouse=True)
def _reset_defaults():
reset_default_values()
yield
reset_default_values()


class _Base:
@apply_defaults
def __init__(self, *, name: str | None = None, count: int = 5) -> None:
self.name = name
self.count = count


class _Child(_Base):
pass


class _WithRequired:
@apply_defaults
def __init__(self, *, target: str = REQUIRED_VALUE) -> None:
self.target = target


# --- _RequiredValueSentinel ---


def test_required_value_sentinel_repr():
assert repr(REQUIRED_VALUE) == "REQUIRED_VALUE"


def test_required_value_sentinel_is_falsy():
assert not REQUIRED_VALUE


# --- DefaultValueScope ---


def test_default_value_scope_hash_equal():
s1 = DefaultValueScope(class_type=_Base, parameter_name="name", include_subclasses=True)
s2 = DefaultValueScope(class_type=_Base, parameter_name="name", include_subclasses=True)
assert hash(s1) == hash(s2)


def test_default_value_scope_hash_differs_on_param():
s1 = DefaultValueScope(class_type=_Base, parameter_name="name")
s2 = DefaultValueScope(class_type=_Base, parameter_name="count")
assert hash(s1) != hash(s2)


# --- GlobalDefaultValues ---


def test_global_default_values_set_and_get():
registry = GlobalDefaultValues()
registry.set_default_value(class_type=_Base, parameter_name="name", value="hello")
found, val = registry.get_default_value(class_type=_Base, parameter_name="name")
assert found is True
assert val == "hello"


def test_global_default_values_not_found():
registry = GlobalDefaultValues()
found, val = registry.get_default_value(class_type=_Base, parameter_name="name")
assert found is False
assert val is None


def test_global_default_values_subclass_inheritance():
registry = GlobalDefaultValues()
registry.set_default_value(class_type=_Base, parameter_name="name", value="inherited")
found, val = registry.get_default_value(class_type=_Child, parameter_name="name")
assert found is True
assert val == "inherited"


def test_global_default_values_no_subclass_when_disabled():
registry = GlobalDefaultValues()
registry.set_default_value(class_type=_Base, parameter_name="name", value="no-inherit", include_subclasses=False)
found, val = registry.get_default_value(class_type=_Child, parameter_name="name")
assert found is False


def test_global_default_values_reset():
registry = GlobalDefaultValues()
registry.set_default_value(class_type=_Base, parameter_name="name", value="x")
registry.reset_defaults()
assert registry.all_defaults == {}


def test_global_default_values_all_defaults_returns_copy():
registry = GlobalDefaultValues()
registry.set_default_value(class_type=_Base, parameter_name="name", value="x")
copy = registry.all_defaults
copy.clear()
assert len(registry.all_defaults) == 1


# --- Module-level helpers ---


def test_get_global_default_values_returns_instance():
assert isinstance(get_global_default_values(), GlobalDefaultValues)


def test_set_default_value_module_function():
set_default_value(class_type=_Base, parameter_name="name", value="mod")
found, val = get_global_default_values().get_default_value(class_type=_Base, parameter_name="name")
assert found is True
assert val == "mod"


def test_reset_default_values_clears():
set_default_value(class_type=_Base, parameter_name="name", value="clear")
reset_default_values()
found, _ = get_global_default_values().get_default_value(class_type=_Base, parameter_name="name")
assert found is False


def test_set_global_variable():
import sys

set_global_variable(name="_test_sentinel_var", value=42)
assert sys.modules["__main__"].__dict__["_test_sentinel_var"] == 42
del sys.modules["__main__"].__dict__["_test_sentinel_var"]


# --- @apply_defaults decorator ---


def test_apply_defaults_uses_explicit_args():
obj = _Base(name="explicit", count=10)
assert obj.name == "explicit"
assert obj.count == 10


def test_apply_defaults_uses_registered_default_when_none():
set_default_value(class_type=_Base, parameter_name="name", value="default_name")
obj = _Base()
assert obj.name == "default_name"


def test_apply_defaults_explicit_overrides_registered():
set_default_value(class_type=_Base, parameter_name="name", value="default_name")
obj = _Base(name="explicit")
assert obj.name == "explicit"


def test_apply_defaults_inherits_to_subclass():
set_default_value(class_type=_Base, parameter_name="name", value="parent_default")
obj = _Child()
assert obj.name == "parent_default"


def test_apply_defaults_required_value_raises_when_missing():
with pytest.raises(ValueError, match="target is required"):
_WithRequired()


def test_apply_defaults_required_value_satisfied_by_registered():
set_default_value(class_type=_WithRequired, parameter_name="target", value="registered")
obj = _WithRequired()
assert obj.target == "registered"


def test_apply_defaults_required_value_satisfied_by_explicit():
obj = _WithRequired(target="explicit")
assert obj.target == "explicit"


def test_apply_defaults_none_on_required_value_param_raises():
with pytest.raises(ValueError, match="target is required"):
_WithRequired(target=None)
44 changes: 44 additions & 0 deletions tests/unit/common/test_csv_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from io import StringIO

from pyrit.common.csv_helper import read_csv, write_csv


def test_read_csv_returns_list_of_dicts():
data = "name,age\nAlice,30\nBob,25\n"
result = read_csv(StringIO(data))
assert result == [{"name": "Alice", "age": "30"}, {"name": "Bob", "age": "25"}]


def test_read_csv_empty_body():
data = "name,age\n"
result = read_csv(StringIO(data))
assert result == []


def test_read_csv_single_column():
data = "value\nfoo\nbar\n"
result = read_csv(StringIO(data))
assert result == [{"value": "foo"}, {"value": "bar"}]


def test_write_csv_produces_expected_output():
output = StringIO()
examples = [{"col1": "a", "col2": "b"}, {"col1": "c", "col2": "d"}]
write_csv(output, examples)
output.seek(0)
lines = output.read().strip().splitlines()
assert lines[0] == "col1,col2"
assert lines[1] == "a,b"
assert lines[2] == "c,d"


def test_write_then_read_roundtrip():
examples = [{"x": "1", "y": "2"}, {"x": "3", "y": "4"}]
buf = StringIO()
write_csv(buf, examples)
buf.seek(0)
result = read_csv(buf)
assert result == examples
53 changes: 53 additions & 0 deletions tests/unit/common/test_data_url_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
from tempfile import NamedTemporaryFile
from unittest.mock import AsyncMock, patch

import pytest

from pyrit.common.data_url_converter import (
AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS,
convert_local_image_to_data_url,
)


def test_supported_image_formats_contains_common_types():
assert ".jpg" in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS
assert ".png" in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS
assert ".gif" in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS


@pytest.mark.asyncio
async def test_convert_raises_file_not_found():
with pytest.raises(FileNotFoundError):
await convert_local_image_to_data_url("nonexistent_image.jpg")


@pytest.mark.asyncio
async def test_convert_raises_for_unsupported_format():
with NamedTemporaryFile(suffix=".svg", delete=False) as f:
tmp = f.name
try:
with pytest.raises(ValueError, match="Unsupported image format"):
await convert_local_image_to_data_url(tmp)
finally:
os.remove(tmp)


@pytest.mark.asyncio
async def test_convert_returns_data_url():
with NamedTemporaryFile(suffix=".png", delete=False) as f:
tmp = f.name
try:
mock_serializer = AsyncMock()
mock_serializer.read_data_base64 = AsyncMock(return_value="AAAA")

with patch("pyrit.common.data_url_converter.data_serializer_factory", return_value=mock_serializer):
result = await convert_local_image_to_data_url(tmp)

assert result.startswith("data:image/png;base64,")
assert result.endswith("AAAA")
finally:
os.remove(tmp)
61 changes: 61 additions & 0 deletions tests/unit/common/test_deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import warnings

from pyrit.common.deprecation import print_deprecation_message


def _old_func():
pass


def _new_func():
pass


class _OldClass:
pass


class _NewClass:
pass


def test_deprecation_warning_with_callables():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
print_deprecation_message(old_item=_old_func, new_item=_new_func, removed_in="2.0")
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "_old_func" in str(w[0].message)
assert "_new_func" in str(w[0].message)
assert "2.0" in str(w[0].message)


def test_deprecation_warning_with_classes():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
print_deprecation_message(old_item=_OldClass, new_item=_NewClass, removed_in="3.0")
assert len(w) == 1
assert "_OldClass" in str(w[0].message)
assert "_NewClass" in str(w[0].message)
assert "3.0" in str(w[0].message)


def test_deprecation_warning_with_strings():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
print_deprecation_message(old_item="OldName", new_item="NewName", removed_in="4.0")
assert len(w) == 1
assert "OldName" in str(w[0].message)
assert "NewName" in str(w[0].message)


def test_deprecation_warning_mixed_types():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
print_deprecation_message(old_item=_OldClass, new_item="some.new.path", removed_in="5.0")
assert len(w) == 1
assert "_OldClass" in str(w[0].message)
assert "some.new.path" in str(w[0].message)
Loading