diff --git a/tests/unit/models/test_chat_message.py b/tests/unit/models/test_chat_message.py new file mode 100644 index 000000000..2a5f591c4 --- /dev/null +++ b/tests/unit/models/test_chat_message.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json + +import pytest +from pydantic import ValidationError + +from pyrit.models.chat_message import ( + ChatMessage, + ChatMessageListDictContent, + ChatMessagesDataset, + ToolCall, +) + + +def test_tool_call_init(): + tc = ToolCall(id="call_1", type="function", function="get_weather") + assert tc.id == "call_1" + assert tc.type == "function" + assert tc.function == "get_weather" + + +def test_tool_call_forbids_extra_fields(): + with pytest.raises(ValidationError): + ToolCall(id="call_1", type="function", function="get_weather", extra="bad") + + +def test_chat_message_init_with_string_content(): + msg = ChatMessage(role="user", content="hello") + assert msg.role == "user" + assert msg.content == "hello" + assert msg.name is None + assert msg.tool_calls is None + assert msg.tool_call_id is None + + +def test_chat_message_init_with_list_content(): + parts = [{"type": "text", "text": "hello"}, {"type": "image_url", "url": "http://img.png"}] + msg = ChatMessage(role="assistant", content=parts) + assert msg.content == parts + + +def test_chat_message_init_with_all_fields(): + tc = ToolCall(id="call_1", type="function", function="lookup") + msg = ChatMessage( + role="assistant", + content="result", + name="helper", + tool_calls=[tc], + tool_call_id="call_1", + ) + assert msg.name == "helper" + assert msg.tool_calls == [tc] + assert msg.tool_call_id == "call_1" + + +def test_chat_message_forbids_extra_fields(): + with pytest.raises(ValidationError): + ChatMessage(role="user", content="hi", extra_field="bad") + + +def test_chat_message_invalid_role(): + with pytest.raises(ValidationError): + ChatMessage(role="invalid_role", content="hi") + + +def test_chat_message_to_json(): + msg = ChatMessage(role="user", content="test") + json_str = msg.to_json() + parsed = json.loads(json_str) + assert parsed["role"] == "user" + assert parsed["content"] == "test" + + +def test_chat_message_to_dict_excludes_none(): + msg = ChatMessage(role="user", content="test") + d = msg.to_dict() + assert "name" not in d + assert "tool_calls" not in d + assert "tool_call_id" not in d + assert d["role"] == "user" + assert d["content"] == "test" + + +def test_chat_message_to_dict_includes_non_none(): + msg = ChatMessage(role="user", content="test", name="bot") + d = msg.to_dict() + assert d["name"] == "bot" + + +def test_chat_message_from_json(): + original = ChatMessage(role="system", content="you are helpful") + json_str = original.to_json() + restored = ChatMessage.from_json(json_str) + assert restored.role == original.role + assert restored.content == original.content + + +def test_chat_message_from_json_roundtrip_with_tool_calls(): + tc = ToolCall(id="c1", type="function", function="fn") + original = ChatMessage(role="assistant", content="ok", tool_calls=[tc], tool_call_id="c1") + restored = ChatMessage.from_json(original.to_json()) + assert restored.tool_calls[0].id == "c1" + assert restored.tool_call_id == "c1" + + +@pytest.mark.parametrize("role", ["system", "user", "assistant", "simulated_assistant", "tool", "developer"]) +def test_chat_message_accepts_all_valid_roles(role): + msg = ChatMessage(role=role, content="test") + assert msg.role == role + + +def test_chat_message_list_dict_content_deprecated(capsys): + msg = ChatMessageListDictContent(role="user", content="hello") + assert msg.role == "user" + assert msg.content == "hello" + + +def test_chat_messages_dataset_init(): + msgs = [[ChatMessage(role="user", content="hi"), ChatMessage(role="assistant", content="hello")]] + dataset = ChatMessagesDataset(name="test_ds", description="A test dataset", list_of_chat_messages=msgs) + assert dataset.name == "test_ds" + assert dataset.description == "A test dataset" + assert len(dataset.list_of_chat_messages) == 1 + assert len(dataset.list_of_chat_messages[0]) == 2 + + +def test_chat_messages_dataset_forbids_extra_fields(): + with pytest.raises(ValidationError): + ChatMessagesDataset( + name="ds", + description="desc", + list_of_chat_messages=[], + extra="bad", + ) diff --git a/tests/unit/models/test_conversation_reference.py b/tests/unit/models/test_conversation_reference.py new file mode 100644 index 000000000..5bf4e2833 --- /dev/null +++ b/tests/unit/models/test_conversation_reference.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.models.conversation_reference import ConversationReference, ConversationType + + +def test_conversation_type_values(): + assert ConversationType.ADVERSARIAL.value == "adversarial" + assert ConversationType.PRUNED.value == "pruned" + assert ConversationType.SCORE.value == "score" + assert ConversationType.CONVERTER.value == "converter" + + +def test_conversation_reference_init(): + ref = ConversationReference(conversation_id="abc-123", conversation_type=ConversationType.ADVERSARIAL) + assert ref.conversation_id == "abc-123" + assert ref.conversation_type == ConversationType.ADVERSARIAL + assert ref.description is None + + +def test_conversation_reference_with_description(): + ref = ConversationReference( + conversation_id="abc-123", + conversation_type=ConversationType.PRUNED, + description="pruned branch", + ) + assert ref.description == "pruned branch" + + +def test_conversation_reference_is_frozen(): + ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.SCORE) + with pytest.raises(AttributeError): + ref.conversation_id = "new_id" + + +def test_conversation_reference_hash(): + ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) + assert hash(ref) == hash("abc") + + +def test_conversation_reference_eq_same_id(): + ref1 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) + ref2 = ConversationReference( + conversation_id="abc", + conversation_type=ConversationType.PRUNED, + description="different", + ) + assert ref1 == ref2 + + +def test_conversation_reference_eq_different_id(): + ref1 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) + ref2 = ConversationReference(conversation_id="xyz", conversation_type=ConversationType.ADVERSARIAL) + assert ref1 != ref2 + + +def test_conversation_reference_eq_non_reference(): + ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) + assert ref != "abc" + assert ref != 42 + assert ref != None # noqa: E711 + + +def test_conversation_reference_usable_in_set(): + ref1 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) + ref2 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.PRUNED) + ref3 = ConversationReference(conversation_id="xyz", conversation_type=ConversationType.SCORE) + s = {ref1, ref2, ref3} + assert len(s) == 2 + + +def test_conversation_reference_usable_as_dict_key(): + ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.CONVERTER) + d = {ref: "value"} + lookup_ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) + assert d[lookup_ref] == "value" diff --git a/tests/unit/models/test_conversation_stats.py b/tests/unit/models/test_conversation_stats.py new file mode 100644 index 000000000..adeabb174 --- /dev/null +++ b/tests/unit/models/test_conversation_stats.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from datetime import datetime, timezone + +import pytest + +from pyrit.models.conversation_stats import ConversationStats + + +def test_conversation_stats_defaults(): + stats = ConversationStats() + assert stats.message_count == 0 + assert stats.last_message_preview is None + assert stats.labels == {} + assert stats.created_at is None + + +def test_conversation_stats_with_values(): + now = datetime.now(timezone.utc) + stats = ConversationStats( + message_count=5, + last_message_preview="Hello world", + labels={"env": "test"}, + created_at=now, + ) + assert stats.message_count == 5 + assert stats.last_message_preview == "Hello world" + assert stats.labels == {"env": "test"} + assert stats.created_at == now + + +def test_conversation_stats_is_frozen(): + stats = ConversationStats(message_count=3) + with pytest.raises(AttributeError): + stats.message_count = 10 + + +def test_conversation_stats_preview_max_len_class_var(): + assert ConversationStats.PREVIEW_MAX_LEN == 100 + + +def test_conversation_stats_labels_default_factory(): + stats1 = ConversationStats() + stats2 = ConversationStats() + assert stats1.labels is not stats2.labels diff --git a/tests/unit/models/test_question_answering.py b/tests/unit/models/test_question_answering.py new file mode 100644 index 000000000..aba02d9f6 --- /dev/null +++ b/tests/unit/models/test_question_answering.py @@ -0,0 +1,151 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest +from pydantic import ValidationError + +from pyrit.models.question_answering import ( + QuestionAnsweringDataset, + QuestionAnsweringEntry, + QuestionChoice, +) + + +def test_question_choice_init(): + choice = QuestionChoice(index=0, text="Option A") + assert choice.index == 0 + assert choice.text == "Option A" + + +def test_question_choice_forbids_extra(): + with pytest.raises(ValidationError): + QuestionChoice(index=0, text="A", extra="bad") + + +def test_question_answering_entry_init(): + choices = [QuestionChoice(index=0, text="A"), QuestionChoice(index=1, text="B")] + entry = QuestionAnsweringEntry( + question="What is 1+1?", + answer_type="int", + correct_answer=1, + choices=choices, + ) + assert entry.question == "What is 1+1?" + assert entry.answer_type == "int" + assert entry.correct_answer == 1 + assert len(entry.choices) == 2 + + +def test_question_answering_entry_invalid_answer_type(): + with pytest.raises(ValidationError): + QuestionAnsweringEntry( + question="Q", + answer_type="invalid", + correct_answer=0, + choices=[QuestionChoice(index=0, text="A")], + ) + + +def test_question_answering_entry_get_correct_answer_text(): + choices = [QuestionChoice(index=0, text="Paris"), QuestionChoice(index=1, text="London")] + entry = QuestionAnsweringEntry( + question="Capital of France?", + answer_type="int", + correct_answer=0, + choices=choices, + ) + assert entry.get_correct_answer_text() == "Paris" + + +def test_question_answering_entry_get_correct_answer_text_string_answer(): + choices = [QuestionChoice(index=0, text="Paris"), QuestionChoice(index=1, text="London")] + entry = QuestionAnsweringEntry( + question="Capital of France?", + answer_type="str", + correct_answer="0", + choices=choices, + ) + assert entry.get_correct_answer_text() == "Paris" + + +def test_question_answering_entry_get_correct_answer_text_no_match(): + choices = [QuestionChoice(index=0, text="A"), QuestionChoice(index=1, text="B")] + entry = QuestionAnsweringEntry( + question="Q", + answer_type="int", + correct_answer=99, + choices=choices, + ) + with pytest.raises(ValueError, match="No matching choice"): + entry.get_correct_answer_text() + + +def test_question_answering_entry_hash(): + choices = [QuestionChoice(index=0, text="A")] + entry1 = QuestionAnsweringEntry(question="Q", answer_type="str", correct_answer="A", choices=choices) + entry2 = QuestionAnsweringEntry(question="Q", answer_type="str", correct_answer="A", choices=choices) + assert hash(entry1) == hash(entry2) + + +def test_question_answering_entry_hash_different(): + choices = [QuestionChoice(index=0, text="A")] + entry1 = QuestionAnsweringEntry(question="Q1", answer_type="str", correct_answer="A", choices=choices) + entry2 = QuestionAnsweringEntry(question="Q2", answer_type="str", correct_answer="A", choices=choices) + assert hash(entry1) != hash(entry2) + + +def test_question_answering_entry_forbids_extra(): + with pytest.raises(ValidationError): + QuestionAnsweringEntry( + question="Q", + answer_type="str", + correct_answer="A", + choices=[], + extra="bad", + ) + + +def test_question_answering_dataset_init(): + choices = [QuestionChoice(index=0, text="A")] + entry = QuestionAnsweringEntry(question="Q", answer_type="str", correct_answer="A", choices=choices) + dataset = QuestionAnsweringDataset( + name="test_ds", + version="1.0", + description="test", + author="tester", + group="grp", + source="src", + questions=[entry], + ) + assert dataset.name == "test_ds" + assert dataset.version == "1.0" + assert len(dataset.questions) == 1 + + +def test_question_answering_dataset_defaults(): + choices = [QuestionChoice(index=0, text="A")] + entry = QuestionAnsweringEntry(question="Q", answer_type="str", correct_answer="A", choices=choices) + dataset = QuestionAnsweringDataset(questions=[entry]) + assert dataset.name == "" + assert dataset.version == "" + assert dataset.description == "" + assert dataset.author == "" + assert dataset.group == "" + assert dataset.source == "" + + +def test_question_answering_dataset_forbids_extra(): + with pytest.raises(ValidationError): + QuestionAnsweringDataset(questions=[], extra="bad") + + +@pytest.mark.parametrize("answer_type", ["int", "float", "str", "bool"]) +def test_question_answering_entry_valid_answer_types(answer_type): + choices = [QuestionChoice(index=0, text="A")] + entry = QuestionAnsweringEntry( + question="Q", + answer_type=answer_type, + correct_answer=0, + choices=choices, + ) + assert entry.answer_type == answer_type diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py new file mode 100644 index 000000000..c650b1311 --- /dev/null +++ b/tests/unit/models/test_scenario_result.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import uuid +from unittest.mock import patch + +from pyrit.models.attack_result import AttackOutcome, AttackResult +from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + + +def _make_scenario_identifier(**kwargs): + defaults = {"name": "TestScenario", "description": "A test", "scenario_version": 1} + defaults.update(kwargs) + return ScenarioIdentifier(**defaults) + + +def _make_component_identifier_dict(class_name="TestTarget"): + return {"__type__": class_name, "__module__": "test.module", "params": {}} + + +def _make_attack_result(*, objective="test objective", outcome=AttackOutcome.SUCCESS): + return AttackResult( + conversation_id=str(uuid.uuid4()), + objective=objective, + outcome=outcome, + ) + + +class TestScenarioIdentifier: + def test_init_basic(self): + si = ScenarioIdentifier(name="MySc") + assert si.name == "MySc" + assert si.description == "" + assert si.version == 1 + assert si.init_data is None + + def test_init_with_all_params(self): + si = ScenarioIdentifier( + name="MySc", + description="desc", + scenario_version=2, + init_data={"key": "val"}, + pyrit_version="1.0.0", + ) + assert si.version == 2 + assert si.init_data == {"key": "val"} + assert si.pyrit_version == "1.0.0" + + def test_init_default_pyrit_version(self): + import pyrit + + si = ScenarioIdentifier(name="X") + assert si.pyrit_version == pyrit.__version__ + + +class TestScenarioResult: + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_init_basic(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + si = _make_scenario_identifier() + target_id = _make_component_identifier_dict() + scorer_id = _make_component_identifier_dict("TestScorer") + result = ScenarioResult( + scenario_identifier=si, + objective_target_identifier=target_id, + attack_results={"strat1": []}, + objective_scorer_identifier=scorer_id, + ) + assert result.scenario_identifier is si + assert result.scenario_run_state == "CREATED" + assert result.labels == {} + assert result.number_tries == 0 + assert isinstance(result.id, uuid.UUID) + + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_init_with_explicit_id(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + si = _make_scenario_identifier() + explicit_id = uuid.uuid4() + result = ScenarioResult( + scenario_identifier=si, + objective_target_identifier={}, + attack_results={}, + objective_scorer_identifier={}, + id=explicit_id, + ) + assert result.id == explicit_id + + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_get_strategies_used(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + si = _make_scenario_identifier() + result = ScenarioResult( + scenario_identifier=si, + objective_target_identifier={}, + attack_results={"crescendo": [], "flip": []}, + objective_scorer_identifier={}, + ) + strategies = result.get_strategies_used() + assert sorted(strategies) == ["crescendo", "flip"] + + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_get_objectives_all(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + ar1 = _make_attack_result(objective="obj1") + ar2 = _make_attack_result(objective="obj2") + ar3 = _make_attack_result(objective="obj1") + result = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier={}, + attack_results={"s1": [ar1, ar3], "s2": [ar2]}, + objective_scorer_identifier={}, + ) + objectives = result.get_objectives() + assert sorted(objectives) == ["obj1", "obj2"] + + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_get_objectives_by_attack_name(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + ar1 = _make_attack_result(objective="obj1") + ar2 = _make_attack_result(objective="obj2") + result = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier={}, + attack_results={"s1": [ar1], "s2": [ar2]}, + objective_scorer_identifier={}, + ) + assert result.get_objectives(atomic_attack_name="s1") == ["obj1"] + assert result.get_objectives(atomic_attack_name="nonexistent") == [] + + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_objective_achieved_rate_all(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + results = [ + _make_attack_result(outcome=AttackOutcome.SUCCESS), + _make_attack_result(outcome=AttackOutcome.FAILURE), + _make_attack_result(outcome=AttackOutcome.SUCCESS), + _make_attack_result(outcome=AttackOutcome.UNDETERMINED), + ] + sr = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier={}, + attack_results={"s1": results}, + objective_scorer_identifier={}, + ) + assert sr.objective_achieved_rate() == 50 + + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_objective_achieved_rate_empty(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + sr = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier={}, + attack_results={"s1": []}, + objective_scorer_identifier={}, + ) + assert sr.objective_achieved_rate() == 0 + + @patch("pyrit.identifiers.component_identifier.ComponentIdentifier.normalize") + def test_objective_achieved_rate_by_name(self, mock_normalize): + mock_normalize.side_effect = lambda x: x + sr = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier={}, + attack_results={ + "s1": [_make_attack_result(outcome=AttackOutcome.SUCCESS)], + "s2": [_make_attack_result(outcome=AttackOutcome.FAILURE)], + }, + objective_scorer_identifier={}, + ) + assert sr.objective_achieved_rate(atomic_attack_name="s1") == 100 + assert sr.objective_achieved_rate(atomic_attack_name="s2") == 0 + assert sr.objective_achieved_rate(atomic_attack_name="missing") == 0 + + def test_normalize_scenario_name_snake_case(self): + assert ScenarioResult.normalize_scenario_name("content_harms") == "ContentHarms" + assert ScenarioResult.normalize_scenario_name("foundry") == "foundry" + + def test_normalize_scenario_name_already_pascal(self): + assert ScenarioResult.normalize_scenario_name("ContentHarms") == "ContentHarms" + + def test_normalize_scenario_name_mixed_case_with_underscore(self): + assert ScenarioResult.normalize_scenario_name("Content_harms") == "Content_harms" diff --git a/tests/unit/models/test_seed_attack_group.py b/tests/unit/models/test_seed_attack_group.py new file mode 100644 index 000000000..4321a7fbb --- /dev/null +++ b/tests/unit/models/test_seed_attack_group.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +import pytest + +from pyrit.models.seeds.seed_attack_group import SeedAttackGroup +from pyrit.models.seeds.seed_objective import SeedObjective +from pyrit.models.seeds.seed_prompt import SeedPrompt + + +def _make_prompt(*, value="test prompt", sequence=0): + return SeedPrompt(value=value, data_type="text", role="user", sequence=sequence) + + +def _make_objective(*, value="test objective"): + return SeedObjective(value=value) + + +def test_seed_attack_group_valid_init(): + objective = _make_objective() + prompt = _make_prompt() + group = SeedAttackGroup(seeds=[objective, prompt]) + assert group.objective is objective + assert len(group.seeds) == 2 + + +def test_seed_attack_group_objective_property(): + objective = _make_objective(value="achieve goal") + group = SeedAttackGroup(seeds=[objective, _make_prompt()]) + assert group.objective.value == "achieve goal" + + +def test_seed_attack_group_no_objective_raises(): + prompt = _make_prompt() + with pytest.raises(ValueError, match="exactly one objective"): + SeedAttackGroup(seeds=[prompt]) + + +def test_seed_attack_group_two_objectives_raises(): + obj1 = _make_objective(value="obj1") + obj2 = _make_objective(value="obj2") + prompt = _make_prompt() + with pytest.raises(ValueError, match="one objective"): + SeedAttackGroup(seeds=[obj1, obj2, prompt]) + + +def test_seed_attack_group_empty_seeds_raises(): + with pytest.raises(ValueError): + SeedAttackGroup(seeds=[]) + + +def test_seed_attack_group_consistent_group_id(): + objective = _make_objective() + prompt = _make_prompt() + group = SeedAttackGroup(seeds=[objective, prompt]) + group_ids = {s.prompt_group_id for s in group.seeds} + assert len(group_ids) == 1 + assert None not in group_ids + + +def test_seed_attack_group_with_multiple_prompts(): + objective = _make_objective() + p1 = _make_prompt(value="p1", sequence=0) + p2 = _make_prompt(value="p2", sequence=1) + group = SeedAttackGroup(seeds=[objective, p1, p2]) + assert len(group.prompts) == 2 diff --git a/tests/unit/models/test_seed_dataset.py b/tests/unit/models/test_seed_dataset.py new file mode 100644 index 000000000..8232aaaf9 --- /dev/null +++ b/tests/unit/models/test_seed_dataset.py @@ -0,0 +1,174 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import uuid + +import pytest + +from pyrit.models.seeds.seed_dataset import SeedDataset +from pyrit.models.seeds.seed_objective import SeedObjective +from pyrit.models.seeds.seed_prompt import SeedPrompt + + +def test_seed_dataset_init_from_dicts(): + ds = SeedDataset(seeds=[{"value": "hello"}, {"value": "world"}]) + assert len(ds.seeds) == 2 + assert all(isinstance(s, SeedPrompt) for s in ds.seeds) + + +def test_seed_dataset_init_from_seed_objects(): + sp = SeedPrompt(value="hello", data_type="text", role="user") + ds = SeedDataset(seeds=[sp]) + assert ds.seeds[0] is sp + + +def test_seed_dataset_empty_raises(): + with pytest.raises(ValueError, match="cannot be empty"): + SeedDataset(seeds=[]) + + +def test_seed_dataset_none_raises(): + with pytest.raises(ValueError, match="cannot be empty"): + SeedDataset(seeds=None) + + +def test_seed_dataset_invalid_seed_type_raises(): + with pytest.raises(ValueError, match="dicts or Seed objects"): + SeedDataset(seeds=[42]) + + +def test_seed_dataset_defaults(): + ds = SeedDataset(seeds=[{"value": "hi"}]) + assert ds.data_type == "text" + assert ds.name is None + assert ds.dataset_name is None + assert ds.description is None + assert ds.source is None + assert ds.date_added is not None + + +def test_seed_dataset_with_metadata(): + ds = SeedDataset( + seeds=[{"value": "hi"}], + name="test_ds", + dataset_name="ds1", + description="a dataset", + authors=["author1"], + groups=["group1"], + source="test", + ) + assert ds.name == "test_ds" + assert ds.dataset_name == "ds1" + assert ds.authors == ["author1"] + + +def test_seed_dataset_objective_seeds(): + ds = SeedDataset(seeds=[{"value": "objective text", "seed_type": "objective"}]) + assert len(ds.seeds) == 1 + assert isinstance(ds.seeds[0], SeedObjective) + + +def test_seed_dataset_get_values(): + ds = SeedDataset(seeds=[{"value": "a"}, {"value": "b"}, {"value": "c"}]) + values = ds.get_values() + assert list(values) == ["a", "b", "c"] + + +def test_seed_dataset_get_values_first(): + ds = SeedDataset(seeds=[{"value": "a"}, {"value": "b"}, {"value": "c"}]) + assert list(ds.get_values(first=2)) == ["a", "b"] + + +def test_seed_dataset_get_values_last(): + ds = SeedDataset(seeds=[{"value": "a"}, {"value": "b"}, {"value": "c"}]) + assert list(ds.get_values(last=2)) == ["b", "c"] + + +def test_seed_dataset_get_values_first_and_last_overlap(): + ds = SeedDataset(seeds=[{"value": "a"}, {"value": "b"}]) + assert list(ds.get_values(first=2, last=2)) == ["a", "b"] + + +def test_seed_dataset_get_values_by_harm_category(): + ds = SeedDataset( + seeds=[ + {"value": "a", "harm_categories": ["violence"]}, + {"value": "b", "harm_categories": ["hate"]}, + {"value": "c", "harm_categories": ["violence", "hate"]}, + ] + ) + values = ds.get_values(harm_categories=["violence"]) + assert "a" in values + assert "c" in values + assert "b" not in values + + +def test_seed_dataset_get_random_values(): + ds = SeedDataset(seeds=[{"value": str(i)} for i in range(10)]) + result = ds.get_random_values(number=3) + assert len(result) == 3 + + +def test_seed_dataset_get_random_values_more_than_available(): + ds = SeedDataset(seeds=[{"value": "a"}, {"value": "b"}]) + result = ds.get_random_values(number=10) + assert len(result) == 2 + + +def test_seed_dataset_prompts_property(): + ds = SeedDataset(seeds=[{"value": "p1"}, {"value": "obj", "seed_type": "objective"}]) + assert len(ds.prompts) == 1 + assert isinstance(ds.prompts[0], SeedPrompt) + + +def test_seed_dataset_objectives_property(): + ds = SeedDataset(seeds=[{"value": "p1"}, {"value": "obj", "seed_type": "objective"}]) + assert len(ds.objectives) == 1 + assert isinstance(ds.objectives[0], SeedObjective) + + +def test_seed_dataset_repr(): + ds = SeedDataset(seeds=[{"value": "a"}]) + assert "1 seeds" in repr(ds) + + +def test_seed_dataset_from_dict(): + data = { + "name": "test_ds", + "description": "desc", + "seeds": [{"value": "hello"}, {"value": "world"}], + } + ds = SeedDataset.from_dict(data) + assert len(ds.seeds) == 2 + assert ds.description == "desc" + + +def test_seed_dataset_from_dict_rejects_preset_group_id(): + data = { + "seeds": [{"value": "hello", "prompt_group_id": str(uuid.uuid4())}], + } + with pytest.raises(ValueError, match="prompt_group_id"): + SeedDataset.from_dict(data) + + +def test_seed_dataset_group_seed_prompts_by_group_id(): + gid = uuid.uuid4() + p1 = SeedPrompt(value="a", data_type="text", role="user", prompt_group_id=gid, sequence=0) + obj = SeedObjective(value="objective", prompt_group_id=gid) + groups = SeedDataset.group_seed_prompts_by_prompt_group_id([p1, obj]) + assert len(groups) == 1 + + +def test_seed_dataset_group_without_group_id(): + p1 = SeedPrompt(value="a", data_type="text", role="user") + p2 = SeedPrompt(value="b", data_type="text", role="user") + p1.prompt_group_id = None + p2.prompt_group_id = None + groups = SeedDataset.group_seed_prompts_by_prompt_group_id([p1, p2]) + assert len(groups) == 2 + + +def test_seed_dataset_is_objective_deprecated(): + with pytest.warns(DeprecationWarning, match="is_objective"): + ds = SeedDataset(seeds=[{"value": "obj"}], is_objective=True) + assert isinstance(ds.seeds[0], SeedObjective) diff --git a/tests/unit/models/test_seed_objective.py b/tests/unit/models/test_seed_objective.py new file mode 100644 index 000000000..b3742a2af --- /dev/null +++ b/tests/unit/models/test_seed_objective.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.models.seeds.seed_objective import SeedObjective + + +def test_seed_objective_init(): + obj = SeedObjective(value="test objective") + assert obj.value == "test objective" + assert obj.data_type == "text" + assert obj.is_general_technique is False + + +def test_seed_objective_data_type_always_text(): + obj = SeedObjective(value="objective") + assert obj.data_type == "text" + + +def test_seed_objective_is_general_technique_raises(): + with pytest.raises(ValueError, match="general technique"): + SeedObjective(value="bad", is_general_technique=True) + + +def test_seed_objective_with_metadata(): + obj = SeedObjective( + value="objective", + name="test_obj", + dataset_name="ds", + harm_categories=["violence"], + description="an objective", + ) + assert obj.name == "test_obj" + assert obj.dataset_name == "ds" + assert obj.harm_categories == ["violence"] + assert obj.description == "an objective" + + +def test_seed_objective_jinja_template_rendering(): + obj = SeedObjective(value="Hello {{ name }}", is_jinja_template=True) + assert "name" in obj.value or "Hello" in obj.value + + +def test_seed_objective_non_jinja_template_preserved(): + obj = SeedObjective(value="Hello {{ name }}") + assert obj.value == "Hello {{ name }}" + + +def test_seed_objective_id_auto_generated(): + obj = SeedObjective(value="test") + assert obj.id is not None diff --git a/tests/unit/models/test_seed_prompt.py b/tests/unit/models/test_seed_prompt.py new file mode 100644 index 000000000..16181d4e6 --- /dev/null +++ b/tests/unit/models/test_seed_prompt.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.models.seeds.seed_prompt import SeedPrompt + + +def test_seed_prompt_init_defaults(): + sp = SeedPrompt(value="hello", data_type="text") + assert sp.value == "hello" + assert sp.data_type == "text" + assert sp.role is None + assert sp.sequence == 0 + assert sp.parameters == [] + + +def test_seed_prompt_init_with_all_fields(): + gid = uuid.uuid4() + sp = SeedPrompt( + value="prompt text", + data_type="text", + role="user", + sequence=3, + parameters=["param1", "param2"], + name="test_prompt", + dataset_name="ds", + prompt_group_id=gid, + ) + assert sp.role == "user" + assert sp.sequence == 3 + assert sp.parameters == ["param1", "param2"] + assert sp.prompt_group_id == gid + + +def test_seed_prompt_infers_text_data_type(): + sp = SeedPrompt(value="just text") + assert sp.data_type == "text" + + +@patch("os.path.isfile", return_value=True) +@patch("os.path.splitext", return_value=("/path/file", ".mp4")) +def test_seed_prompt_infers_video_data_type(mock_splitext, mock_isfile): + sp = SeedPrompt(value="/path/file.mp4") + assert sp.data_type == "video_path" + + +@patch("os.path.isfile", return_value=True) +@patch("os.path.splitext", return_value=("/path/file", ".wav")) +def test_seed_prompt_infers_audio_data_type(mock_splitext, mock_isfile): + sp = SeedPrompt(value="/path/file.wav") + assert sp.data_type == "audio_path" + + +@patch("os.path.isfile", return_value=True) +@patch("os.path.splitext", return_value=("/path/file", ".png")) +def test_seed_prompt_infers_image_data_type(mock_splitext, mock_isfile): + sp = SeedPrompt(value="/path/file.png") + assert sp.data_type == "image_path" + + +@patch("os.path.isfile", return_value=True) +@patch("os.path.splitext", return_value=("/path/file", ".xyz")) +def test_seed_prompt_unknown_file_extension_raises(mock_splitext, mock_isfile): + with pytest.raises(ValueError, match="Unable to infer data_type"): + SeedPrompt(value="/path/file.xyz") + + +def test_seed_prompt_explicit_data_type_not_overridden(): + sp = SeedPrompt(value="some text", data_type="text") + assert sp.data_type == "text" + + +def test_seed_prompt_jinja_template_rendering(): + sp = SeedPrompt(value="Hello {{ name }}", data_type="text", is_jinja_template=True) + assert "name" in sp.value or "Hello" in sp.value + + +def test_seed_prompt_non_jinja_preserved(): + sp = SeedPrompt(value="Hello {{ name }}", data_type="text") + assert sp.value == "Hello {{ name }}" + + +def test_seed_prompt_from_messages(): + piece_mock = MagicMock() + piece_mock.converted_value = "test value" + piece_mock.converted_value_data_type = "text" + + message_mock = MagicMock() + message_mock.api_role = "user" + message_mock.message_pieces = [piece_mock] + + result = SeedPrompt.from_messages([message_mock]) + assert len(result) == 1 + assert result[0].value == "test value" + assert result[0].role == "user" + assert result[0].sequence == 0 + + +def test_seed_prompt_from_messages_multiple(): + piece1 = MagicMock() + piece1.converted_value = "user msg" + piece1.converted_value_data_type = "text" + msg1 = MagicMock() + msg1.api_role = "user" + msg1.message_pieces = [piece1] + + piece2 = MagicMock() + piece2.converted_value = "assistant msg" + piece2.converted_value_data_type = "text" + msg2 = MagicMock() + msg2.api_role = "assistant" + msg2.message_pieces = [piece2] + + result = SeedPrompt.from_messages([msg1, msg2]) + assert len(result) == 2 + assert result[0].role == "user" + assert result[0].sequence == 0 + assert result[1].role == "assistant" + assert result[1].sequence == 1 + + +def test_seed_prompt_from_messages_with_group_id(): + piece = MagicMock() + piece.converted_value = "val" + piece.converted_value_data_type = "text" + msg = MagicMock() + msg.api_role = "user" + msg.message_pieces = [piece] + + gid = uuid.uuid4() + result = SeedPrompt.from_messages([msg], prompt_group_id=gid) + assert result[0].prompt_group_id == gid + + +def test_seed_prompt_from_messages_with_starting_sequence(): + piece = MagicMock() + piece.converted_value = "val" + piece.converted_value_data_type = "text" + msg = MagicMock() + msg.api_role = "user" + msg.message_pieces = [piece] + + result = SeedPrompt.from_messages([msg], starting_sequence=5) + assert result[0].sequence == 5 diff --git a/tests/unit/models/test_strategy_result.py b/tests/unit/models/test_strategy_result.py new file mode 100644 index 000000000..15eceb5cb --- /dev/null +++ b/tests/unit/models/test_strategy_result.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass + +from pyrit.models.strategy_result import StrategyResult + + +@dataclass +class ConcreteResult(StrategyResult): + value: str = "" + count: int = 0 + + +def test_strategy_result_duplicate_creates_deep_copy(): + original = ConcreteResult(value="hello", count=5) + copy = original.duplicate() + assert copy.value == "hello" + assert copy.count == 5 + assert copy is not original + + +def test_strategy_result_duplicate_is_independent(): + original = ConcreteResult(value="hello", count=5) + copy = original.duplicate() + copy.value = "changed" + copy.count = 99 + assert original.value == "hello" + assert original.count == 5 + + +def test_strategy_result_duplicate_preserves_type(): + original = ConcreteResult(value="test", count=1) + copy = original.duplicate() + assert type(copy) is ConcreteResult