diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index da28bc01b..9dd3561f4 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -8,7 +8,7 @@ from contextlib import closing, suppress from datetime import datetime from pathlib import Path -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, TypeVar, Union, cast from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine @@ -35,6 +35,14 @@ Model = TypeVar("Model") +class _ExportableConversationPiece: + def __init__(self, data: dict[str, Any]) -> None: + self._data = data + + def to_dict(self) -> dict[str, Any]: + return self._data + + class SQLiteMemory(MemoryInterface, metaclass=Singleton): """ A memory interface that uses SQLite as the backend database. @@ -474,6 +482,9 @@ def export_conversations( Returns: Path: The path to the exported file. + + Raises: + ValueError: If the specified export format is not supported. """ # Import here to avoid circular import issues from pyrit.memory.memory_exporter import MemoryExporter @@ -522,9 +533,20 @@ def export_conversations( piece_data["scores"] = [score.to_dict() for score in piece_scores] merged_data.append(piece_data) - # Export to JSON manually since the exporter expects objects but we have dicts - with open(file_path, "w") as f: - json.dump(merged_data, f, indent=4) + if not merged_data: + if export_type == "json": + with open(file_path, "w", encoding="utf-8") as f: + json.dump(merged_data, f, indent=4) + elif export_type in self.exporter.export_strategies: + file_path.write_text("", encoding="utf-8") + else: + raise ValueError(f"Unsupported export format: {export_type}") + return file_path + + exportable_pieces = [_ExportableConversationPiece(data=piece_data) for piece_data in merged_data] + self.exporter.export_data( + cast("list[MessagePiece]", exportable_pieces), file_path=file_path, export_type=export_type + ) return file_path def print_schema(self) -> None: diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py index 14064b718..42d4e6d80 100644 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ b/tests/unit/memory/memory_interface/test_interface_export.py @@ -1,12 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import csv +import json import os import tempfile from collections.abc import Sequence from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + from pyrit.common.path import DB_DATA_PATH from pyrit.memory import MemoryExporter, MemoryInterface from pyrit.models import MessagePiece @@ -103,8 +107,6 @@ def test_export_all_conversations_with_scores_correct_data(sqlite_instance: Memo assert file_path.exists() # Read and verify the exported JSON content - import json - with open(file_path) as f: exported_data = json.load(f) @@ -141,8 +143,6 @@ def test_export_all_conversations_with_scores_empty_data(sqlite_instance: Memory assert file_path.exists() # Read and verify the exported JSON content is empty - import json - with open(file_path) as f: exported_data = json.load(f) @@ -151,3 +151,57 @@ def test_export_all_conversations_with_scores_empty_data(sqlite_instance: Memory # Clean up the temp file if file_path.exists(): os.remove(file_path) + + +@pytest.mark.parametrize("export_type, suffix", [("json", ".json"), ("csv", ".csv"), ("md", ".md")]) +def test_export_all_conversations_with_scores_respects_export_type( + sqlite_instance: MemoryInterface, export_type: str, suffix: str +): + sqlite_instance.exporter = MemoryExporter() + + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: + file_path = Path(temp_file.name) + temp_file.close() + + try: + with ( + patch.object(sqlite_instance, "get_message_pieces") as mock_get_pieces, + patch.object(sqlite_instance, "get_prompt_scores") as mock_get_scores, + ): + mock_piece = MagicMock() + mock_piece.id = "piece_id_1234" + mock_piece.to_dict.return_value = { + "id": "piece_id_1234", + "converted_value": "sample piece", + } + + mock_score = MagicMock() + mock_score.message_piece_id = "piece_id_1234" + mock_score.to_dict.return_value = {"message_piece_id": "piece_id_1234", "score_value": 10} + + mock_get_pieces.return_value = [mock_piece] + mock_get_scores.return_value = [mock_score] + + sqlite_instance.export_conversations(file_path=file_path, export_type=export_type) + + assert file_path.exists() + exported_content = file_path.read_text(encoding="utf-8") + assert "piece_id_1234" in exported_content + assert "sample piece" in exported_content + + if export_type == "json": + exported_data = json.loads(exported_content) + assert len(exported_data) == 1 + assert exported_data[0]["id"] == "piece_id_1234" + elif export_type == "csv": + with open(file_path, newline="") as exported_file: + reader = csv.DictReader(exported_file) + assert reader.fieldnames == ["id", "converted_value", "scores"] + rows = list(reader) + assert len(rows) == 1 + assert rows[0]["id"] == "piece_id_1234" + elif export_type == "md": + assert exported_content.startswith("| id | converted_value | scores |") + finally: + if file_path.exists(): + os.remove(file_path)