Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions tests/unit/models/test_chat_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json

import pytest
from pydantic import ValidationError

from pyrit.models.chat_message import (
ChatMessage,
ChatMessageListDictContent,
ChatMessagesDataset,
ToolCall,
)


def test_tool_call_init():
tc = ToolCall(id="call_1", type="function", function="get_weather")
assert tc.id == "call_1"
assert tc.type == "function"
assert tc.function == "get_weather"


def test_tool_call_forbids_extra_fields():
with pytest.raises(ValidationError):
ToolCall(id="call_1", type="function", function="get_weather", extra="bad")


def test_chat_message_init_with_string_content():
msg = ChatMessage(role="user", content="hello")
assert msg.role == "user"
assert msg.content == "hello"
assert msg.name is None
assert msg.tool_calls is None
assert msg.tool_call_id is None


def test_chat_message_init_with_list_content():
parts = [{"type": "text", "text": "hello"}, {"type": "image_url", "url": "http://img.png"}]
msg = ChatMessage(role="assistant", content=parts)
assert msg.content == parts


def test_chat_message_init_with_all_fields():
tc = ToolCall(id="call_1", type="function", function="lookup")
msg = ChatMessage(
role="assistant",
content="result",
name="helper",
tool_calls=[tc],
tool_call_id="call_1",
)
assert msg.name == "helper"
assert msg.tool_calls == [tc]
assert msg.tool_call_id == "call_1"


def test_chat_message_forbids_extra_fields():
with pytest.raises(ValidationError):
ChatMessage(role="user", content="hi", extra_field="bad")


def test_chat_message_invalid_role():
with pytest.raises(ValidationError):
ChatMessage(role="invalid_role", content="hi")


def test_chat_message_to_json():
msg = ChatMessage(role="user", content="test")
json_str = msg.to_json()
parsed = json.loads(json_str)
assert parsed["role"] == "user"
assert parsed["content"] == "test"


def test_chat_message_to_dict_excludes_none():
msg = ChatMessage(role="user", content="test")
d = msg.to_dict()
assert "name" not in d
assert "tool_calls" not in d
assert "tool_call_id" not in d
assert d["role"] == "user"
assert d["content"] == "test"


def test_chat_message_to_dict_includes_non_none():
msg = ChatMessage(role="user", content="test", name="bot")
d = msg.to_dict()
assert d["name"] == "bot"


def test_chat_message_from_json():
original = ChatMessage(role="system", content="you are helpful")
json_str = original.to_json()
restored = ChatMessage.from_json(json_str)
assert restored.role == original.role
assert restored.content == original.content


def test_chat_message_from_json_roundtrip_with_tool_calls():
tc = ToolCall(id="c1", type="function", function="fn")
original = ChatMessage(role="assistant", content="ok", tool_calls=[tc], tool_call_id="c1")
restored = ChatMessage.from_json(original.to_json())
assert restored.tool_calls[0].id == "c1"
assert restored.tool_call_id == "c1"


@pytest.mark.parametrize("role", ["system", "user", "assistant", "simulated_assistant", "tool", "developer"])
def test_chat_message_accepts_all_valid_roles(role):
msg = ChatMessage(role=role, content="test")
assert msg.role == role


def test_chat_message_list_dict_content_deprecated(capsys):
msg = ChatMessageListDictContent(role="user", content="hello")
assert msg.role == "user"
assert msg.content == "hello"


def test_chat_messages_dataset_init():
msgs = [[ChatMessage(role="user", content="hi"), ChatMessage(role="assistant", content="hello")]]
dataset = ChatMessagesDataset(name="test_ds", description="A test dataset", list_of_chat_messages=msgs)
assert dataset.name == "test_ds"
assert dataset.description == "A test dataset"
assert len(dataset.list_of_chat_messages) == 1
assert len(dataset.list_of_chat_messages[0]) == 2


def test_chat_messages_dataset_forbids_extra_fields():
with pytest.raises(ValidationError):
ChatMessagesDataset(
name="ds",
description="desc",
list_of_chat_messages=[],
extra="bad",
)
78 changes: 78 additions & 0 deletions tests/unit/models/test_conversation_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytest

from pyrit.models.conversation_reference import ConversationReference, ConversationType


def test_conversation_type_values():
assert ConversationType.ADVERSARIAL.value == "adversarial"
assert ConversationType.PRUNED.value == "pruned"
assert ConversationType.SCORE.value == "score"
assert ConversationType.CONVERTER.value == "converter"


def test_conversation_reference_init():
ref = ConversationReference(conversation_id="abc-123", conversation_type=ConversationType.ADVERSARIAL)
assert ref.conversation_id == "abc-123"
assert ref.conversation_type == ConversationType.ADVERSARIAL
assert ref.description is None


def test_conversation_reference_with_description():
ref = ConversationReference(
conversation_id="abc-123",
conversation_type=ConversationType.PRUNED,
description="pruned branch",
)
assert ref.description == "pruned branch"


def test_conversation_reference_is_frozen():
ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.SCORE)
with pytest.raises(AttributeError):
ref.conversation_id = "new_id"


def test_conversation_reference_hash():
ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL)
assert hash(ref) == hash("abc")


def test_conversation_reference_eq_same_id():
ref1 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL)
ref2 = ConversationReference(
conversation_id="abc",
conversation_type=ConversationType.PRUNED,
description="different",
)
assert ref1 == ref2


def test_conversation_reference_eq_different_id():
ref1 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL)
ref2 = ConversationReference(conversation_id="xyz", conversation_type=ConversationType.ADVERSARIAL)
assert ref1 != ref2


def test_conversation_reference_eq_non_reference():
ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL)
assert ref != "abc"
assert ref != 42
assert ref != None # noqa: E711


def test_conversation_reference_usable_in_set():
ref1 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL)
ref2 = ConversationReference(conversation_id="abc", conversation_type=ConversationType.PRUNED)
ref3 = ConversationReference(conversation_id="xyz", conversation_type=ConversationType.SCORE)
s = {ref1, ref2, ref3}
assert len(s) == 2


def test_conversation_reference_usable_as_dict_key():
ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.CONVERTER)
d = {ref: "value"}
lookup_ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL)
assert d[lookup_ref] == "value"
46 changes: 46 additions & 0 deletions tests/unit/models/test_conversation_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from datetime import datetime, timezone

import pytest

from pyrit.models.conversation_stats import ConversationStats


def test_conversation_stats_defaults():
stats = ConversationStats()
assert stats.message_count == 0
assert stats.last_message_preview is None
assert stats.labels == {}
assert stats.created_at is None


def test_conversation_stats_with_values():
now = datetime.now(timezone.utc)
stats = ConversationStats(
message_count=5,
last_message_preview="Hello world",
labels={"env": "test"},
created_at=now,
)
assert stats.message_count == 5
assert stats.last_message_preview == "Hello world"
assert stats.labels == {"env": "test"}
assert stats.created_at == now


def test_conversation_stats_is_frozen():
stats = ConversationStats(message_count=3)
with pytest.raises(AttributeError):
stats.message_count = 10


def test_conversation_stats_preview_max_len_class_var():
assert ConversationStats.PREVIEW_MAX_LEN == 100


def test_conversation_stats_labels_default_factory():
stats1 = ConversationStats()
stats2 = ConversationStats()
assert stats1.labels is not stats2.labels
Loading
Loading