From 8bf9523d3f445e2669b0cc19eada44dd35e43eef Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 23 Mar 2026 12:01:02 -0400 Subject: [PATCH 1/9] feat: citation requirement Signed-off-by: Akihiko Kuroda --- docs/examples/citation_requirement_example.py | 102 ++++++ mellea/stdlib/requirements/__init__.py | 3 + mellea/stdlib/requirements/rag.py | 341 ++++++++++++++++++ test/stdlib/requirements/test_rag.py | 298 +++++++++++++++ 4 files changed, 744 insertions(+) create mode 100644 docs/examples/citation_requirement_example.py create mode 100644 mellea/stdlib/requirements/rag.py create mode 100644 test/stdlib/requirements/test_rag.py diff --git a/docs/examples/citation_requirement_example.py b/docs/examples/citation_requirement_example.py new file mode 100644 index 000000000..56bc3dc53 --- /dev/null +++ b/docs/examples/citation_requirement_example.py @@ -0,0 +1,102 @@ +# pytest: huggingface, llm, requires_heavy_ram +"""Example demonstrating CitationRequirement for RAG workflows. + +This example shows how to use CitationRequirement to validate that +assistant responses properly cite their sources in RAG workflows. + +Note: This example requires HuggingFace backend and access to the +meta-llama/Llama-3.2-1B-Instruct model. +""" + +import asyncio + +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.rag import CitationRequirement, citation_check + + +async def main(): + """Demonstrate CitationRequirement usage.""" + print("=" * 70) + print("CitationRequirement Example") + print("=" * 70) + + # Initialize HuggingFace backend + print("\nInitializing HuggingFace backend...") + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Create documents + docs = [ + Document( + doc_id="doc1", + title="Sky Facts", + text="The sky appears blue during the day due to Rayleigh scattering.", + ), + Document( + doc_id="doc2", + title="Grass Facts", + text="Grass is typically green because of chlorophyll in the leaves.", + ), + ] + + # Create a response that should have citations + response = ( + "The sky appears blue during the day. " + "Grass is green because it contains chlorophyll." + ) + + # Create context + print("\nCreating context with user question and assistant response...") + ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Example 1: Using CitationRequirement directly + print("\n--- Example 1: CitationRequirement with 70% coverage ---") + req = CitationRequirement(min_citation_coverage=0.7, documents=docs) + result = await req.validate(backend, ctx) + + print(f"Validation passed: {result.as_bool()}") + print(f"Citation coverage score: {result.score:.2%}") + if result.reason: + reason_preview = ( + result.reason[:200] + "..." if len(result.reason) > 200 else result.reason + ) + print(f"Reason: {reason_preview}") + + # Example 2: Using citation_check factory + print("\n--- Example 2: Using citation_check factory ---") + req2 = citation_check(docs, min_citation_coverage=0.8) + result2 = await req2.validate(backend, ctx) + + print(f"Validation passed: {result2.as_bool()}") + print(f"Citation coverage score: {result2.score:.2%}") + + # Example 3: Documents attached to message + print("\n--- Example 3: Documents in message (not constructor) ---") + ctx2 = ChatContext().add(Message("user", "Tell me about Mars.")) + ctx2 = ctx2.add( + Message( + "assistant", + "Mars is the fourth planet from the Sun.", + documents=[ + Document(doc_id="doc1", text="Mars is the fourth planet from the Sun.") + ], + ) + ) + + req3 = CitationRequirement(min_citation_coverage=0.7) # No documents in constructor + result3 = await req3.validate(backend, ctx2) + + print(f"Validation passed: {result3.as_bool()}") + print(f"Citation coverage score: {result3.score:.2%}") + + print("\n" + "=" * 70) + print("Example completed successfully!") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) + +# Made with Bob diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index c0bd7d3c9..15787dc8a 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -4,6 +4,7 @@ from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table from .python_reqs import PythonExecutionReq +from .rag import CitationRequirement, citation_check from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -17,12 +18,14 @@ __all__ = [ "ALoraRequirement", + "CitationRequirement", "LLMaJRequirement", "PythonExecutionReq", "Requirement", "ValidationResult", "as_markdown_list", "check", + "citation_check", "default_output_to_bool", "is_markdown_list", "is_markdown_table", diff --git a/mellea/stdlib/requirements/rag.py b/mellea/stdlib/requirements/rag.py new file mode 100644 index 000000000..f3502ea52 --- /dev/null +++ b/mellea/stdlib/requirements/rag.py @@ -0,0 +1,341 @@ +"""Requirements for RAG (Retrieval-Augmented Generation) workflows.""" + +from collections.abc import Iterable + +from ...backends.adapters import AdapterMixin +from ...core import Backend, Context, Requirement, ValidationResult +from ..components import Document, Message + + +class CitationRequirement(Requirement): + """Requirement that validates RAG responses have adequate citation coverage. + + Uses the find_citations intrinsic to identify which parts of an assistant's + response are supported by explicit citations to retrieved documents. Content + without citations below the minimum coverage threshold fails validation. + + **Important**: This requirement requires a HuggingFace backend (LocalHFBackend) + as the find_citations intrinsic only works with HuggingFace models. Using other + backends (Ollama, OpenAI, etc.) will result in a validation error. + + This requirement is designed for RAG workflows where you want to ensure + responses properly cite their sources. It works with: + - A user question in the context + - Retrieved documents + - An assistant response to validate + + Documents can be provided either: + 1. In the constructor (for reusable requirements with fixed documents) + 2. Attached to the assistant message in the context (for dynamic documents) + + Example: + ```python + from mellea.backends.huggingface import LocalHFBackend + from mellea.stdlib.requirements.rag import CitationRequirement + + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Option 1: Documents in constructor + req = CitationRequirement( + documents=doc_objects, + min_citation_coverage=0.8 + ) + + # Option 2: Documents in context (original pattern) + req = CitationRequirement(min_citation_coverage=0.8) + ctx = ChatContext().add( + Message("assistant", response, documents=doc_objects) + ) + ``` + """ + + def __init__( + self, + min_citation_coverage: float = 0.8, + documents: Iterable[Document] | Iterable[str] | None = None, + description: str | None = None, + ): + """Initialize citation coverage requirement. + + Args: + min_citation_coverage: Minimum ratio of cited content (0.0-1.0). + The ratio of characters with citations to total response length + must meet or exceed this threshold. Default: 0.8 (80% coverage) + documents: Optional documents to validate against. Can be Document + objects or strings (will be converted to Documents). If provided, + these documents will be used instead of documents attached to + messages in the context. Default: None (use context documents) + description: Custom description for the requirement. If None, + generates a description based on coverage threshold. + """ + if not 0.0 <= min_citation_coverage <= 1.0: + raise ValueError( + f"min_citation_coverage must be between 0.0 and 1.0, got {min_citation_coverage}" + ) + + self.min_citation_coverage = min_citation_coverage + + # Convert documents to Document objects if provided + if documents is not None: + self.documents: list[Document] | None = [ + doc + if isinstance(doc, Document) + else Document(doc_id=str(i), text=str(doc)) + for i, doc in enumerate(documents) + ] + else: + self.documents = None + + # Generate description if not provided + if description is None: + description = ( + f"Response must have adequate citation coverage " + f"(minimum {min_citation_coverage * 100:.0f}% of content cited)" + ) + + # Initialize parent without validation function - we override validate() instead + super().__init__(description=description, validation_fn=None) + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + """Validate citation coverage in the context using the backend. + + Args: + backend: Backend to use for citation detection. Must be LocalHFBackend + as the find_citations intrinsic only works with HuggingFace models. + ctx: Context containing the conversation history + format: Unused for this requirement + model_options: Unused for this requirement + + Returns: + ValidationResult with pass/fail status, reason, and score + """ + # Extract last message (should be assistant response) + messages = ctx.as_list() + if not messages: + return ValidationResult( + False, reason="Context is empty, cannot validate citation coverage" + ) + + last_message = messages[-1] + if not isinstance(last_message, Message): + return ValidationResult( + False, + reason="Last context item is not a Message, cannot validate citation coverage", + ) + + if last_message.role != "assistant": + return ValidationResult( + False, + reason=f"Last message must be assistant response, got role: {last_message.role}", + ) + + response = last_message.content + + # Use constructor documents if provided, otherwise get from message + if self.documents is not None: + documents = self.documents + else: + # Access private _docs attribute since documents property returns formatted strings + documents = last_message._docs or [] + + if not documents: + return ValidationResult( + False, + reason="No documents provided for citation validation. " + "Either pass documents to CitationRequirement constructor " + "or attach them to the assistant message.", + ) + + # Check backend compatibility + if not isinstance(backend, AdapterMixin): + return ValidationResult( + False, + reason=f"Backend {backend.__class__.__name__} does not support adapters required for citation detection", + ) + + # More specific check for HuggingFace backend + try: + from ...backends.huggingface import LocalHFBackend + + if not isinstance(backend, LocalHFBackend): + return ValidationResult( + False, + reason=f"Citation detection requires LocalHFBackend (HuggingFace), " + f"but got {backend.__class__.__name__}. The find_citations intrinsic " + f"only works with HuggingFace models.", + ) + except ImportError: + return ValidationResult( + False, + reason="HuggingFace backend not available. Please install mellea[hf] to use citation detection.", + ) + + # Create context before the response by getting all but the last message + all_messages = ctx.as_list() + if len(all_messages) > 1: + # Rebuild context without last message + from ..context import ChatContext + + context_before_response = ChatContext() + for msg in all_messages[:-1]: + context_before_response = context_before_response.add(msg) + else: + # If only one message, use empty context + from ..context import ChatContext + + context_before_response = ChatContext() + + # Call find_citations intrinsic + try: + # Import here to avoid circular dependency + from ..components.intrinsic import rag + + citations: list[dict] = rag.find_citations( + response, documents, context_before_response, backend + ) + except Exception as e: + return ValidationResult( + False, reason=f"Citation detection intrinsic failed: {e!s}" + ) + + # Calculate citation coverage + total_chars = len(response) + if total_chars == 0: + return ValidationResult( + True, reason="Empty response has 100% citation coverage", score=1.0 + ) + + cited_chars = sum( + citation["response_end"] - citation["response_begin"] + for citation in citations + ) + coverage_ratio = cited_chars / total_chars + + # Check against min_citation_coverage + passed = coverage_ratio >= self.min_citation_coverage + + # Build detailed reason + reason = self._build_reason(citations, coverage_ratio, passed) + + return ValidationResult(passed, reason=reason, score=coverage_ratio) + + def _build_reason( + self, citations: list[dict], coverage_ratio: float, passed: bool + ) -> str: + """Build a detailed reason string for the validation result. + + Args: + citations: List of citation records from find_citations + coverage_ratio: Ratio of cited content + passed: Whether validation passed + + Returns: + Detailed reason string + """ + num_citations = len(citations) + coverage_pct = coverage_ratio * 100 + threshold_pct = self.min_citation_coverage * 100 + + if passed: + reason = ( + f"Response has adequate citation coverage " + f"({coverage_pct:.1f}% cited, threshold: {threshold_pct:.1f}%)" + ) + else: + reason = ( + f"Response has insufficient citation coverage " + f"({coverage_pct:.1f}% cited, threshold: {threshold_pct:.1f}%)" + ) + + # Add details about citations + if citations: + reason += f"\n\nCitations found ({num_citations}):" + for i, citation in enumerate(citations[:5]): # Show first 5 + response_text = citation["response_text"].strip() + doc_id = citation.get("citation_doc_id", "unknown") + citation_text = citation.get("citation_text", "").strip() + # Truncate long texts + if len(response_text) > 60: + response_text = response_text[:57] + "..." + if len(citation_text) > 60: + citation_text = citation_text[:57] + "..." + reason += f"\n {i + 1}. '{response_text}' → Document '{doc_id}'" + if citation_text: + reason += f"\n Source: '{citation_text}'" + + if len(citations) > 5: + reason += f"\n ... and {len(citations) - 5} more citation(s)" + else: + reason += "\n\nNo citations found in the response." + + if not passed: + uncited_pct = 100.0 - coverage_pct + reason += ( + f"\n\nUncited content represents {uncited_pct:.1f}% of the response." + ) + + return reason + + +def citation_check( + documents: Iterable[Document] | Iterable[str], + min_citation_coverage: float = 0.8, + description: str | None = None, +) -> CitationRequirement: + """Create a citation coverage requirement with pre-attached documents. + + This is a convenience factory function that creates a CitationRequirement + with documents already attached. This is useful when you have a fixed set of + documents to validate against and want a cleaner API. + + **Important**: This requirement requires a HuggingFace backend (LocalHFBackend). + + Args: + documents: Documents to check for citations. Can be Document objects + or strings (will be converted to Documents). + min_citation_coverage: Minimum ratio of cited content (0.0-1.0). + Default: 0.8 (80% coverage) + description: Custom description for the requirement. If None, + generates a description based on coverage threshold. + + Returns: + A CitationRequirement with documents attached + + Example: + ```python + from mellea.backends.huggingface import LocalHFBackend + from mellea.stdlib.requirements.rag import citation_check + from mellea.stdlib.components import Document + + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + docs = [ + Document(doc_id="1", text="The sky is blue."), + Document(doc_id="2", text="Grass is green.") + ] + req = citation_check(docs, min_citation_coverage=0.8) + + # Use with instruct() - no need to attach documents to messages + result = m.instruct( + "Answer: {{query}}", + grounding_context={"query": "What color is the sky?"}, + requirements=[req], + backend=backend, + strategy=RejectionSamplingStrategy() + ) + ``` + """ + return CitationRequirement( + min_citation_coverage=min_citation_coverage, + documents=documents, + description=description, + ) + + +# Made with Bob diff --git a/test/stdlib/requirements/test_rag.py b/test/stdlib/requirements/test_rag.py new file mode 100644 index 000000000..06e00c34d --- /dev/null +++ b/test/stdlib/requirements/test_rag.py @@ -0,0 +1,298 @@ +"""Tests for RAG requirements.""" +# pytest: huggingface, llm, requires_heavy_ram + +import pytest + +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.rag import CitationRequirement, citation_check + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_requirement_basic(): + """Test basic citation requirement functionality.""" + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Create documents + docs = [ + Document(doc_id="doc1", text="The sky is blue during the day."), + Document(doc_id="doc2", text="Grass is typically green in color."), + ] + + # Create a response that should have citations + response = "The sky is blue. Grass is green." + + # Create context with assistant message + ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.5) + + # Validate + result = await req.validate(backend, ctx) + + # Should pass if citations are found + assert isinstance(result.score, float) + assert 0.0 <= result.score <= 1.0 + assert result.reason is not None + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_requirement_with_constructor_documents(): + """Test citation requirement with documents in constructor.""" + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Create documents + docs = [ + Document(doc_id="doc1", text="The sky is blue during the day."), + Document(doc_id="doc2", text="Grass is typically green in color."), + ] + + # Create a response + response = "The sky is blue. Grass is green." + + # Create context with assistant message (no documents attached) + ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) + ctx = ctx.add(Message("assistant", response)) + + # Create requirement with documents in constructor + req = CitationRequirement(min_citation_coverage=0.5, documents=docs) + + # Validate + result = await req.validate(backend, ctx) + + # Should use constructor documents + assert isinstance(result.score, float) + assert result.reason is not None + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_check_factory(): + """Test citation_check factory function.""" + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Create documents + docs = [Document(doc_id="doc1", text="The sky is blue during the day.")] + + # Create a response + response = "The sky is blue." + + # Create context + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", response)) + + # Use factory function + req = citation_check(docs, min_citation_coverage=0.5) + + # Validate + result = await req.validate(backend, ctx) + + # Should work the same as CitationRequirement + assert isinstance(result.score, float) + assert result.reason is not None + + +async def test_citation_requirement_empty_context(): + """Test citation requirement with empty context.""" + # Create a mock backend - we don't need a real one for this test + # since validation fails before backend is used + from unittest.mock import Mock + + backend = Mock(spec=LocalHFBackend) + + # Create empty context + ctx = ChatContext() + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error + assert not result.as_bool() + assert result.reason is not None + assert "empty" in result.reason.lower() + + +async def test_citation_requirement_wrong_message_role(): + """Test citation requirement with non-assistant last message.""" + # Create a mock backend - we don't need a real one for this test + from unittest.mock import Mock + + backend = Mock(spec=LocalHFBackend) + + # Create context ending with user message + ctx = ChatContext().add(Message("user", "What color is the sky?")) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error + assert not result.as_bool() + assert result.reason is not None + assert "assistant" in result.reason.lower() + + +async def test_citation_requirement_no_documents(): + """Test citation requirement with no documents provided.""" + # Create a mock backend - we don't need a real one for this test + from unittest.mock import Mock + + backend = Mock(spec=LocalHFBackend) + + # Create context without documents + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", "The sky is blue.")) + + # Create requirement without documents + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error about missing documents + assert not result.as_bool() + assert result.reason is not None + assert "documents" in result.reason.lower() + + +async def test_citation_requirement_wrong_backend(): + """Test citation requirement with non-HuggingFace backend.""" + try: + from mellea.backends.ollama import OllamaBackend # type: ignore + except ImportError: + pytest.skip("Ollama backend not available") + + backend = OllamaBackend(model_id="llama3.2") # type: ignore + + # Create documents + docs = [Document(doc_id="doc1", text="The sky is blue.")] + + # Create context + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", "The sky is blue.", documents=docs)) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Should fail with clear error about backend requirement + assert not result.as_bool() + assert result.reason is not None + assert "LocalHFBackend" in result.reason or "HuggingFace" in result.reason + + +def test_citation_requirement_invalid_coverage(): + """Test citation requirement with invalid coverage values.""" + # Test coverage > 1.0 + with pytest.raises(ValueError, match=r"between 0\.0 and 1\.0"): + CitationRequirement(min_citation_coverage=1.5) + + # Test coverage < 0.0 + with pytest.raises(ValueError, match=r"between 0\.0 and 1\.0"): + CitationRequirement(min_citation_coverage=-0.5) + + +def test_citation_requirement_string_documents(): + """Test citation requirement with string documents.""" + # Should convert strings to Document objects + req = CitationRequirement( + min_citation_coverage=0.8, documents=["The sky is blue.", "Grass is green."] + ) + + # Check documents were converted + assert req.documents is not None + assert len(req.documents) == 2 + assert all(isinstance(doc, Document) for doc in req.documents) + + +def test_citation_requirement_custom_description(): + """Test citation requirement with custom description.""" + custom_desc = "Custom citation requirement description" + req = CitationRequirement(min_citation_coverage=0.8, description=custom_desc) + + assert req.description == custom_desc + + +def test_citation_requirement_default_description(): + """Test citation requirement generates default description.""" + req = CitationRequirement(min_citation_coverage=0.75) + + assert req.description is not None + assert "75%" in req.description or "0.75" in req.description + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_requirement_empty_response(): + """Test citation requirement with empty response.""" + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Create documents + docs = [Document(doc_id="doc1", text="The sky is blue.")] + + # Create context with empty response + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", "", documents=docs)) + + # Create requirement + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Empty response should pass (100% coverage of nothing) + assert result.as_bool() + assert result.score == 1.0 + + +@pytest.mark.huggingface +@pytest.mark.llm +@pytest.mark.requires_heavy_ram +async def test_citation_requirement_threshold_boundary(): + """Test citation requirement at exact threshold boundary.""" + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + + # Create documents + docs = [Document(doc_id="doc1", text="The sky is blue during the day.")] + + # Create a response + response = "The sky is blue." + + # Create context + ctx = ChatContext().add(Message("user", "What color is the sky?")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Create requirement with specific threshold + req = CitationRequirement(min_citation_coverage=0.8) + + # Validate + result = await req.validate(backend, ctx) + + # Check that score is calculated + assert isinstance(result.score, float) + assert 0.0 <= result.score <= 1.0 + + # Result should match threshold comparison + if result.score >= 0.8: + assert result.as_bool() + else: + assert not result.as_bool() + + +# Made with Bob From 52ae5f44971f789541904f0c3bd31d33c095b538 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 23 Mar 2026 15:40:30 -0400 Subject: [PATCH 2/9] fix test error Signed-off-by: Akihiko Kuroda --- docs/examples/citation_requirement_example.py | 2 -- mellea/stdlib/requirements/rag.py | 15 +++++++-------- test/stdlib/requirements/test_rag.py | 13 +++++-------- 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/docs/examples/citation_requirement_example.py b/docs/examples/citation_requirement_example.py index 56bc3dc53..be1590c41 100644 --- a/docs/examples/citation_requirement_example.py +++ b/docs/examples/citation_requirement_example.py @@ -98,5 +98,3 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) - -# Made with Bob diff --git a/mellea/stdlib/requirements/rag.py b/mellea/stdlib/requirements/rag.py index f3502ea52..1b0a003a9 100644 --- a/mellea/stdlib/requirements/rag.py +++ b/mellea/stdlib/requirements/rag.py @@ -192,6 +192,13 @@ async def validate( context_before_response = ChatContext() + # Handle empty response before calling intrinsic + total_chars = len(response) + if total_chars == 0: + return ValidationResult( + True, reason="Empty response has 100% citation coverage", score=1.0 + ) + # Call find_citations intrinsic try: # Import here to avoid circular dependency @@ -206,11 +213,6 @@ async def validate( ) # Calculate citation coverage - total_chars = len(response) - if total_chars == 0: - return ValidationResult( - True, reason="Empty response has 100% citation coverage", score=1.0 - ) cited_chars = sum( citation["response_end"] - citation["response_begin"] @@ -336,6 +338,3 @@ def citation_check( documents=documents, description=description, ) - - -# Made with Bob diff --git a/test/stdlib/requirements/test_rag.py b/test/stdlib/requirements/test_rag.py index 06e00c34d..c57819ad0 100644 --- a/test/stdlib/requirements/test_rag.py +++ b/test/stdlib/requirements/test_rag.py @@ -14,7 +14,7 @@ @pytest.mark.requires_heavy_ram async def test_citation_requirement_basic(): """Test basic citation requirement functionality.""" - backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") # Create documents docs = [ @@ -46,7 +46,7 @@ async def test_citation_requirement_basic(): @pytest.mark.requires_heavy_ram async def test_citation_requirement_with_constructor_documents(): """Test citation requirement with documents in constructor.""" - backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") # Create documents docs = [ @@ -77,7 +77,7 @@ async def test_citation_requirement_with_constructor_documents(): @pytest.mark.requires_heavy_ram async def test_citation_check_factory(): """Test citation_check factory function.""" - backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") # Create documents docs = [Document(doc_id="doc1", text="The sky is blue during the day.")] @@ -241,7 +241,7 @@ def test_citation_requirement_default_description(): @pytest.mark.requires_heavy_ram async def test_citation_requirement_empty_response(): """Test citation requirement with empty response.""" - backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") # Create documents docs = [Document(doc_id="doc1", text="The sky is blue.")] @@ -266,7 +266,7 @@ async def test_citation_requirement_empty_response(): @pytest.mark.requires_heavy_ram async def test_citation_requirement_threshold_boundary(): """Test citation requirement at exact threshold boundary.""" - backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") # Create documents docs = [Document(doc_id="doc1", text="The sky is blue during the day.")] @@ -293,6 +293,3 @@ async def test_citation_requirement_threshold_boundary(): assert result.as_bool() else: assert not result.as_bool() - - -# Made with Bob From 1b5cc8230089e02a07efebb55f2ebe43822aa1fe Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 23 Mar 2026 16:01:23 -0400 Subject: [PATCH 3/9] fix example Signed-off-by: Akihiko Kuroda --- docs/examples/citation_requirement_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/citation_requirement_example.py b/docs/examples/citation_requirement_example.py index be1590c41..8718e262a 100644 --- a/docs/examples/citation_requirement_example.py +++ b/docs/examples/citation_requirement_example.py @@ -5,7 +5,7 @@ assistant responses properly cite their sources in RAG workflows. Note: This example requires HuggingFace backend and access to the -meta-llama/Llama-3.2-1B-Instruct model. +ibm-granite/granite-4.0-micro model. """ import asyncio @@ -24,7 +24,7 @@ async def main(): # Initialize HuggingFace backend print("\nInitializing HuggingFace backend...") - backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") + backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") # Create documents docs = [ From 8f3779cf9cade433e4a748670961eb0866bbcf3d Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 23 Mar 2026 16:16:44 -0400 Subject: [PATCH 4/9] fix dockstring issue Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/rag.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/mellea/stdlib/requirements/rag.py b/mellea/stdlib/requirements/rag.py index 1b0a003a9..508bddbbf 100644 --- a/mellea/stdlib/requirements/rag.py +++ b/mellea/stdlib/requirements/rag.py @@ -28,6 +28,17 @@ class CitationRequirement(Requirement): 1. In the constructor (for reusable requirements with fixed documents) 2. Attached to the assistant message in the context (for dynamic documents) + Args: + min_citation_coverage: Minimum ratio of cited content (0.0-1.0). + The ratio of characters with citations to total response length + must meet or exceed this threshold. Default is 0.8 (80% coverage). + documents: Optional documents to validate against. Can be Document + objects or strings (will be converted to Documents). If provided, + these documents will be used instead of documents attached to + messages in the context. Default is None (use context documents). + description: Custom description for the requirement. If None, + generates a description based on coverage threshold. + Example: ```python from mellea.backends.huggingface import LocalHFBackend @@ -55,19 +66,7 @@ def __init__( documents: Iterable[Document] | Iterable[str] | None = None, description: str | None = None, ): - """Initialize citation coverage requirement. - - Args: - min_citation_coverage: Minimum ratio of cited content (0.0-1.0). - The ratio of characters with citations to total response length - must meet or exceed this threshold. Default: 0.8 (80% coverage) - documents: Optional documents to validate against. Can be Document - objects or strings (will be converted to Documents). If provided, - these documents will be used instead of documents attached to - messages in the context. Default: None (use context documents) - description: Custom description for the requirement. If None, - generates a description based on coverage threshold. - """ + """Initialize citation coverage requirement.""" if not 0.0 <= min_citation_coverage <= 1.0: raise ValueError( f"min_citation_coverage must be between 0.0 and 1.0, got {min_citation_coverage}" @@ -302,8 +301,8 @@ def citation_check( Args: documents: Documents to check for citations. Can be Document objects or strings (will be converted to Documents). - min_citation_coverage: Minimum ratio of cited content (0.0-1.0). - Default: 0.8 (80% coverage) + min_citation_coverage: Minimum ratio of cited content (0.0-1.0), + defaults to 0.8 (80% coverage). description: Custom description for the requirement. If None, generates a description based on coverage threshold. From 6677b5e0a9ff3ff57ac89d2bc723ed4aad3ffc75 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 23 Mar 2026 16:30:31 -0400 Subject: [PATCH 5/9] fix test name conflict Signed-off-by: Akihiko Kuroda --- .../stdlib/requirements/{test_rag.py => test_rag_requirements.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/stdlib/requirements/{test_rag.py => test_rag_requirements.py} (100%) diff --git a/test/stdlib/requirements/test_rag.py b/test/stdlib/requirements/test_rag_requirements.py similarity index 100% rename from test/stdlib/requirements/test_rag.py rename to test/stdlib/requirements/test_rag_requirements.py From 70d70cb9673b84ac9df28fd689721dab5bbc3246 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Tue, 24 Mar 2026 12:29:47 -0400 Subject: [PATCH 6/9] review comments Signed-off-by: Akihiko Kuroda --- docs/examples/citation_requirement_example.py | 12 +- mellea/stdlib/requirements/__init__.py | 3 +- mellea/stdlib/requirements/rag.py | 77 +------------ .../requirements/test_rag_requirements.py | 105 ++++++++---------- 4 files changed, 60 insertions(+), 137 deletions(-) diff --git a/docs/examples/citation_requirement_example.py b/docs/examples/citation_requirement_example.py index 8718e262a..c1d8faaa5 100644 --- a/docs/examples/citation_requirement_example.py +++ b/docs/examples/citation_requirement_example.py @@ -13,7 +13,7 @@ from mellea.backends.huggingface import LocalHFBackend from mellea.stdlib.components import Document, Message from mellea.stdlib.context import ChatContext -from mellea.stdlib.requirements.rag import CitationRequirement, citation_check +from mellea.stdlib.requirements.rag import CitationRequirement async def main(): @@ -51,8 +51,8 @@ async def main(): ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) ctx = ctx.add(Message("assistant", response, documents=docs)) - # Example 1: Using CitationRequirement directly - print("\n--- Example 1: CitationRequirement with 70% coverage ---") + # Example 1: Documents in constructor + print("\n--- Example 1: CitationRequirement with documents in constructor ---") req = CitationRequirement(min_citation_coverage=0.7, documents=docs) result = await req.validate(backend, ctx) @@ -64,9 +64,9 @@ async def main(): ) print(f"Reason: {reason_preview}") - # Example 2: Using citation_check factory - print("\n--- Example 2: Using citation_check factory ---") - req2 = citation_check(docs, min_citation_coverage=0.8) + # Example 2: Higher coverage threshold + print("\n--- Example 2: Higher coverage threshold (80%) ---") + req2 = CitationRequirement(min_citation_coverage=0.8, documents=docs) result2 = await req2.validate(backend, ctx) print(f"Validation passed: {result2.as_bool()}") diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index 15787dc8a..6a1c5af56 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -4,7 +4,7 @@ from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table from .python_reqs import PythonExecutionReq -from .rag import CitationRequirement, citation_check +from .rag import CitationRequirement from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -25,7 +25,6 @@ "ValidationResult", "as_markdown_list", "check", - "citation_check", "default_output_to_bool", "is_markdown_list", "is_markdown_table", diff --git a/mellea/stdlib/requirements/rag.py b/mellea/stdlib/requirements/rag.py index 508bddbbf..46999bf52 100644 --- a/mellea/stdlib/requirements/rag.py +++ b/mellea/stdlib/requirements/rag.py @@ -159,27 +159,11 @@ async def validate( reason=f"Backend {backend.__class__.__name__} does not support adapters required for citation detection", ) - # More specific check for HuggingFace backend - try: - from ...backends.huggingface import LocalHFBackend - - if not isinstance(backend, LocalHFBackend): - return ValidationResult( - False, - reason=f"Citation detection requires LocalHFBackend (HuggingFace), " - f"but got {backend.__class__.__name__}. The find_citations intrinsic " - f"only works with HuggingFace models.", - ) - except ImportError: - return ValidationResult( - False, - reason="HuggingFace backend not available. Please install mellea[hf] to use citation detection.", - ) - # Create context before the response by getting all but the last message all_messages = ctx.as_list() if len(all_messages) > 1: # Rebuild context without last message + # Import here to avoid circular dependency from ..context import ChatContext context_before_response = ChatContext() @@ -187,6 +171,7 @@ async def validate( context_before_response = context_before_response.add(msg) else: # If only one message, use empty context + # Import here to avoid circular dependency from ..context import ChatContext context_before_response = ChatContext() @@ -195,7 +180,9 @@ async def validate( total_chars = len(response) if total_chars == 0: return ValidationResult( - True, reason="Empty response has 100% citation coverage", score=1.0 + True, + reason="Empty response is considered to have adequate citation coverage", + score=1.0, ) # Call find_citations intrinsic @@ -283,57 +270,3 @@ def _build_reason( ) return reason - - -def citation_check( - documents: Iterable[Document] | Iterable[str], - min_citation_coverage: float = 0.8, - description: str | None = None, -) -> CitationRequirement: - """Create a citation coverage requirement with pre-attached documents. - - This is a convenience factory function that creates a CitationRequirement - with documents already attached. This is useful when you have a fixed set of - documents to validate against and want a cleaner API. - - **Important**: This requirement requires a HuggingFace backend (LocalHFBackend). - - Args: - documents: Documents to check for citations. Can be Document objects - or strings (will be converted to Documents). - min_citation_coverage: Minimum ratio of cited content (0.0-1.0), - defaults to 0.8 (80% coverage). - description: Custom description for the requirement. If None, - generates a description based on coverage threshold. - - Returns: - A CitationRequirement with documents attached - - Example: - ```python - from mellea.backends.huggingface import LocalHFBackend - from mellea.stdlib.requirements.rag import citation_check - from mellea.stdlib.components import Document - - backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") - docs = [ - Document(doc_id="1", text="The sky is blue."), - Document(doc_id="2", text="Grass is green.") - ] - req = citation_check(docs, min_citation_coverage=0.8) - - # Use with instruct() - no need to attach documents to messages - result = m.instruct( - "Answer: {{query}}", - grounding_context={"query": "What color is the sky?"}, - requirements=[req], - backend=backend, - strategy=RejectionSamplingStrategy() - ) - ``` - """ - return CitationRequirement( - min_citation_coverage=min_citation_coverage, - documents=documents, - description=description, - ) diff --git a/test/stdlib/requirements/test_rag_requirements.py b/test/stdlib/requirements/test_rag_requirements.py index c57819ad0..08c22eaad 100644 --- a/test/stdlib/requirements/test_rag_requirements.py +++ b/test/stdlib/requirements/test_rag_requirements.py @@ -6,7 +6,7 @@ from mellea.backends.huggingface import LocalHFBackend from mellea.stdlib.components import Document, Message from mellea.stdlib.context import ChatContext -from mellea.stdlib.requirements.rag import CitationRequirement, citation_check +from mellea.stdlib.requirements.rag import CitationRequirement @pytest.mark.huggingface @@ -72,34 +72,6 @@ async def test_citation_requirement_with_constructor_documents(): assert result.reason is not None -@pytest.mark.huggingface -@pytest.mark.llm -@pytest.mark.requires_heavy_ram -async def test_citation_check_factory(): - """Test citation_check factory function.""" - backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") - - # Create documents - docs = [Document(doc_id="doc1", text="The sky is blue during the day.")] - - # Create a response - response = "The sky is blue." - - # Create context - ctx = ChatContext().add(Message("user", "What color is the sky?")) - ctx = ctx.add(Message("assistant", response)) - - # Use factory function - req = citation_check(docs, min_citation_coverage=0.5) - - # Validate - result = await req.validate(backend, ctx) - - # Should work the same as CitationRequirement - assert isinstance(result.score, float) - assert result.reason is not None - - async def test_citation_requirement_empty_context(): """Test citation requirement with empty context.""" # Create a mock backend - we don't need a real one for this test @@ -169,13 +141,12 @@ async def test_citation_requirement_no_documents(): async def test_citation_requirement_wrong_backend(): - """Test citation requirement with non-HuggingFace backend.""" - try: - from mellea.backends.ollama import OllamaBackend # type: ignore - except ImportError: - pytest.skip("Ollama backend not available") + """Test citation requirement with non-adapter backend.""" + from unittest.mock import Mock - backend = OllamaBackend(model_id="llama3.2") # type: ignore + # Create a mock backend that doesn't support adapters + backend = Mock() + backend.__class__.__name__ = "MockBackend" # Create documents docs = [Document(doc_id="doc1", text="The sky is blue.")] @@ -190,10 +161,10 @@ async def test_citation_requirement_wrong_backend(): # Validate result = await req.validate(backend, ctx) - # Should fail with clear error about backend requirement + # Should fail with clear error about adapter requirement assert not result.as_bool() assert result.reason is not None - assert "LocalHFBackend" in result.reason or "HuggingFace" in result.reason + assert "adapter" in result.reason.lower() def test_citation_requirement_invalid_coverage(): @@ -256,40 +227,60 @@ async def test_citation_requirement_empty_response(): # Validate result = await req.validate(backend, ctx) - # Empty response should pass (100% coverage of nothing) + # Empty response should pass (considered to have adequate coverage) assert result.as_bool() assert result.score == 1.0 + assert result.reason is not None + assert "adequate citation coverage" in result.reason.lower() -@pytest.mark.huggingface -@pytest.mark.llm -@pytest.mark.requires_heavy_ram async def test_citation_requirement_threshold_boundary(): - """Test citation requirement at exact threshold boundary.""" - backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") + """Test citation requirement at exact threshold boundary. + + This test mocks the find_citations intrinsic to return a controlled + result that produces exactly the threshold coverage (80%). + """ + from unittest.mock import Mock, patch + + backend = Mock(spec=LocalHFBackend) # Create documents docs = [Document(doc_id="doc1", text="The sky is blue during the day.")] - # Create a response - response = "The sky is blue." + # Create a response with 10 characters + response = "1234567890" # Create context ctx = ChatContext().add(Message("user", "What color is the sky?")) ctx = ctx.add(Message("assistant", response, documents=docs)) - # Create requirement with specific threshold - req = CitationRequirement(min_citation_coverage=0.8) - - # Validate - result = await req.validate(backend, ctx) + # Mock find_citations to return exactly 8 characters cited (80% of 10) + mock_citations = [ + { + "response_begin": 0, + "response_end": 8, # 8 characters cited + "response_text": "12345678", + "citation_doc_id": "doc1", + "citation_text": "The sky is blue", + } + ] - # Check that score is calculated - assert isinstance(result.score, float) - assert 0.0 <= result.score <= 1.0 + with patch( + "mellea.stdlib.components.intrinsic.rag.find_citations", + return_value=mock_citations, + ): + # Test at exact threshold (0.8) + req = CitationRequirement(min_citation_coverage=0.8) + result = await req.validate(backend, ctx) - # Result should match threshold comparison - if result.score >= 0.8: + # At exact threshold, should pass (>= comparison) assert result.as_bool() - else: - assert not result.as_bool() + assert result.score == 0.8 + + # Test just below threshold (0.81) + req_above = CitationRequirement(min_citation_coverage=0.81) + result_above = await req_above.validate(backend, ctx) + + # Just below threshold, should fail + assert not result_above.as_bool() + assert result_above.score == 0.8 From e08b8be7593c84903336e35c06e7ea5bb8c71423 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Tue, 24 Mar 2026 13:45:28 -0400 Subject: [PATCH 7/9] feat: review comments Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/__init__.py | 3 +- mellea/stdlib/requirements/rag.py | 86 +++++++++--- .../requirements/test_rag_requirements.py | 129 +++++++++++++++++- 3 files changed, 196 insertions(+), 22 deletions(-) diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index 6a1c5af56..ef03524ae 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -4,7 +4,7 @@ from ...core import Requirement, ValidationResult, default_output_to_bool from .md import as_markdown_list, is_markdown_list, is_markdown_table from .python_reqs import PythonExecutionReq -from .rag import CitationRequirement +from .rag import CitationMode, CitationRequirement from .requirement import ( ALoraRequirement, LLMaJRequirement, @@ -18,6 +18,7 @@ __all__ = [ "ALoraRequirement", + "CitationMode", "CitationRequirement", "LLMaJRequirement", "PythonExecutionReq", diff --git a/mellea/stdlib/requirements/rag.py b/mellea/stdlib/requirements/rag.py index 46999bf52..f71ffcd91 100644 --- a/mellea/stdlib/requirements/rag.py +++ b/mellea/stdlib/requirements/rag.py @@ -1,12 +1,27 @@ """Requirements for RAG (Retrieval-Augmented Generation) workflows.""" from collections.abc import Iterable +from enum import Enum from ...backends.adapters import AdapterMixin from ...core import Backend, Context, Requirement, ValidationResult from ..components import Document, Message +class CitationMode(Enum): + """Mode for calculating citation coverage. + + Attributes: + CLAIMS: Count the fraction of factual claims that have citations. + Each citation record from find_citations represents one claim. + CHARACTERS: Calculate the ratio of cited characters to total characters. + Sums character ranges covered by citations. + """ + + CLAIMS = "claims" + CHARACTERS = "characters" + + class CitationRequirement(Requirement): """Requirement that validates RAG responses have adequate citation coverage. @@ -30,14 +45,19 @@ class CitationRequirement(Requirement): Args: min_citation_coverage: Minimum ratio of cited content (0.0-1.0). - The ratio of characters with citations to total response length - must meet or exceed this threshold. Default is 0.8 (80% coverage). + Interpretation depends on mode: + - CLAIMS mode: fraction of factual claims with citations + - CHARACTERS mode: ratio of cited characters to total characters + Default is 0.8 (80% coverage). documents: Optional documents to validate against. Can be Document objects or strings (will be converted to Documents). If provided, these documents will be used instead of documents attached to messages in the context. Default is None (use context documents). + mode: Citation coverage calculation mode. Default is CitationMode.CLAIMS + (count fraction of claims with citations). Use CitationMode.CHARACTERS + to calculate character-based coverage ratio instead. description: Custom description for the requirement. If None, - generates a description based on coverage threshold. + generates a description based on coverage threshold and mode. Example: ```python @@ -64,6 +84,7 @@ def __init__( self, min_citation_coverage: float = 0.8, documents: Iterable[Document] | Iterable[str] | None = None, + mode: CitationMode = CitationMode.CLAIMS, description: str | None = None, ): """Initialize citation coverage requirement.""" @@ -73,6 +94,7 @@ def __init__( ) self.min_citation_coverage = min_citation_coverage + self.mode = mode # Convert documents to Document objects if provided if documents is not None: @@ -87,10 +109,16 @@ def __init__( # Generate description if not provided if description is None: - description = ( - f"Response must have adequate citation coverage " - f"(minimum {min_citation_coverage * 100:.0f}% of content cited)" - ) + if mode == CitationMode.CLAIMS: + description = ( + f"Response must have adequate citation coverage " + f"(minimum {min_citation_coverage * 100:.0f}% of factual claims cited)" + ) + else: # CitationMode.CHARACTERS + description = ( + f"Response must have adequate citation coverage " + f"(minimum {min_citation_coverage * 100:.0f}% of characters cited)" + ) # Initialize parent without validation function - we override validate() instead super().__init__(description=description, validation_fn=None) @@ -198,13 +226,34 @@ async def validate( False, reason=f"Citation detection intrinsic failed: {e!s}" ) - # Calculate citation coverage - - cited_chars = sum( - citation["response_end"] - citation["response_begin"] - for citation in citations - ) - coverage_ratio = cited_chars / total_chars + # Calculate citation coverage based on mode + if self.mode == CitationMode.CLAIMS: + # Count fraction of claims (citation records) that exist + # Each citation record represents a factual claim that has a citation + # We need to estimate total claims in the response + # For now, use a simple heuristic: split by sentence-ending punctuation + import re + + # Split response into sentences (simple heuristic) + sentences = re.split(r"[.!?]+", response) + # Filter out empty strings and whitespace-only strings + sentences = [s.strip() for s in sentences if s.strip()] + total_claims = len(sentences) + + if total_claims == 0: + # Edge case: no sentences detected + coverage_ratio = 1.0 if len(citations) == 0 else 0.0 + else: + # Number of claims with citations = number of citation records + cited_claims = len(citations) + coverage_ratio = cited_claims / total_claims + else: # CitationMode.CHARACTERS + # Calculate character-based coverage + cited_chars = sum( + citation["response_end"] - citation["response_begin"] + for citation in citations + ) + coverage_ratio = cited_chars / total_chars # Check against min_citation_coverage passed = coverage_ratio >= self.min_citation_coverage @@ -231,15 +280,20 @@ def _build_reason( coverage_pct = coverage_ratio * 100 threshold_pct = self.min_citation_coverage * 100 + if self.mode == CitationMode.CLAIMS: + metric_name = "claims" + else: + metric_name = "characters" + if passed: reason = ( f"Response has adequate citation coverage " - f"({coverage_pct:.1f}% cited, threshold: {threshold_pct:.1f}%)" + f"({coverage_pct:.1f}% of {metric_name} cited, threshold: {threshold_pct:.1f}%)" ) else: reason = ( f"Response has insufficient citation coverage " - f"({coverage_pct:.1f}% cited, threshold: {threshold_pct:.1f}%)" + f"({coverage_pct:.1f}% of {metric_name} cited, threshold: {threshold_pct:.1f}%)" ) # Add details about citations diff --git a/test/stdlib/requirements/test_rag_requirements.py b/test/stdlib/requirements/test_rag_requirements.py index 08c22eaad..29e80d4ae 100644 --- a/test/stdlib/requirements/test_rag_requirements.py +++ b/test/stdlib/requirements/test_rag_requirements.py @@ -6,7 +6,7 @@ from mellea.backends.huggingface import LocalHFBackend from mellea.stdlib.components import Document, Message from mellea.stdlib.context import ChatContext -from mellea.stdlib.requirements.rag import CitationRequirement +from mellea.stdlib.requirements.rag import CitationMode, CitationRequirement @pytest.mark.huggingface @@ -269,18 +269,137 @@ async def test_citation_requirement_threshold_boundary(): "mellea.stdlib.components.intrinsic.rag.find_citations", return_value=mock_citations, ): - # Test at exact threshold (0.8) - req = CitationRequirement(min_citation_coverage=0.8) + # Test at exact threshold (0.8) with CHARACTERS mode + req = CitationRequirement( + min_citation_coverage=0.8, mode=CitationMode.CHARACTERS + ) result = await req.validate(backend, ctx) # At exact threshold, should pass (>= comparison) assert result.as_bool() assert result.score == 0.8 - # Test just below threshold (0.81) - req_above = CitationRequirement(min_citation_coverage=0.81) + # Test just below threshold (0.81) with CHARACTERS mode + req_above = CitationRequirement( + min_citation_coverage=0.81, mode=CitationMode.CHARACTERS + ) result_above = await req_above.validate(backend, ctx) # Just below threshold, should fail assert not result_above.as_bool() assert result_above.score == 0.8 + + +async def test_citation_requirement_claims_mode(): + """Test citation requirement with CLAIMS mode.""" + from unittest.mock import Mock, patch + + backend = Mock(spec=LocalHFBackend) + + # Create documents + docs = [ + Document(doc_id="doc1", text="The sky is blue."), + Document(doc_id="doc2", text="Grass is green."), + ] + + # Create a response with 3 sentences + response = "The sky is blue. Grass is green. Water is wet." + + # Create context + ctx = ChatContext().add(Message("user", "Tell me some facts.")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Mock find_citations to return 2 citations (2 out of 3 claims = 66.7%) + mock_citations = [ + { + "response_begin": 0, + "response_end": 16, + "response_text": "The sky is blue", + "citation_doc_id": "doc1", + "citation_text": "The sky is blue.", + }, + { + "response_begin": 18, + "response_end": 34, + "response_text": "Grass is green", + "citation_doc_id": "doc2", + "citation_text": "Grass is green.", + }, + ] + + with patch( + "mellea.stdlib.components.intrinsic.rag.find_citations", + return_value=mock_citations, + ): + # Test with CLAIMS mode (default) - 2 citations out of 3 sentences = 66.7% + req = CitationRequirement(min_citation_coverage=0.6) + result = await req.validate(backend, ctx) + + # Should pass (66.7% >= 60%) + assert result.as_bool() + assert result.score is not None + assert abs(result.score - 0.667) < 0.01 # Allow small floating point error + assert result.reason is not None + assert "claims" in result.reason + + # Test with higher threshold that should fail + req_high = CitationRequirement(min_citation_coverage=0.7) + result_high = await req_high.validate(backend, ctx) + + # Should fail (66.7% < 70%) + assert not result_high.as_bool() + assert result_high.score is not None + assert abs(result_high.score - 0.667) < 0.01 + + +async def test_citation_requirement_characters_vs_claims(): + """Test that CHARACTERS and CLAIMS modes produce different results.""" + from unittest.mock import Mock, patch + + backend = Mock(spec=LocalHFBackend) + + # Create documents + docs = [Document(doc_id="doc1", text="Short fact.")] + + # Create a response: 1 short sentence with citation, 1 long sentence without + response = "Short. This is a much longer sentence without any citation support." + + # Create context + ctx = ChatContext().add(Message("user", "Tell me something.")) + ctx = ctx.add(Message("assistant", response, documents=docs)) + + # Mock find_citations to return 1 citation for the short sentence + mock_citations = [ + { + "response_begin": 0, + "response_end": 6, # "Short." = 6 characters + "response_text": "Short", + "citation_doc_id": "doc1", + "citation_text": "Short fact.", + } + ] + + with patch( + "mellea.stdlib.components.intrinsic.rag.find_citations", + return_value=mock_citations, + ): + # CLAIMS mode: 1 citation out of 2 sentences = 50% + req_claims = CitationRequirement( + min_citation_coverage=0.5, mode=CitationMode.CLAIMS + ) + result_claims = await req_claims.validate(backend, ctx) + + # CHARACTERS mode: 6 characters out of 67 total = ~9% + req_chars = CitationRequirement( + min_citation_coverage=0.5, mode=CitationMode.CHARACTERS + ) + result_chars = await req_chars.validate(backend, ctx) + + # CLAIMS mode should pass (50% >= 50%) + assert result_claims.as_bool() + assert result_claims.score == 0.5 + + # CHARACTERS mode should fail (~9% < 50%) + assert not result_chars.as_bool() + assert result_chars.score is not None + assert result_chars.score < 0.1 From c18c9b2f7981958f9e6a6a604460c50c406b310f Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Tue, 24 Mar 2026 16:19:33 -0400 Subject: [PATCH 8/9] feat: review comments Signed-off-by: Akihiko Kuroda --- docs/examples/citation_requirement_example.py | 46 +++++++++++++++---- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/docs/examples/citation_requirement_example.py b/docs/examples/citation_requirement_example.py index c1d8faaa5..580143732 100644 --- a/docs/examples/citation_requirement_example.py +++ b/docs/examples/citation_requirement_example.py @@ -13,7 +13,7 @@ from mellea.backends.huggingface import LocalHFBackend from mellea.stdlib.components import Document, Message from mellea.stdlib.context import ChatContext -from mellea.stdlib.requirements.rag import CitationRequirement +from mellea.stdlib.requirements.rag import CitationMode, CitationRequirement async def main(): @@ -64,16 +64,42 @@ async def main(): ) print(f"Reason: {reason_preview}") - # Example 2: Higher coverage threshold - print("\n--- Example 2: Higher coverage threshold (80%) ---") - req2 = CitationRequirement(min_citation_coverage=0.8, documents=docs) + # Example 2: CLAIMS mode (default) - counts fraction of claims with citations + print("\n--- Example 2: CLAIMS mode (default) - fraction of claims cited ---") + req2 = CitationRequirement( + min_citation_coverage=0.7, documents=docs, mode=CitationMode.CLAIMS + ) result2 = await req2.validate(backend, ctx) print(f"Validation passed: {result2.as_bool()}") print(f"Citation coverage score: {result2.score:.2%}") + if result2.reason: + reason_preview = ( + result2.reason[:200] + "..." + if len(result2.reason) > 200 + else result2.reason + ) + print(f"Reason: {reason_preview}") + + # Example 3: CHARACTERS mode - calculates character-based coverage + print("\n--- Example 3: CHARACTERS mode - character-based coverage ---") + req3 = CitationRequirement( + min_citation_coverage=0.7, documents=docs, mode=CitationMode.CHARACTERS + ) + result3 = await req3.validate(backend, ctx) + + print(f"Validation passed: {result3.as_bool()}") + print(f"Citation coverage score: {result3.score:.2%}") + if result3.reason: + reason_preview = ( + result3.reason[:200] + "..." + if len(result3.reason) > 200 + else result3.reason + ) + print(f"Reason: {reason_preview}") - # Example 3: Documents attached to message - print("\n--- Example 3: Documents in message (not constructor) ---") + # Example 4: Documents attached to message + print("\n--- Example 4: Documents in message (not constructor) ---") ctx2 = ChatContext().add(Message("user", "Tell me about Mars.")) ctx2 = ctx2.add( Message( @@ -85,11 +111,11 @@ async def main(): ) ) - req3 = CitationRequirement(min_citation_coverage=0.7) # No documents in constructor - result3 = await req3.validate(backend, ctx2) + req4 = CitationRequirement(min_citation_coverage=0.7) # No documents in constructor + result4 = await req4.validate(backend, ctx2) - print(f"Validation passed: {result3.as_bool()}") - print(f"Citation coverage score: {result3.score:.2%}") + print(f"Validation passed: {result4.as_bool()}") + print(f"Citation coverage score: {result4.score:.2%}") print("\n" + "=" * 70) print("Example completed successfully!") From 84472ff39c3e3fe829dbf110851f7d30a1d4f6cb Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Wed, 25 Mar 2026 10:00:26 -0400 Subject: [PATCH 9/9] review comment Signed-off-by: Akihiko Kuroda --- mellea/stdlib/requirements/rag.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/mellea/stdlib/requirements/rag.py b/mellea/stdlib/requirements/rag.py index f71ffcd91..705a2a012 100644 --- a/mellea/stdlib/requirements/rag.py +++ b/mellea/stdlib/requirements/rag.py @@ -6,6 +6,7 @@ from ...backends.adapters import AdapterMixin from ...core import Backend, Context, Requirement, ValidationResult from ..components import Document, Message +from ..context import ChatContext class CitationMode(Enum): @@ -191,17 +192,11 @@ async def validate( all_messages = ctx.as_list() if len(all_messages) > 1: # Rebuild context without last message - # Import here to avoid circular dependency - from ..context import ChatContext - context_before_response = ChatContext() for msg in all_messages[:-1]: context_before_response = context_before_response.add(msg) else: # If only one message, use empty context - # Import here to avoid circular dependency - from ..context import ChatContext - context_before_response = ChatContext() # Handle empty response before calling intrinsic @@ -215,7 +210,7 @@ async def validate( # Call find_citations intrinsic try: - # Import here to avoid circular dependency + # Import here to avoid circular dependency with backends from ..components.intrinsic import rag citations: list[dict] = rag.find_citations(