diff --git a/AGENTS.md b/AGENTS.md index ed61b64e..ecd80822 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -23,6 +23,7 @@ Important code areas in `py/src/braintrust/`: - temporal: `contrib/temporal/` - CLI/devserver: `cli/`, `devserver/` - tests: colocated `test_*.py` +- type tests: `type_tests/` ## Setup @@ -83,6 +84,7 @@ Testing preferences: Key facts: - `test_core` runs without optional vendor packages. +- `test_types` runs pyright, mypy, and pytest on `py/src/braintrust/type_tests/`. Use this session when changing generic type signatures in the framework. - wrapper coverage is split across dedicated nox sessions by provider/version. - `pylint` installs the broad dependency surface before checking files. - `cd py && make pylint` runs only `pylint`; `cd py && make lint` runs pre-commit hooks first and then `pylint`. @@ -90,6 +92,19 @@ Key facts: When changing behavior, run the narrowest affected session first, then expand only if needed. +## Type Tests + +`py/src/braintrust/type_tests/` contains tests that are validated by both static type checkers (pyright, mypy) and pytest at runtime. The `test_types` nox session runs all three checks and is auto-discovered by CI. + +When changing generic type signatures (e.g., `Eval`, `EvalCase`, `EvalScorer`, `EvalHooks`), add or update a test in `type_tests/` to verify the type checker accepts the intended usage patterns. + +New test files should be named `test_*.py` and use absolute imports (`from braintrust.framework import ...`). They are regular pytest files that also happen to be valid pyright/mypy targets. + +```bash +cd py +nox -s test_types +``` + ## VCR VCR/cassette coverage is the default and preferred testing strategy for provider and integration behavior in this repo. Reach for cassette-backed tests before introducing mocks or fakes, and keep new coverage aligned with the existing VCR patterns unless there is a strong reason not to. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ab5d641c..19c5c00c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -134,6 +134,17 @@ BRAINTRUST_CLAUDE_AGENT_SDK_RECORD_MODE=all \ nox -s "test_claude_agent_sdk(latest)" -- -k "test_calculator_with_multiple_operations" ``` +### Type Tests + +`py/src/braintrust/type_tests/` contains tests that are checked by pyright, mypy, and pytest. The `test_types` nox session runs all three and is included in CI automatically. + +When changing generic type signatures (e.g., `Eval`, `EvalCase`, `EvalScorer`, `EvalHooks`), add or update a test file in `type_tests/` to verify the type checker accepts the intended usage patterns. Test files are named `test_*.py`, use absolute imports (`from braintrust.framework import ...`), and double as regular pytest files. + +```bash +cd py +nox -s test_types +``` + ### Fixtures Shared test fixtures live in `py/src/braintrust/conftest.py`. diff --git a/py/noxfile.py b/py/noxfile.py index 09c044c1..77b73ad7 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -44,6 +44,7 @@ def _pinned_python_version(): INTEGRATION_DIR = "braintrust/integrations" CONTRIB_DIR = "braintrust/contrib" DEVSERVER_DIR = "braintrust/devserver" +TYPE_TESTS_DIR = "braintrust/type_tests" SILENT_INSTALLS = True @@ -390,6 +391,27 @@ def test_otel_not_installed(session): _run_tests(session, "braintrust/test_otel.py") +@nox.session() +def test_types(session): + """Run type-check tests with pyright, mypy, and pytest.""" + _install_test_deps(session) + session.install("pyright==1.1.408", "mypy==1.20.0") + + type_tests_dir = f"src/{TYPE_TESTS_DIR}" + test_files = glob.glob(os.path.join(type_tests_dir, "test_*.py")) + if not test_files: + session.skip("No type test files found") + + # Run pyright on each file + session.run("pyright", *test_files) + + # Run mypy on each file (only check the test files themselves, not transitive deps) + session.run("mypy", "--follow-imports=silent", *test_files) + + # Run pytest for the runtime assertions + _run_tests(session, TYPE_TESTS_DIR) + + @nox.session() def pylint(session): # pylint needs everything so we don't trigger missing import errors @@ -502,6 +524,7 @@ def _run_core_tests(session): *_integration_subdirs_to_ignore(), CONTRIB_DIR, DEVSERVER_DIR, + TYPE_TESTS_DIR, ], ) diff --git a/py/src/braintrust/devserver/server.py b/py/src/braintrust/devserver/server.py index 4b3ff79a..37e2965e 100644 --- a/py/src/braintrust/devserver/server.py +++ b/py/src/braintrust/devserver/server.py @@ -50,7 +50,7 @@ from .schemas import ValidationError, parse_eval_body -_all_evaluators: dict[str, Evaluator[Any, Any]] = {} +_all_evaluators: dict[str, Evaluator[Any, Any, Any]] = {} class _ParameterOverrideHooks: @@ -289,7 +289,7 @@ async def run_and_complete(): return JSONResponse({"error": f"Failed to run evaluation: {str(e)}"}, status_code=500) -def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = None): +def create_app(evaluators: list[Evaluator[Any, Any, Any]], org_name: str | None = None): """Create and configure the Starlette app for the dev server. Args: @@ -318,7 +318,7 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = Non def run_dev_server( - evaluators: list[Evaluator[Any, Any]], + evaluators: list[Evaluator[Any, Any, Any]], host: str = "localhost", port: int = 8300, org_name: str | None = None, @@ -346,7 +346,7 @@ def snake_to_camel(snake_str: str) -> str: def make_scorer( state: BraintrustState, name: str, score: FunctionId, project_id: str | None = None -) -> EvalScorer[Any, Any]: +) -> EvalScorer[Any, Any, Any]: def scorer_fn(input, output, expected, metadata): request = { **score, diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 950e9fcc..cbe529cf 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -58,6 +58,7 @@ Input = TypeVar("Input") Output = TypeVar("Output") +Expected = TypeVar("Expected") # https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal @@ -74,14 +75,14 @@ class bcolors: @dataclasses.dataclass -class EvalCase(SerializableDataClass, Generic[Input, Output]): +class EvalCase(SerializableDataClass, Generic[Input, Expected]): """ An evaluation case. This is a single input to the evaluation task, along with an optional expected output, metadata, and tags. """ input: Input - expected: Output | None = None + expected: Expected | None = None metadata: Metadata | None = None tags: Sequence[str] | None = None @@ -94,13 +95,13 @@ class EvalCase(SerializableDataClass, Generic[Input, Output]): # Inheritance doesn't quite work for dataclasses, so we redefine the fields # from EvalCase here. @dataclasses.dataclass -class EvalResult(SerializableDataClass, Generic[Input, Output]): +class EvalResult(SerializableDataClass, Generic[Input, Output, Expected]): """The result of an evaluation. This includes the input, expected output, actual output, and metadata.""" input: Input output: Output scores: dict[str, float | None] - expected: Output | None = None + expected: Expected | None = None metadata: Metadata | None = None tags: list[str] | None = None error: Exception | None = None @@ -138,7 +139,7 @@ class SSEProgressEvent(TaskProgressEvent): name: str -class EvalHooks(abc.ABC, Generic[Output]): +class EvalHooks(abc.ABC, Generic[Expected]): """ An object that can be used to add metadata to an evaluation. This is passed to the `task` function. """ @@ -152,7 +153,7 @@ def metadata(self) -> Metadata | None: @property @abc.abstractmethod - def expected(self) -> Output | None: + def expected(self) -> Expected | None: """ The expected output for the current evaluation. """ @@ -204,14 +205,14 @@ def parameters(self) -> ValidatedParameters | None: """ -class EvalScorerArgs(SerializableDataClass, Generic[Input, Output]): +class EvalScorerArgs(SerializableDataClass, Generic[Input, Output, Expected]): """ Arguments passed to an evaluator scorer. This includes the input, expected output, actual output, and metadata. """ input: Input output: Output - expected: Output | None = None + expected: Expected | None = None metadata: Metadata | None = None @@ -219,35 +220,35 @@ class EvalScorerArgs(SerializableDataClass, Generic[Input, Output]): # Synchronous scorer interface - implements callable -class SyncScorerLike(Protocol, Generic[Input, Output]): +class SyncScorerLike(Protocol, Generic[Input, Output, Expected]): """ Protocol for synchronous scorers that implement the callable interface. This is the most common interface and is used when no async version is available. """ def __call__( - self, input: Input, output: Output, expected: Output | None = None, **kwargs: Any + self, input: Input, output: Output, expected: Expected | None = None, **kwargs: Any ) -> OneOrMoreScores: ... # Asynchronous scorer interface -class AsyncScorerLike(Protocol, Generic[Input, Output]): +class AsyncScorerLike(Protocol, Generic[Input, Output, Expected]): """ Protocol for asynchronous scorers that implement the eval_async interface. The framework will prefer this interface if available. """ - async def eval_async(self, output: Output, expected: Output | None = None, **kwargs: Any) -> OneOrMoreScores: ... + async def eval_async(self, output: Output, expected: Expected | None = None, **kwargs: Any) -> OneOrMoreScores: ... # Union type for any kind of scorer (for typing) -ScorerLike = Union[SyncScorerLike[Input, Output], AsyncScorerLike[Input, Output]] +ScorerLike = Union[SyncScorerLike[Input, Output, Expected], AsyncScorerLike[Input, Output, Expected]] EvalScorer = Union[ - ScorerLike[Input, Output], - type[ScorerLike[Input, Output]], - Callable[[Input, Output, Output], OneOrMoreScores], - Callable[[Input, Output, Output], Awaitable[OneOrMoreScores]], + ScorerLike[Input, Output, Expected], + type[ScorerLike[Input, Output, Expected]], + Callable[[Input, Output, Expected], OneOrMoreScores], + Callable[[Input, Output, Expected], Awaitable[OneOrMoreScores]], ] @@ -267,32 +268,32 @@ class BaseExperiment: _AnyEvalCase = Union[ - EvalCase[Input, Output], - EvalCaseDict[Input, Output], + EvalCase[Input, Expected], + EvalCaseDict[Input, Expected], EvalCaseDictNoOutput[Input], ExperimentDatasetEvent, ] _EvalDataObject = Union[ - Iterable[_AnyEvalCase[Input, Output]], - Iterator[_AnyEvalCase[Input, Output]], - Awaitable[Iterator[_AnyEvalCase[Input, Output]]], - Callable[[], Union[Iterator[_AnyEvalCase[Input, Output]], Awaitable[Iterator[_AnyEvalCase[Input, Output]]]]], + Iterable[_AnyEvalCase[Input, Expected]], + Iterator[_AnyEvalCase[Input, Expected]], + Awaitable[Iterator[_AnyEvalCase[Input, Expected]]], + Callable[[], Union[Iterator[_AnyEvalCase[Input, Expected]], Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]]], BaseExperiment, ] -EvalData = Union[_EvalDataObject[Input, Output], type[_EvalDataObject[Input, Output]], Dataset] +EvalData = Union[_EvalDataObject[Input, Expected], type[_EvalDataObject[Input, Expected]], Dataset] EvalTask = Union[ Callable[[Input], Union[Output, Awaitable[Output]]], - Callable[[Input, EvalHooks[Output]], Union[Output, Awaitable[Output]]], + Callable[[Input, EvalHooks[Expected]], Union[Output, Awaitable[Output]]], ] -ErrorScoreHandler = Callable[[Span, EvalCase[Input, Output], list[str]], Optional[dict[str, float]]] +ErrorScoreHandler = Callable[[Span, EvalCase[Input, Expected], list[str]], Optional[dict[str, float]]] @dataclasses.dataclass -class Evaluator(Generic[Input, Output]): +class Evaluator(Generic[Input, Output, Expected]): """ An evaluator is an abstraction that defines an evaluation dataset, a task to run on the dataset, and a set of scorers to evaluate the results of the task. Each method attribute can be synchronous or asynchronous (for @@ -312,18 +313,18 @@ class Evaluator(Generic[Input, Output]): A name that describes the experiment. You do not need to change it each time the experiment runs. """ - data: EvalData[Input, Output] + data: EvalData[Input, Expected] """ Returns an iterator over the evaluation dataset. Each element of the iterator should be an `EvalCase` or a dict with the same fields as an `EvalCase` (`input`, `expected`, `metadata`). """ - task: EvalTask[Input, Output] + task: EvalTask[Input, Output, Expected] """ Runs the evaluation task on a single input. The `hooks` object can be used to add metadata to the evaluation. """ - scores: list[EvalScorer[Input, Output]] + scores: list[EvalScorer[Input, Output, Expected]] """ A list of scorers to evaluate the results of the task. Each scorer can be a Scorer object or a function that takes `input`, `output`, and `expected` arguments and returns a `Score` object. The function can be async. @@ -405,7 +406,7 @@ class Evaluator(Generic[Input, Output]): takes precedence over `git_metadata_settings` if specified. """ - error_score_handler: ErrorScoreHandler[Input, Output] | None = None + error_score_handler: ErrorScoreHandler[Input, Expected] | None = None """ Optionally supply a custom function to specifically handle score values when tasks or scoring functions have errored. A default implementation is exported as `default_error_score_handler` which will log a 0 score to the root span for any scorer that was not run. @@ -431,9 +432,9 @@ class Evaluator(Generic[Input, Output]): @dataclasses.dataclass -class EvalResultWithSummary(SerializableDataClass, Generic[Input, Output]): +class EvalResultWithSummary(SerializableDataClass, Generic[Input, Output, Expected]): summary: ExperimentSummary - results: list[EvalResult[Input, Output]] + results: list[EvalResult[Input, Output, Expected]] def _repr_pretty_(self, p, cycle): p.text(f'EvalResultWithSummary(summary="...", results=[...])') @@ -502,7 +503,7 @@ async def call_user_fn(event_loop, fn, **kwargs): @dataclasses.dataclass -class ReporterDef(SerializableDataClass, Generic[Input, Output, EvalReport]): +class ReporterDef(SerializableDataClass, Generic[Input, Output, Expected, EvalReport]): """ A reporter takes an evaluator and its result and returns a report. """ @@ -513,7 +514,7 @@ class ReporterDef(SerializableDataClass, Generic[Input, Output, EvalReport]): """ report_eval: Callable[ - [Evaluator[Input, Output], EvalResultWithSummary[Input, Output], bool, bool], + [Evaluator[Input, Output, Expected], EvalResultWithSummary[Input, Output, Expected], bool, bool], EvalReport | Awaitable[EvalReport], ] """ @@ -528,8 +529,8 @@ class ReporterDef(SerializableDataClass, Generic[Input, Output, EvalReport]): async def _call_report_eval( self, - evaluator: Evaluator[Input, Output], - result: EvalResultWithSummary[Input, Output], + evaluator: Evaluator[Input, Output, Expected], + result: EvalResultWithSummary[Input, Output, Expected], verbose: bool, jsonl: bool, ) -> EvalReport | Awaitable[EvalReport]: @@ -544,9 +545,9 @@ async def _call_report_run(self, results: list[EvalReport], verbose: bool, jsonl @dataclasses.dataclass -class EvaluatorInstance(SerializableDataClass, Generic[Input, Output, EvalReport]): - evaluator: Evaluator[Input, Output] - reporter: ReporterDef[Input, Output, EvalReport] | str | None +class EvaluatorInstance(SerializableDataClass, Generic[Input, Output, Expected, EvalReport]): + evaluator: Evaluator[Input, Output, Expected] + reporter: ReporterDef[Input, Output, Expected, EvalReport] | str | None @dataclasses.dataclass @@ -643,16 +644,16 @@ def _make_eval_name(name: str, experiment_name: str | None): def _EvalCommon( name: str, - data: EvalData[Input, Output], - task: EvalTask[Input, Output], - scores: Sequence[EvalScorer[Input, Output]], + data: EvalData[Input, Expected], + task: EvalTask[Input, Output, Expected], + scores: Sequence[EvalScorer[Input, Output, Expected]], experiment_name: str | None, trial_count: int, metadata: Metadata | None, tags: list[str] | None, is_public: bool, update: bool, - reporter: ReporterDef[Input, Output, EvalReport] | None, + reporter: ReporterDef[Input, Output, Expected, EvalReport] | None, timeout: float | None, max_concurrency: int | None, project_id: str | None, @@ -663,14 +664,14 @@ def _EvalCommon( description: str | None, summarize_scores: bool, no_send_logs: bool, - error_score_handler: ErrorScoreHandler[Input, Output] | None = None, + error_score_handler: ErrorScoreHandler[Input, Expected] | None = None, parameters: EvalParameters | RemoteEvalParameters | None = None, on_start: Callable[[ExperimentSummary], None] | None = None, stream: Callable[[SSEProgressEvent], None] | None = None, parent: str | None = None, state: BraintrustState | None = None, enable_cache: bool = True, -) -> Callable[[], Coroutine[Any, Any, EvalResultWithSummary[Input, Output]]]: +) -> Callable[[], Coroutine[Any, Any, EvalResultWithSummary[Input, Output, Expected]]]: """ This helper is needed because in case of `_lazy_load`, we need to update the `_evals` global immediately instead of whenever the coroutine is @@ -779,16 +780,16 @@ async def run_to_completion(): async def EvalAsync( name: str, - data: EvalData[Input, Output], - task: EvalTask[Input, Output], - scores: Sequence[EvalScorer[Input, Output]], + data: EvalData[Input, Expected], + task: EvalTask[Input, Output, Expected], + scores: Sequence[EvalScorer[Input, Output, Expected]], experiment_name: str | None = None, trial_count: int = 1, metadata: Metadata | None = None, tags: list[str] | None = None, is_public: bool = False, update: bool = False, - reporter: ReporterDef[Input, Output, EvalReport] | None = None, + reporter: ReporterDef[Input, Output, Expected, EvalReport] | None = None, timeout: float | None = None, max_concurrency: int | None = None, project_id: str | None = None, @@ -796,7 +797,7 @@ async def EvalAsync( base_experiment_id: str | None = None, git_metadata_settings: GitMetadataSettings | None = None, repo_info: RepoInfo | None = None, - error_score_handler: ErrorScoreHandler[Input, Output] | None = None, + error_score_handler: ErrorScoreHandler[Input, Expected] | None = None, description: str | None = None, summarize_scores: bool = True, no_send_logs: bool = False, @@ -806,7 +807,7 @@ async def EvalAsync( parent: str | None = None, state: BraintrustState | None = None, enable_cache: bool = True, -) -> EvalResultWithSummary[Input, Output]: +) -> EvalResultWithSummary[Input, Output, Expected]: """ A function you can use to define an evaluator. This is a convenience wrapper around the `Evaluator` class. @@ -906,16 +907,16 @@ async def EvalAsync( def Eval( name: str, - data: EvalData[Input, Output], - task: EvalTask[Input, Output], - scores: Sequence[EvalScorer[Input, Output]], + data: EvalData[Input, Expected], + task: EvalTask[Input, Output, Expected], + scores: Sequence[EvalScorer[Input, Output, Expected]], experiment_name: str | None = None, trial_count: int = 1, metadata: Metadata | None = None, tags: list[str] | None = None, is_public: bool = False, update: bool = False, - reporter: ReporterDef[Input, Output, EvalReport] | None = None, + reporter: ReporterDef[Input, Output, Expected, EvalReport] | None = None, timeout: float | None = None, max_concurrency: int | None = None, project_id: str | None = None, @@ -923,7 +924,7 @@ def Eval( base_experiment_id: str | None = None, git_metadata_settings: GitMetadataSettings | None = None, repo_info: RepoInfo | None = None, - error_score_handler: ErrorScoreHandler[Input, Output] | None = None, + error_score_handler: ErrorScoreHandler[Input, Expected] | None = None, description: str | None = None, summarize_scores: bool = True, no_send_logs: bool = False, @@ -933,7 +934,7 @@ def Eval( parent: str | None = None, state: BraintrustState | None = None, enable_cache: bool = True, -) -> EvalResultWithSummary[Input, Output]: +) -> EvalResultWithSummary[Input, Output, Expected]: """ A function you can use to define an evaluator. This is a convenience wrapper around the `Evaluator` class. @@ -1053,7 +1054,7 @@ def Eval( def Reporter( name: str, report_eval: Callable[ - [Evaluator[Input, Output], EvalResultWithSummary[Input, Output], bool, bool], + [Evaluator[Input, Output, Expected], EvalResultWithSummary[Input, Output, Expected], bool, bool], EvalReport | Awaitable[EvalReport], ], report_run: Callable[[list[EvalReport], bool, bool], bool | Awaitable[bool]], @@ -1267,13 +1268,13 @@ def helper(): async def run_evaluator( experiment: Experiment | None, - evaluator: Evaluator[Input, Output], + evaluator: Evaluator[Input, Output, Expected], position: int | None, filters: list[Filter], stream: Callable[[SSEProgressEvent], None] | None = None, state: BraintrustState | None = None, enable_cache: bool = True, -) -> EvalResultWithSummary[Input, Output]: +) -> EvalResultWithSummary[Input, Output, Expected]: """Wrapper on _run_evaluator_internal that times out execution after evaluator.timeout.""" results = await asyncio.wait_for( _run_evaluator_internal(experiment, evaluator, position, filters, stream, state, enable_cache), @@ -1290,7 +1291,7 @@ async def run_evaluator( def default_error_score_handler( root_span: Span, - data: EvalCase[Input, Output], + data: EvalCase[Input, Expected], unhandled_scores: list[str], ): scores = {s: 0 for s in unhandled_scores} @@ -1691,7 +1692,7 @@ async def with_max_concurrency(coro): def build_local_summary( - evaluator: Evaluator[Input, Output], results: list[EvalResultWithSummary[Input, Output]] + evaluator: Evaluator[Input, Output, Expected], results: list[EvalResultWithSummary[Input, Output, Expected]] ) -> ExperimentSummary: scores_by_name = defaultdict(lambda: (0, 0)) for result in results: diff --git a/py/src/braintrust/type_tests/__init__.py b/py/src/braintrust/type_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/py/src/braintrust/type_tests/test_eval_generics.py b/py/src/braintrust/type_tests/test_eval_generics.py new file mode 100644 index 00000000..8ef793ac --- /dev/null +++ b/py/src/braintrust/type_tests/test_eval_generics.py @@ -0,0 +1,120 @@ +"""Type-check tests for the Eval framework generic parameters. + +These tests verify that pyright/mypy accept valid usage patterns +and that the runtime behavior is correct. + +Run as type checks: + nox -s test_types + +Run as pytest: + pytest src/braintrust/type_tests/test_eval_generics.py +""" + +from typing import TypedDict + +import pytest +from braintrust.framework import EvalAsync, EvalCase, EvalResultWithSummary +from braintrust.score import Score + + +# --- Domain types for testing --- +class ModelOutput(TypedDict): + answer: str + confidence: float + + +class AssertionSpec: + """Assertion specification — not the same type as the model output.""" + + def __init__(self, field: str, expected_value: str): + self.field = field + self.expected_value = expected_value + + +# ============================================================ +# Case 1: Same-type Output and Expected (should always work) +# ============================================================ + + +def same_type_data(): + return iter([EvalCase(input="query", expected="golden answer")]) + + +async def same_type_task(input: str) -> str: + return "model answer" + + +async def same_type_scorer(input: str, output: str, expected: str | None = None) -> Score: + return Score(name="match", score=1.0 if output == expected else 0.0) + + +# ============================================================ +# Case 2: Divergent Output and Expected (the bug from #240) +# ============================================================ + + +def divergent_data(): + return iter( + [ + EvalCase( + input="What is 2+2?", + expected=frozenset({AssertionSpec("answer", "4")}), + ), + ] + ) + + +async def divergent_task(input: str) -> ModelOutput: + return ModelOutput(answer="4", confidence=0.99) + + +async def divergent_scorer( + input: str, + output: ModelOutput, + expected: frozenset[AssertionSpec] | None = None, +) -> Score: + if expected is None: + return Score(name="match", score=0) + for spec in expected: + if output.get(spec.field) != spec.expected_value: + return Score(name="match", score=0) + return Score(name="match", score=1) + + +# ============================================================ +# Runtime tests — confirm the eval framework works correctly +# with divergent types at runtime. +# ============================================================ + + +@pytest.mark.asyncio +async def test_eval_same_type_output_and_expected(): + """Output and Expected are the same type — classic pattern.""" + result = await EvalAsync( + "test-same-type", + data=same_type_data, + task=same_type_task, + scores=[same_type_scorer], + no_send_logs=True, + ) + assert isinstance(result, EvalResultWithSummary) + assert len(result.results) == 1 + assert result.results[0].output == "model answer" + assert result.results[0].expected == "golden answer" + + +@pytest.mark.asyncio +async def test_eval_divergent_output_and_expected(): + """Output and Expected differ — the pattern reported in #240.""" + result = await EvalAsync( + "test-divergent", + data=divergent_data, + task=divergent_task, + scores=[divergent_scorer], + no_send_logs=True, + ) + assert isinstance(result, EvalResultWithSummary) + assert len(result.results) == 1 + assert result.results[0].output == ModelOutput(answer="4", confidence=0.99) + assert isinstance(result.results[0].expected, frozenset) + assert result.results[0].scores.get("match") == 1.0 diff --git a/py/src/braintrust/types/_eval.py b/py/src/braintrust/types/_eval.py index 1252d0b4..c8b5dc6d 100644 --- a/py/src/braintrust/types/_eval.py +++ b/py/src/braintrust/types/_eval.py @@ -11,15 +11,15 @@ Input = TypeVar("Input") -Output = TypeVar("Output") +Expected = TypeVar("Expected") class EvalCaseDictNoOutput(Generic[Input], TypedDict): """ Workaround for the Pyright type checker handling of generics. Specifically, the type checker doesn't know that a dict which is missing the key - "expected" can be used to satisfy ``EvalCaseDict[Input, Output]`` for any - ``Output`` type. + "expected" can be used to satisfy ``EvalCaseDict[Input, Expected]`` for any + ``Expected`` type. """ input: Input @@ -30,12 +30,12 @@ class EvalCaseDictNoOutput(Generic[Input], TypedDict): _xact_id: NotRequired[str | None] -class EvalCaseDict(Generic[Input, Output], EvalCaseDictNoOutput[Input]): +class EvalCaseDict(Generic[Input, Expected], EvalCaseDictNoOutput[Input]): """ Mirrors EvalCase for callers who pass a dict instead of dataclass. """ - expected: NotRequired[Output | None] + expected: NotRequired[Expected | None] class ExperimentDatasetEvent(TypedDict):