diff --git a/docs/examples/citation_requirement_example.py b/docs/examples/citation_requirement_example.py new file mode 100644 index 000000000..580143732 --- /dev/null +++ b/docs/examples/citation_requirement_example.py @@ -0,0 +1,126 @@ +# pytest: huggingface, llm, requires_heavy_ram +"""Example demonstrating CitationRequirement for RAG workflows. + +This example shows how to use CitationRequirement to validate that +assistant responses properly cite their sources in RAG workflows. + +Note: This example requires HuggingFace backend and access to the +ibm-granite/granite-4.0-micro model. +""" + +import asyncio + +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.rag import CitationMode, CitationRequirement + + +async def main(): + """Demonstrate CitationRequirement usage.""" + print("=" * 70) + print("CitationRequirement Example") + print("=" * 70) + + # Initialize HuggingFace backend + print("\nInitializing HuggingFace backend...") + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") + + # Create documents + docs = [ + Document( + doc_id="doc1", + title="Sky Facts", + text="The sky appears blue during the day due to Rayleigh scattering.", + ), + Document( + doc_id="doc2", + title="Grass Facts", + text="Grass is typically green because of chlorophyll in the leaves.", + ), + ] + + # Create a response that should have citations + response = ( + "The sky appears blue during the day. " + "Grass is green because it contains chlorophyll." + ) + + # Create context + print("\nCreating context with user question and assistant response...") + ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Example 1: Documents in constructor + print("\n--- Example 1: CitationRequirement with documents in constructor ---") + req = CitationRequirement(min_citation_coverage=0.7, documents=docs) + result = await req.validate(backend, ctx) + + print(f"Validation passed: {result.as_bool()}") + print(f"Citation coverage score: {result.score:.2%}") + if result.reason: + reason_preview = ( + result.reason[:200] + "..." if len(result.reason) > 200 else result.reason + ) + print(f"Reason: {reason_preview}") + + # Example 2: CLAIMS mode (default) - counts fraction of claims with citations + print("\n--- Example 2: CLAIMS mode (default) - fraction of claims cited ---") + req2 = CitationRequirement( + min_citation_coverage=0.7, documents=docs, mode=CitationMode.CLAIMS + ) + result2 = await req2.validate(backend, ctx) + + print(f"Validation passed: {result2.as_bool()}") + print(f"Citation coverage score: {result2.score:.2%}") + if result2.reason: + reason_preview = ( + result2.reason[:200] + "..." + if len(result2.reason) > 200 + else result2.reason + ) + print(f"Reason: {reason_preview}") + + # Example 3: CHARACTERS mode - calculates character-based coverage + print("\n--- Example 3: CHARACTERS mode - character-based coverage ---") + req3 = CitationRequirement( + min_citation_coverage=0.7, documents=docs, mode=CitationMode.CHARACTERS + ) + result3 = await req3.validate(backend, ctx) + + print(f"Validation passed: {result3.as_bool()}") + print(f"Citation coverage score: {result3.score:.2%}") + if result3.reason: + reason_preview = ( + result3.reason[:200] + "..." + if len(result3.reason) > 200 + else result3.reason + ) + print(f"Reason: {reason_preview}") + + # Example 4: Documents attached to message + print("\n--- Example 4: Documents in message (not constructor) ---") + ctx2 = ChatContext().add(Message("user", "Tell me about Mars.")) + ctx2 = ctx2.add( + Message( + "assistant", + "Mars is the fourth planet from the Sun.", + documents=[ + Document(doc_id="doc1", text="Mars is the fourth planet from the Sun.") + ], + ) + ) + + req4 = CitationRequirement(min_citation_coverage=0.7) # No documents in constructor + result4 = await req4.validate(backend, ctx2) + + print(f"Validation passed: {result4.as_bool()}") + print(f"Citation coverage score: {result4.score:.2%}") + + print("\n" + "=" * 70) + print("Example completed successfully!") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index c0bd7d3c9..ef03524ae 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -4,6 +4,7 @@ from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table from .python_reqs import PythonExecutionReq +from .rag import CitationMode, CitationRequirement from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -17,6 +18,8 @@ __all__ = [ "ALoraRequirement", + "CitationMode", + "CitationRequirement", "LLMaJRequirement", "PythonExecutionReq", "Requirement", diff --git a/mellea/stdlib/requirements/rag.py b/mellea/stdlib/requirements/rag.py new file mode 100644 index 000000000..705a2a012 --- /dev/null +++ b/mellea/stdlib/requirements/rag.py @@ -0,0 +1,321 @@ +"""Requirements for RAG (Retrieval-Augmented Generation) workflows.""" + +from collections.abc import Iterable +from enum import Enum + +from ...backends.adapters import AdapterMixin +from ...core import Backend, Context, Requirement, ValidationResult +from ..components import Document, Message +from ..context import ChatContext + + +class CitationMode(Enum): + """Mode for calculating citation coverage. + + Attributes: + CLAIMS: Count the fraction of factual claims that have citations. + Each citation record from find_citations represents one claim. + CHARACTERS: Calculate the ratio of cited characters to total characters. + Sums character ranges covered by citations. + """ + + CLAIMS = "claims" + CHARACTERS = "characters" + + +class CitationRequirement(Requirement): + """Requirement that validates RAG responses have adequate citation coverage. + + Uses the find_citations intrinsic to identify which parts of an assistant's + response are supported by explicit citations to retrieved documents. Content + without citations below the minimum coverage threshold fails validation. + + **Important**: This requirement requires a HuggingFace backend (LocalHFBackend) + as the find_citations intrinsic only works with HuggingFace models. Using other + backends (Ollama, OpenAI, etc.) will result in a validation error. + + This requirement is designed for RAG workflows where you want to ensure + responses properly cite their sources. It works with: + - A user question in the context + - Retrieved documents + - An assistant response to validate + + Documents can be provided either: + 1. In the constructor (for reusable requirements with fixed documents) + 2. Attached to the assistant message in the context (for dynamic documents) + + Args: + min_citation_coverage: Minimum ratio of cited content (0.0-1.0). + Interpretation depends on mode: + - CLAIMS mode: fraction of factual claims with citations + - CHARACTERS mode: ratio of cited characters to total characters + Default is 0.8 (80% coverage). + documents: Optional documents to validate against. Can be Document + objects or strings (will be converted to Documents). If provided, + these documents will be used instead of documents attached to + messages in the context. Default is None (use context documents). + mode: Citation coverage calculation mode. Default is CitationMode.CLAIMS + (count fraction of claims with citations). Use CitationMode.CHARACTERS + to calculate character-based coverage ratio instead. + description: Custom description for the requirement. If None, + generates a description based on coverage threshold and mode. + + Example: + ```python + from mellea.backends.huggingface import LocalHFBackend + from mellea.stdlib.requirements.rag import CitationRequirement + + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Option 1: Documents in constructor + req = CitationRequirement( + documents=doc_objects, + min_citation_coverage=0.8 + ) + + # Option 2: Documents in context (original pattern) + req = CitationRequirement(min_citation_coverage=0.8) + ctx = ChatContext().add( + Message("assistant", response, documents=doc_objects) + ) + ``` + """ + + def __init__( + self, + min_citation_coverage: float = 0.8, + documents: Iterable[Document] | Iterable[str] | None = None, + mode: CitationMode = CitationMode.CLAIMS, + description: str | None = None, + ): + """Initialize citation coverage requirement.""" + if not 0.0 <= min_citation_coverage <= 1.0: + raise ValueError( + f"min_citation_coverage must be between 0.0 and 1.0, got {min_citation_coverage}" + ) + + self.min_citation_coverage = min_citation_coverage + self.mode = mode + + # Convert documents to Document objects if provided + if documents is not None: + self.documents: list[Document] | None = [ + doc + if isinstance(doc, Document) + else Document(doc_id=str(i), text=str(doc)) + for i, doc in enumerate(documents) + ] + else: + self.documents = None + + # Generate description if not provided + if description is None: + if mode == CitationMode.CLAIMS: + description = ( + f"Response must have adequate citation coverage " + f"(minimum {min_citation_coverage * 100:.0f}% of factual claims cited)" + ) + else: # CitationMode.CHARACTERS + description = ( + f"Response must have adequate citation coverage " + f"(minimum {min_citation_coverage * 100:.0f}% of characters cited)" + ) + + # Initialize parent without validation function - we override validate() instead + super().__init__(description=description, validation_fn=None) + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + """Validate citation coverage in the context using the backend. + + Args: + backend: Backend to use for citation detection. Must be LocalHFBackend + as the find_citations intrinsic only works with HuggingFace models. + ctx: Context containing the conversation history + format: Unused for this requirement + model_options: Unused for this requirement + + Returns: + ValidationResult with pass/fail status, reason, and score + """ + # Extract last message (should be assistant response) + messages = ctx.as_list() + if not messages: + return ValidationResult( + False, reason="Context is empty, cannot validate citation coverage" + ) + + last_message = messages[-1] + if not isinstance(last_message, Message): + return ValidationResult( + False, + reason="Last context item is not a Message, cannot validate citation coverage", + ) + + if last_message.role != "assistant": + return ValidationResult( + False, + reason=f"Last message must be assistant response, got role: {last_message.role}", + ) + + response = last_message.content + + # Use constructor documents if provided, otherwise get from message + if self.documents is not None: + documents = self.documents + else: + # Access private _docs attribute since documents property returns formatted strings + documents = last_message._docs or [] + + if not documents: + return ValidationResult( + False, + reason="No documents provided for citation validation. " + "Either pass documents to CitationRequirement constructor " + "or attach them to the assistant message.", + ) + + # Check backend compatibility + if not isinstance(backend, AdapterMixin): + return ValidationResult( + False, + reason=f"Backend {backend.__class__.__name__} does not support adapters required for citation detection", + ) + + # Create context before the response by getting all but the last message + all_messages = ctx.as_list() + if len(all_messages) > 1: + # Rebuild context without last message + context_before_response = ChatContext() + for msg in all_messages[:-1]: + context_before_response = context_before_response.add(msg) + else: + # If only one message, use empty context + context_before_response = ChatContext() + + # Handle empty response before calling intrinsic + total_chars = len(response) + if total_chars == 0: + return ValidationResult( + True, + reason="Empty response is considered to have adequate citation coverage", + score=1.0, + ) + + # Call find_citations intrinsic + try: + # Import here to avoid circular dependency with backends + from ..components.intrinsic import rag + + citations: list[dict] = rag.find_citations( + response, documents, context_before_response, backend + ) + except Exception as e: + return ValidationResult( + False, reason=f"Citation detection intrinsic failed: {e!s}" + ) + + # Calculate citation coverage based on mode + if self.mode == CitationMode.CLAIMS: + # Count fraction of claims (citation records) that exist + # Each citation record represents a factual claim that has a citation + # We need to estimate total claims in the response + # For now, use a simple heuristic: split by sentence-ending punctuation + import re + + # Split response into sentences (simple heuristic) + sentences = re.split(r"[.!?]+", response) + # Filter out empty strings and whitespace-only strings + sentences = [s.strip() for s in sentences if s.strip()] + total_claims = len(sentences) + + if total_claims == 0: + # Edge case: no sentences detected + coverage_ratio = 1.0 if len(citations) == 0 else 0.0 + else: + # Number of claims with citations = number of citation records + cited_claims = len(citations) + coverage_ratio = cited_claims / total_claims + else: # CitationMode.CHARACTERS + # Calculate character-based coverage + cited_chars = sum( + citation["response_end"] - citation["response_begin"] + for citation in citations + ) + coverage_ratio = cited_chars / total_chars + + # Check against min_citation_coverage + passed = coverage_ratio >= self.min_citation_coverage + + # Build detailed reason + reason = self._build_reason(citations, coverage_ratio, passed) + + return ValidationResult(passed, reason=reason, score=coverage_ratio) + + def _build_reason( + self, citations: list[dict], coverage_ratio: float, passed: bool + ) -> str: + """Build a detailed reason string for the validation result. + + Args: + citations: List of citation records from find_citations + coverage_ratio: Ratio of cited content + passed: Whether validation passed + + Returns: + Detailed reason string + """ + num_citations = len(citations) + coverage_pct = coverage_ratio * 100 + threshold_pct = self.min_citation_coverage * 100 + + if self.mode == CitationMode.CLAIMS: + metric_name = "claims" + else: + metric_name = "characters" + + if passed: + reason = ( + f"Response has adequate citation coverage " + f"({coverage_pct:.1f}% of {metric_name} cited, threshold: {threshold_pct:.1f}%)" + ) + else: + reason = ( + f"Response has insufficient citation coverage " + f"({coverage_pct:.1f}% of {metric_name} cited, threshold: {threshold_pct:.1f}%)" + ) + + # Add details about citations + if citations: + reason += f"\n\nCitations found ({num_citations}):" + for i, citation in enumerate(citations[:5]): # Show first 5 + response_text = citation["response_text"].strip() + doc_id = citation.get("citation_doc_id", "unknown") + citation_text = citation.get("citation_text", "").strip() + # Truncate long texts + if len(response_text) > 60: + response_text = response_text[:57] + "..." + if len(citation_text) > 60: + citation_text = citation_text[:57] + "..." + reason += f"\n {i + 1}. '{response_text}' → Document '{doc_id}'" + if citation_text: + reason += f"\n Source: '{citation_text}'" + + if len(citations) > 5: + reason += f"\n ... and {len(citations) - 5} more citation(s)" + else: + reason += "\n\nNo citations found in the response." + + if not passed: + uncited_pct = 100.0 - coverage_pct + reason += ( + f"\n\nUncited content represents {uncited_pct:.1f}% of the response." + ) + + return reason diff --git a/test/stdlib/requirements/test_rag_requirements.py b/test/stdlib/requirements/test_rag_requirements.py new file mode 100644 index 000000000..29e80d4ae --- /dev/null +++ b/test/stdlib/requirements/test_rag_requirements.py @@ -0,0 +1,405 @@ +"""Tests for RAG requirements.""" +# pytest: huggingface, llm, requires_heavy_ram + +import pytest + +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.rag import CitationMode, CitationRequirement + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_requirement_basic(): + """Test basic citation requirement functionality.""" + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") + + # Create documents + docs = [ + Document(doc_id="doc1", text="The sky is blue during the day."), + Document(doc_id="doc2", text="Grass is typically green in color."), + ] + + # Create a response that should have citations + response = "The sky is blue. Grass is green." + + # Create context with assistant message + ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.5) + + # Validate + result = await req.validate(backend, ctx) + + # Should pass if citations are found + assert isinstance(result.score, float) + assert 0.0 <= result.score <= 1.0 + assert result.reason is not None + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_requirement_with_constructor_documents(): + """Test citation requirement with documents in constructor.""" + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") + + # Create documents + docs = [ + Document(doc_id="doc1", text="The sky is blue during the day."), + Document(doc_id="doc2", text="Grass is typically green in color."), + ] + + # Create a response + response = "The sky is blue. Grass is green." + + # Create context with assistant message (no documents attached) + ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) + ctx = ctx.add(Message("assistant", response)) + + # Create requirement with documents in constructor + req = CitationRequirement(min_citation_coverage=0.5, documents=docs) + + # Validate + result = await req.validate(backend, ctx) + + # Should use constructor documents + assert isinstance(result.score, float) + assert result.reason is not None + + +async def test_citation_requirement_empty_context(): + """Test citation requirement with empty context.""" + # Create a mock backend - we don't need a real one for this test + # since validation fails before backend is used + from unittest.mock import Mock + + backend = Mock(spec=LocalHFBackend) + + # Create empty context + ctx = ChatContext() + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error + assert not result.as_bool() + assert result.reason is not None + assert "empty" in result.reason.lower() + + +async def test_citation_requirement_wrong_message_role(): + """Test citation requirement with non-assistant last message.""" + # Create a mock backend - we don't need a real one for this test + from unittest.mock import Mock + + backend = Mock(spec=LocalHFBackend) + + # Create context ending with user message + ctx = ChatContext().add(Message("user", "What color is the sky?")) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error + assert not result.as_bool() + assert result.reason is not None + assert "assistant" in result.reason.lower() + + +async def test_citation_requirement_no_documents(): + """Test citation requirement with no documents provided.""" + # Create a mock backend - we don't need a real one for this test + from unittest.mock import Mock + + backend = Mock(spec=LocalHFBackend) + + # Create context without documents + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", "The sky is blue.")) + + # Create requirement without documents + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error about missing documents + assert not result.as_bool() + assert result.reason is not None + assert "documents" in result.reason.lower() + + +async def test_citation_requirement_wrong_backend(): + """Test citation requirement with non-adapter backend.""" + from unittest.mock import Mock + + # Create a mock backend that doesn't support adapters + backend = Mock() + backend.__class__.__name__ = "MockBackend" + + # Create documents + docs = [Document(doc_id="doc1", text="The sky is blue.")] + + # Create context + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", "The sky is blue.", documents=docs)) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error about adapter requirement + assert not result.as_bool() + assert result.reason is not None + assert "adapter" in result.reason.lower() + + +def test_citation_requirement_invalid_coverage(): + """Test citation requirement with invalid coverage values.""" + # Test coverage > 1.0 + with pytest.raises(ValueError, match=r"between 0\.0 and 1\.0"): + CitationRequirement(min_citation_coverage=1.5) + + # Test coverage < 0.0 + with pytest.raises(ValueError, match=r"between 0\.0 and 1\.0"): + CitationRequirement(min_citation_coverage=-0.5) + + +def test_citation_requirement_string_documents(): + """Test citation requirement with string documents.""" + # Should convert strings to Document objects + req = CitationRequirement( + min_citation_coverage=0.8, documents=["The sky is blue.", "Grass is green."] + ) + + # Check documents were converted + assert req.documents is not None + assert len(req.documents) == 2 + assert all(isinstance(doc, Document) for doc in req.documents) + + +def test_citation_requirement_custom_description(): + """Test citation requirement with custom description.""" + custom_desc = "Custom citation requirement description" + req = CitationRequirement(min_citation_coverage=0.8, description=custom_desc) + + assert req.description == custom_desc + + +def test_citation_requirement_default_description(): + """Test citation requirement generates default description.""" + req = CitationRequirement(min_citation_coverage=0.75) + + assert req.description is not None + assert "75%" in req.description or "0.75" in req.description + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_requirement_empty_response(): + """Test citation requirement with empty response.""" + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") + + # Create documents + docs = [Document(doc_id="doc1", text="The sky is blue.")] + + # Create context with empty response + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", "", documents=docs)) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Empty response should pass (considered to have adequate coverage) + assert result.as_bool() + assert result.score == 1.0 + assert result.reason is not None + assert "adequate citation coverage" in result.reason.lower() + + +async def test_citation_requirement_threshold_boundary(): + """Test citation requirement at exact threshold boundary. + + This test mocks the find_citations intrinsic to return a controlled + result that produces exactly the threshold coverage (80%). + """ + from unittest.mock import Mock, patch + + backend = Mock(spec=LocalHFBackend) + + # Create documents + docs = [Document(doc_id="doc1", text="The sky is blue during the day.")] + + # Create a response with 10 characters + response = "1234567890" + + # Create context + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Mock find_citations to return exactly 8 characters cited (80% of 10) + mock_citations = [ + { + "response_begin": 0, + "response_end": 8, # 8 characters cited + "response_text": "12345678", + "citation_doc_id": "doc1", + "citation_text": "The sky is blue", + } + ] + + with patch( + "mellea.stdlib.components.intrinsic.rag.find_citations", + return_value=mock_citations, + ): + # Test at exact threshold (0.8) with CHARACTERS mode + req = CitationRequirement( + min_citation_coverage=0.8, mode=CitationMode.CHARACTERS + ) + result = await req.validate(backend, ctx) + + # At exact threshold, should pass (>= comparison) + assert result.as_bool() + assert result.score == 0.8 + + # Test just below threshold (0.81) with CHARACTERS mode + req_above = CitationRequirement( + min_citation_coverage=0.81, mode=CitationMode.CHARACTERS + ) + result_above = await req_above.validate(backend, ctx) + + # Just below threshold, should fail + assert not result_above.as_bool() + assert result_above.score == 0.8 + + +async def test_citation_requirement_claims_mode(): + """Test citation requirement with CLAIMS mode.""" + from unittest.mock import Mock, patch + + backend = Mock(spec=LocalHFBackend) + + # Create documents + docs = [ + Document(doc_id="doc1", text="The sky is blue."), + Document(doc_id="doc2", text="Grass is green."), + ] + + # Create a response with 3 sentences + response = "The sky is blue. Grass is green. Water is wet." + + # Create context + ctx = ChatContext().add(Message("user", "Tell me some facts.")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Mock find_citations to return 2 citations (2 out of 3 claims = 66.7%) + mock_citations = [ + { + "response_begin": 0, + "response_end": 16, + "response_text": "The sky is blue", + "citation_doc_id": "doc1", + "citation_text": "The sky is blue.", + }, + { + "response_begin": 18, + "response_end": 34, + "response_text": "Grass is green", + "citation_doc_id": "doc2", + "citation_text": "Grass is green.", + }, + ] + + with patch( + "mellea.stdlib.components.intrinsic.rag.find_citations", + return_value=mock_citations, + ): + # Test with CLAIMS mode (default) - 2 citations out of 3 sentences = 66.7% + req = CitationRequirement(min_citation_coverage=0.6) + result = await req.validate(backend, ctx) + + # Should pass (66.7% >= 60%) + assert result.as_bool() + assert result.score is not None + assert abs(result.score - 0.667) < 0.01 # Allow small floating point error + assert result.reason is not None + assert "claims" in result.reason + + # Test with higher threshold that should fail + req_high = CitationRequirement(min_citation_coverage=0.7) + result_high = await req_high.validate(backend, ctx) + + # Should fail (66.7% < 70%) + assert not result_high.as_bool() + assert result_high.score is not None + assert abs(result_high.score - 0.667) < 0.01 + + +async def test_citation_requirement_characters_vs_claims(): + """Test that CHARACTERS and CLAIMS modes produce different results.""" + from unittest.mock import Mock, patch + + backend = Mock(spec=LocalHFBackend) + + # Create documents + docs = [Document(doc_id="doc1", text="Short fact.")] + + # Create a response: 1 short sentence with citation, 1 long sentence without + response = "Short. This is a much longer sentence without any citation support." + + # Create context + ctx = ChatContext().add(Message("user", "Tell me something.")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Mock find_citations to return 1 citation for the short sentence + mock_citations = [ + { + "response_begin": 0, + "response_end": 6, # "Short." = 6 characters + "response_text": "Short", + "citation_doc_id": "doc1", + "citation_text": "Short fact.", + } + ] + + with patch( + "mellea.stdlib.components.intrinsic.rag.find_citations", + return_value=mock_citations, + ): + # CLAIMS mode: 1 citation out of 2 sentences = 50% + req_claims = CitationRequirement( + min_citation_coverage=0.5, mode=CitationMode.CLAIMS + ) + result_claims = await req_claims.validate(backend, ctx) + + # CHARACTERS mode: 6 characters out of 67 total = ~9% + req_chars = CitationRequirement( + min_citation_coverage=0.5, mode=CitationMode.CHARACTERS + ) + result_chars = await req_chars.validate(backend, ctx) + + # CLAIMS mode should pass (50% >= 50%) + assert result_claims.as_bool() + assert result_claims.score == 0.5 + + # CHARACTERS mode should fail (~9% < 50%) + assert not result_chars.as_bool() + assert result_chars.score is not None + assert result_chars.score < 0.1