From 52b74e6f53f98c69af073959bb89a320e3ba2238 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Sat, 11 Apr 2026 18:04:47 -0700 Subject: [PATCH] Add unit tests for pyrit/common/ utilities Adds tests for 10 previously untested utility files: - apply_defaults.py (decorator, sentinel, registry) - csv_helper.py (read/write/roundtrip) - data_url_converter.py (format support, encoding) - deprecation.py (warnings for callables, classes) - display_response.py (notebook skip, image display) - path.py (constants, git_repo detection, default paths) - question_answer_helpers.py (prompt formatting) - singleton.py (identity, distinct classes) - utils.py (combine_list, to_sha256) - yaml_loadable.py (basic load, from_dict, errors) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/common/test_apply_defaults.py | 190 ++++++++++++++++++ tests/unit/common/test_csv_helper.py | 44 ++++ tests/unit/common/test_data_url_converter.py | 53 +++++ tests/unit/common/test_deprecation.py | 61 ++++++ tests/unit/common/test_display_response.py | 81 ++++++++ tests/unit/common/test_path.py | 84 ++++++++ .../common/test_question_answer_helpers.py | 52 +++++ tests/unit/common/test_singleton.py | 46 +++++ tests/unit/common/test_utils.py | 42 ++++ tests/unit/common/test_yaml_loadable.py | 60 ++++++ 10 files changed, 713 insertions(+) create mode 100644 tests/unit/common/test_apply_defaults.py create mode 100644 tests/unit/common/test_csv_helper.py create mode 100644 tests/unit/common/test_data_url_converter.py create mode 100644 tests/unit/common/test_deprecation.py create mode 100644 tests/unit/common/test_display_response.py create mode 100644 tests/unit/common/test_path.py create mode 100644 tests/unit/common/test_question_answer_helpers.py create mode 100644 tests/unit/common/test_singleton.py create mode 100644 tests/unit/common/test_utils.py create mode 100644 tests/unit/common/test_yaml_loadable.py diff --git a/tests/unit/common/test_apply_defaults.py b/tests/unit/common/test_apply_defaults.py new file mode 100644 index 0000000000..77f472ee07 --- /dev/null +++ b/tests/unit/common/test_apply_defaults.py @@ -0,0 +1,190 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.common.apply_defaults import ( + REQUIRED_VALUE, + DefaultValueScope, + GlobalDefaultValues, + apply_defaults, + get_global_default_values, + reset_default_values, + set_default_value, + set_global_variable, +) + + +@pytest.fixture(autouse=True) +def _reset_defaults(): + reset_default_values() + yield + reset_default_values() + + +class _Base: + @apply_defaults + def __init__(self, *, name: str | None = None, count: int = 5) -> None: + self.name = name + self.count = count + + +class _Child(_Base): + pass + + +class _WithRequired: + @apply_defaults + def __init__(self, *, target: str = REQUIRED_VALUE) -> None: + self.target = target + + +# --- _RequiredValueSentinel --- + + +def test_required_value_sentinel_repr(): + assert repr(REQUIRED_VALUE) == "REQUIRED_VALUE" + + +def test_required_value_sentinel_is_falsy(): + assert not REQUIRED_VALUE + + +# --- DefaultValueScope --- + + +def test_default_value_scope_hash_equal(): + s1 = DefaultValueScope(class_type=_Base, parameter_name="name", include_subclasses=True) + s2 = DefaultValueScope(class_type=_Base, parameter_name="name", include_subclasses=True) + assert hash(s1) == hash(s2) + + +def test_default_value_scope_hash_differs_on_param(): + s1 = DefaultValueScope(class_type=_Base, parameter_name="name") + s2 = DefaultValueScope(class_type=_Base, parameter_name="count") + assert hash(s1) != hash(s2) + + +# --- GlobalDefaultValues --- + + +def test_global_default_values_set_and_get(): + registry = GlobalDefaultValues() + registry.set_default_value(class_type=_Base, parameter_name="name", value="hello") + found, val = registry.get_default_value(class_type=_Base, parameter_name="name") + assert found is True + assert val == "hello" + + +def test_global_default_values_not_found(): + registry = GlobalDefaultValues() + found, val = registry.get_default_value(class_type=_Base, parameter_name="name") + assert found is False + assert val is None + + +def test_global_default_values_subclass_inheritance(): + registry = GlobalDefaultValues() + registry.set_default_value(class_type=_Base, parameter_name="name", value="inherited") + found, val = registry.get_default_value(class_type=_Child, parameter_name="name") + assert found is True + assert val == "inherited" + + +def test_global_default_values_no_subclass_when_disabled(): + registry = GlobalDefaultValues() + registry.set_default_value(class_type=_Base, parameter_name="name", value="no-inherit", include_subclasses=False) + found, val = registry.get_default_value(class_type=_Child, parameter_name="name") + assert found is False + + +def test_global_default_values_reset(): + registry = GlobalDefaultValues() + registry.set_default_value(class_type=_Base, parameter_name="name", value="x") + registry.reset_defaults() + assert registry.all_defaults == {} + + +def test_global_default_values_all_defaults_returns_copy(): + registry = GlobalDefaultValues() + registry.set_default_value(class_type=_Base, parameter_name="name", value="x") + copy = registry.all_defaults + copy.clear() + assert len(registry.all_defaults) == 1 + + +# --- Module-level helpers --- + + +def test_get_global_default_values_returns_instance(): + assert isinstance(get_global_default_values(), GlobalDefaultValues) + + +def test_set_default_value_module_function(): + set_default_value(class_type=_Base, parameter_name="name", value="mod") + found, val = get_global_default_values().get_default_value(class_type=_Base, parameter_name="name") + assert found is True + assert val == "mod" + + +def test_reset_default_values_clears(): + set_default_value(class_type=_Base, parameter_name="name", value="clear") + reset_default_values() + found, _ = get_global_default_values().get_default_value(class_type=_Base, parameter_name="name") + assert found is False + + +def test_set_global_variable(): + import sys + + set_global_variable(name="_test_sentinel_var", value=42) + assert sys.modules["__main__"].__dict__["_test_sentinel_var"] == 42 + del sys.modules["__main__"].__dict__["_test_sentinel_var"] + + +# --- @apply_defaults decorator --- + + +def test_apply_defaults_uses_explicit_args(): + obj = _Base(name="explicit", count=10) + assert obj.name == "explicit" + assert obj.count == 10 + + +def test_apply_defaults_uses_registered_default_when_none(): + set_default_value(class_type=_Base, parameter_name="name", value="default_name") + obj = _Base() + assert obj.name == "default_name" + + +def test_apply_defaults_explicit_overrides_registered(): + set_default_value(class_type=_Base, parameter_name="name", value="default_name") + obj = _Base(name="explicit") + assert obj.name == "explicit" + + +def test_apply_defaults_inherits_to_subclass(): + set_default_value(class_type=_Base, parameter_name="name", value="parent_default") + obj = _Child() + assert obj.name == "parent_default" + + +def test_apply_defaults_required_value_raises_when_missing(): + with pytest.raises(ValueError, match="target is required"): + _WithRequired() + + +def test_apply_defaults_required_value_satisfied_by_registered(): + set_default_value(class_type=_WithRequired, parameter_name="target", value="registered") + obj = _WithRequired() + assert obj.target == "registered" + + +def test_apply_defaults_required_value_satisfied_by_explicit(): + obj = _WithRequired(target="explicit") + assert obj.target == "explicit" + + +def test_apply_defaults_none_on_required_value_param_raises(): + with pytest.raises(ValueError, match="target is required"): + _WithRequired(target=None) diff --git a/tests/unit/common/test_csv_helper.py b/tests/unit/common/test_csv_helper.py new file mode 100644 index 0000000000..f39ea9a3d7 --- /dev/null +++ b/tests/unit/common/test_csv_helper.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from io import StringIO + +from pyrit.common.csv_helper import read_csv, write_csv + + +def test_read_csv_returns_list_of_dicts(): + data = "name,age\nAlice,30\nBob,25\n" + result = read_csv(StringIO(data)) + assert result == [{"name": "Alice", "age": "30"}, {"name": "Bob", "age": "25"}] + + +def test_read_csv_empty_body(): + data = "name,age\n" + result = read_csv(StringIO(data)) + assert result == [] + + +def test_read_csv_single_column(): + data = "value\nfoo\nbar\n" + result = read_csv(StringIO(data)) + assert result == [{"value": "foo"}, {"value": "bar"}] + + +def test_write_csv_produces_expected_output(): + output = StringIO() + examples = [{"col1": "a", "col2": "b"}, {"col1": "c", "col2": "d"}] + write_csv(output, examples) + output.seek(0) + lines = output.read().strip().splitlines() + assert lines[0] == "col1,col2" + assert lines[1] == "a,b" + assert lines[2] == "c,d" + + +def test_write_then_read_roundtrip(): + examples = [{"x": "1", "y": "2"}, {"x": "3", "y": "4"}] + buf = StringIO() + write_csv(buf, examples) + buf.seek(0) + result = read_csv(buf) + assert result == examples diff --git a/tests/unit/common/test_data_url_converter.py b/tests/unit/common/test_data_url_converter.py new file mode 100644 index 0000000000..206d8fa86f --- /dev/null +++ b/tests/unit/common/test_data_url_converter.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from tempfile import NamedTemporaryFile +from unittest.mock import AsyncMock, patch + +import pytest + +from pyrit.common.data_url_converter import ( + AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS, + convert_local_image_to_data_url, +) + + +def test_supported_image_formats_contains_common_types(): + assert ".jpg" in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS + assert ".png" in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS + assert ".gif" in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS + + +@pytest.mark.asyncio +async def test_convert_raises_file_not_found(): + with pytest.raises(FileNotFoundError): + await convert_local_image_to_data_url("nonexistent_image.jpg") + + +@pytest.mark.asyncio +async def test_convert_raises_for_unsupported_format(): + with NamedTemporaryFile(suffix=".svg", delete=False) as f: + tmp = f.name + try: + with pytest.raises(ValueError, match="Unsupported image format"): + await convert_local_image_to_data_url(tmp) + finally: + os.remove(tmp) + + +@pytest.mark.asyncio +async def test_convert_returns_data_url(): + with NamedTemporaryFile(suffix=".png", delete=False) as f: + tmp = f.name + try: + mock_serializer = AsyncMock() + mock_serializer.read_data_base64 = AsyncMock(return_value="AAAA") + + with patch("pyrit.common.data_url_converter.data_serializer_factory", return_value=mock_serializer): + result = await convert_local_image_to_data_url(tmp) + + assert result.startswith("data:image/png;base64,") + assert result.endswith("AAAA") + finally: + os.remove(tmp) diff --git a/tests/unit/common/test_deprecation.py b/tests/unit/common/test_deprecation.py new file mode 100644 index 0000000000..e22d028f7f --- /dev/null +++ b/tests/unit/common/test_deprecation.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import warnings + +from pyrit.common.deprecation import print_deprecation_message + + +def _old_func(): + pass + + +def _new_func(): + pass + + +class _OldClass: + pass + + +class _NewClass: + pass + + +def test_deprecation_warning_with_callables(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + print_deprecation_message(old_item=_old_func, new_item=_new_func, removed_in="2.0") + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "_old_func" in str(w[0].message) + assert "_new_func" in str(w[0].message) + assert "2.0" in str(w[0].message) + + +def test_deprecation_warning_with_classes(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + print_deprecation_message(old_item=_OldClass, new_item=_NewClass, removed_in="3.0") + assert len(w) == 1 + assert "_OldClass" in str(w[0].message) + assert "_NewClass" in str(w[0].message) + assert "3.0" in str(w[0].message) + + +def test_deprecation_warning_with_strings(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + print_deprecation_message(old_item="OldName", new_item="NewName", removed_in="4.0") + assert len(w) == 1 + assert "OldName" in str(w[0].message) + assert "NewName" in str(w[0].message) + + +def test_deprecation_warning_mixed_types(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + print_deprecation_message(old_item=_OldClass, new_item="some.new.path", removed_in="5.0") + assert len(w) == 1 + assert "_OldClass" in str(w[0].message) + assert "some.new.path" in str(w[0].message) diff --git a/tests/unit/common/test_display_response.py b/tests/unit/common/test_display_response.py new file mode 100644 index 0000000000..43c23686cf --- /dev/null +++ b/tests/unit/common/test_display_response.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.common.display_response import display_image_response + + +@pytest.fixture() +def _mock_central_memory(): + mock_memory = MagicMock() + mock_memory.results_storage_io.read_file = AsyncMock(return_value=b"\x89PNG") + with patch("pyrit.memory.CentralMemory.get_memory_instance", return_value=mock_memory): + yield mock_memory + + +@pytest.mark.asyncio +@patch("pyrit.common.display_response.is_in_ipython_session", return_value=False) +async def test_display_image_skips_when_not_notebook(mock_ipython, _mock_central_memory): + piece = MagicMock() + piece.response_error = "none" + piece.converted_value_data_type = "image_path" + piece.converted_value = "some/image.png" + await display_image_response(piece) + # No error — function should silently skip display outside notebook + + +@pytest.mark.asyncio +async def test_display_image_logs_blocked_response(_mock_central_memory, caplog): + piece = MagicMock() + piece.response_error = "blocked" + piece.converted_value_data_type = "text" + with caplog.at_level(logging.INFO, logger="pyrit.common.display_response"): + await display_image_response(piece) + assert "Content blocked" in caplog.text + + +@pytest.mark.asyncio +async def test_display_image_no_action_for_text_type(_mock_central_memory): + piece = MagicMock() + piece.response_error = "none" + piece.converted_value_data_type = "text" + await display_image_response(piece) + + +@pytest.mark.asyncio +@patch("pyrit.common.display_response.is_in_ipython_session", return_value=True) +@patch("pyrit.common.display_response.Image") +@patch("builtins.display", create=True) +async def test_display_image_reads_and_displays(mock_display, mock_image, mock_ipython, _mock_central_memory): + piece = MagicMock() + piece.response_error = "none" + piece.converted_value_data_type = "image_path" + piece.converted_value = "path/to/img.png" + + mock_img_obj = MagicMock() + mock_image.open.return_value = mock_img_obj + + await display_image_response(piece) + + _mock_central_memory.results_storage_io.read_file.assert_awaited_once_with("path/to/img.png") + mock_image.open.assert_called_once() + mock_display.assert_called_once_with(mock_img_obj) + + +@pytest.mark.asyncio +@patch("pyrit.common.display_response.is_in_ipython_session", return_value=True) +async def test_display_image_logs_error_on_read_failure(mock_ipython, _mock_central_memory, caplog): + piece = MagicMock() + piece.response_error = "none" + piece.converted_value_data_type = "image_path" + piece.converted_value = "bad/path.png" + + _mock_central_memory.results_storage_io.read_file = AsyncMock(side_effect=Exception("disk error")) + + with caplog.at_level(logging.ERROR, logger="pyrit.common.display_response"): + await display_image_response(piece) + assert "Failed to read image" in caplog.text diff --git a/tests/unit/common/test_path.py b/tests/unit/common/test_path.py new file mode 100644 index 0000000000..45ebcd4bbb --- /dev/null +++ b/tests/unit/common/test_path.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pathlib +from unittest.mock import patch + +from pyrit.common.path import ( + CONFIGURATION_DIRECTORY_PATH, + DATASETS_PATH, + DB_DATA_PATH, + DEFAULT_CONFIG_FILENAME, + DEFAULT_CONFIG_PATH, + DOCS_CODE_PATH, + DOCS_PATH, + HOME_PATH, + LOG_PATH, + PATHS_DICT, + PYRIT_PATH, + get_default_data_path, + in_git_repo, +) + + +def test_pyrit_path_is_absolute(): + assert PYRIT_PATH.is_absolute() + + +def test_home_path_is_parent_of_pyrit_path(): + assert PYRIT_PATH.parent == HOME_PATH + + +def test_docs_path_relative_to_home(): + assert (HOME_PATH / "doc").resolve() == DOCS_PATH + + +def test_docs_code_path_relative_to_home(): + assert (HOME_PATH / "doc" / "code").resolve() == DOCS_CODE_PATH + + +def test_datasets_path_inside_pyrit(): + assert (PYRIT_PATH / "datasets").resolve() == DATASETS_PATH + + +def test_configuration_directory_is_in_home(): + assert pathlib.Path.home() / ".pyrit" == CONFIGURATION_DIRECTORY_PATH + + +def test_default_config_filename(): + assert DEFAULT_CONFIG_FILENAME == ".pyrit_conf" + + +def test_default_config_path(): + assert DEFAULT_CONFIG_PATH == CONFIGURATION_DIRECTORY_PATH / DEFAULT_CONFIG_FILENAME + + +def test_db_data_path_exists(): + assert DB_DATA_PATH.exists() + + +def test_log_path_exists(): + assert LOG_PATH.exists() + + +def test_paths_dict_contains_expected_keys(): + expected = {"pyrit_path", "datasets_path", "db_data_path", "log_path", "docs_path"} + assert expected.issubset(set(PATHS_DICT.keys())) + + +def test_in_git_repo_returns_bool(): + result = in_git_repo() + assert isinstance(result, bool) + + +def test_get_default_data_path_in_git_repo(): + with patch("pyrit.common.path.in_git_repo", return_value=True): + result = get_default_data_path("testdir") + assert result == pathlib.Path(PYRIT_PATH, "..", "testdir").resolve() + + +def test_get_default_data_path_not_in_git_repo(): + with patch("pyrit.common.path.in_git_repo", return_value=False): + result = get_default_data_path("testdir") + assert "testdir" in str(result) + assert result.is_absolute() diff --git a/tests/unit/common/test_question_answer_helpers.py b/tests/unit/common/test_question_answer_helpers.py new file mode 100644 index 0000000000..eac5beca14 --- /dev/null +++ b/tests/unit/common/test_question_answer_helpers.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pyrit.common.question_answer_helpers import construct_evaluation_prompt +from pyrit.models import QuestionAnsweringEntry, QuestionChoice + + +def test_construct_evaluation_prompt_basic(): + entry = QuestionAnsweringEntry( + question="What color is the sky?", + answer_type="str", + correct_answer="blue", + choices=[ + QuestionChoice(index=0, text="red"), + QuestionChoice(index=1, text="blue"), + ], + ) + result = construct_evaluation_prompt(entry) + assert "What color is the sky?" in result + assert "index=0, value=red" in result + assert "index=1, value=blue" in result + + +def test_construct_evaluation_prompt_single_choice(): + entry = QuestionAnsweringEntry( + question="Is 1+1=2?", + answer_type="bool", + correct_answer="True", + choices=[QuestionChoice(index=0, text="True")], + ) + result = construct_evaluation_prompt(entry) + assert "Question:" in result + assert "Choices:" in result + assert "index=0, value=True" in result + + +def test_construct_evaluation_prompt_format(): + entry = QuestionAnsweringEntry( + question="Pick a number", + answer_type="int", + correct_answer=2, + choices=[ + QuestionChoice(index=0, text="1"), + QuestionChoice(index=1, text="2"), + QuestionChoice(index=2, text="3"), + ], + ) + result = construct_evaluation_prompt(entry) + lines = result.split("\n") + assert lines[0] == "Question:" + assert lines[1] == "Pick a number" + assert lines[3] == "Choices:" diff --git a/tests/unit/common/test_singleton.py b/tests/unit/common/test_singleton.py new file mode 100644 index 0000000000..8665d96f52 --- /dev/null +++ b/tests/unit/common/test_singleton.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import abc + +from pyrit.common.singleton import Singleton + + +def test_singleton_returns_same_instance(): + class _MySingleton(abc.ABC, metaclass=Singleton): + pass + + a = _MySingleton() + b = _MySingleton() + assert a is b + + # Cleanup to avoid polluting other tests + Singleton._instances.pop(_MySingleton, None) + + +def test_singleton_different_classes_have_different_instances(): + class _A(abc.ABC, metaclass=Singleton): + pass + + class _B(abc.ABC, metaclass=Singleton): + pass + + a = _A() + b = _B() + assert a is not b + + Singleton._instances.pop(_A, None) + Singleton._instances.pop(_B, None) + + +def test_singleton_preserves_init_args(): + class _Configured(abc.ABC, metaclass=Singleton): + def __init__(self, value: int = 0) -> None: + self.value = value + + first = _Configured(value=42) + second = _Configured(value=99) + assert first is second + assert second.value == 42 + + Singleton._instances.pop(_Configured, None) diff --git a/tests/unit/common/test_utils.py b/tests/unit/common/test_utils.py new file mode 100644 index 0000000000..da4e0840e9 --- /dev/null +++ b/tests/unit/common/test_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pyrit.common.utils import combine_list, to_sha256 + + +def test_combine_list_two_lists(): + result = combine_list(["a", "b"], ["b", "c"]) + assert set(result) == {"a", "b", "c"} + + +def test_combine_list_strings(): + result = combine_list("x", "y") + assert set(result) == {"x", "y"} + + +def test_combine_list_mixed(): + result = combine_list("a", ["a", "b"]) + assert set(result) == {"a", "b"} + + +def test_combine_list_duplicates_removed(): + result = combine_list(["a", "a"], ["a"]) + assert result == ["a"] + + +def test_to_sha256_deterministic(): + h1 = to_sha256("hello") + h2 = to_sha256("hello") + assert h1 == h2 + assert len(h1) == 64 + + +def test_to_sha256_different_inputs(): + assert to_sha256("a") != to_sha256("b") + + +def test_to_sha256_known_value(): + import hashlib + + expected = hashlib.sha256(b"test").hexdigest() + assert to_sha256("test") == expected diff --git a/tests/unit/common/test_yaml_loadable.py b/tests/unit/common/test_yaml_loadable.py new file mode 100644 index 0000000000..7bc91c4aa1 --- /dev/null +++ b/tests/unit/common/test_yaml_loadable.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path + +import pytest + +from pyrit.common.yaml_loadable import YamlLoadable + + +class _SimpleYaml(YamlLoadable): + def __init__(self, name: str, value: int = 0) -> None: + self.name = name + self.value = value + + +class _WithFromDict(YamlLoadable): + def __init__(self, name: str) -> None: + self.name = name + + @classmethod + def from_dict(cls, data: dict) -> "_WithFromDict": + return cls(name=data["name"].upper()) + + +@pytest.fixture() +def yaml_file(tmp_path: Path) -> Path: + p = tmp_path / "test.yaml" + p.write_text("name: hello\nvalue: 42\n", encoding="utf-8") + return p + + +@pytest.fixture() +def yaml_file_for_from_dict(tmp_path: Path) -> Path: + p = tmp_path / "fd.yaml" + p.write_text("name: lower\n", encoding="utf-8") + return p + + +def test_from_yaml_file_basic(yaml_file: Path): + obj = _SimpleYaml.from_yaml_file(yaml_file) + assert obj.name == "hello" + assert obj.value == 42 + + +def test_from_yaml_file_uses_from_dict_if_available(yaml_file_for_from_dict: Path): + obj = _WithFromDict.from_yaml_file(yaml_file_for_from_dict) + assert obj.name == "LOWER" + + +def test_from_yaml_file_nonexistent_raises(): + with pytest.raises(FileNotFoundError): + _SimpleYaml.from_yaml_file("nonexistent_file.yaml") + + +def test_from_yaml_file_invalid_yaml(tmp_path: Path): + p = tmp_path / "bad.yaml" + p.write_text(":\n - :\n invalid: [unterminated", encoding="utf-8") + with pytest.raises((ValueError, TypeError)): + _SimpleYaml.from_yaml_file(p)