diff --git a/README.md b/README.md index 92a972e..30a4ca5 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,12 @@ Use `--` to forward extra arguments to the eval file via `process.argv`: bt eval foo.eval.ts -- --description "Prod" --shard=1/4 ``` +**Sampling modes:** + +- `bt eval --first 20 qa.eval.ts` — run the first 20 examples and clearly label the summary as a non-final smoke run. +- `bt eval --sample 20 --sample-seed 7 qa.eval.ts` — run a deterministic random sample and clearly label the summary as a non-final smoke run. +- If you do not pass a sampling flag, `bt eval` runs the full dataset and marks the summary as final. + ## `bt sql` - Runs interactively on TTY by default. diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index 8742375..32d3436 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -6,6 +6,7 @@ import inspect import json import os +import random import re import socket import sys @@ -59,6 +60,9 @@ class RunnerConfig: terminate_on_failure: bool num_workers: int | None filters: list[EvalFilter] + first: int | None + sample: int | None + sample_seed: int | None dev_mode: str | None dev_request_json: str | None @@ -142,6 +146,23 @@ def parse_dev_mode(value: str | None) -> str | None: raise ValueError(f"Invalid BT_EVAL_DEV_MODE value: {value}") +def parse_positive_int_env(name: str) -> int | None: + value = os.getenv(name) + if value is None or value == "": + return None + parsed = int(value) + if parsed < 1: + raise ValueError(f"{name} must be a positive integer") + return parsed + + +def parse_int_env(name: str) -> int | None: + value = os.getenv(name) + if value is None or value == "": + return None + return int(value) + + def read_runner_config() -> RunnerConfig: num_workers_value = os.getenv("BT_EVAL_NUM_WORKERS") num_workers = int(num_workers_value) if num_workers_value else None @@ -151,6 +172,9 @@ def read_runner_config() -> RunnerConfig: terminate_on_failure=env_flag("BT_EVAL_TERMINATE_ON_FAILURE"), num_workers=num_workers, filters=parse_serialized_filters(os.getenv("BT_EVAL_FILTER_PARSED")), + first=parse_positive_int_env("BT_EVAL_FIRST"), + sample=parse_positive_int_env("BT_EVAL_SAMPLE"), + sample_seed=parse_int_env("BT_EVAL_SAMPLE_SEED"), dev_mode=parse_dev_mode(os.getenv("BT_EVAL_DEV_MODE")), dev_request_json=os.getenv("BT_EVAL_DEV_REQUEST_JSON"), ) @@ -202,8 +226,35 @@ def snake_to_camel(value: str) -> str: return parts[0] + "".join(word.title() for word in parts[1:]) -def format_summary(summary: dict[str, Any]) -> dict[str, Any]: - return {snake_to_camel(k): v for k, v in summary.items()} +def sampling_metadata(config: RunnerConfig) -> dict[str, Any]: + if config.first is not None: + return { + "runMode": "first", + "isFinal": False, + "runLabel": f"Run mode: first {config.first} examples (non-final smoke run)", + "sampleCount": config.first, + } + if config.sample is not None: + seed = config.sample_seed if config.sample_seed is not None else 0 + return { + "runMode": "sample", + "isFinal": False, + "runLabel": f"Run mode: random sample of {config.sample} examples (seed {seed}, non-final smoke run)", + "sampleCount": config.sample, + "sampleSeed": seed, + } + return { + "runMode": "full", + "isFinal": True, + "runLabel": "Run mode: full dataset", + } + + +def format_summary(summary: dict[str, Any], config: RunnerConfig) -> dict[str, Any]: + return { + **{snake_to_camel(k): v for k, v in summary.items()}, + **sampling_metadata(config), + } def send_eval_progress(sse: SseWriter | None, evaluator_name: str, kind: str, total: int | None = None) -> None: @@ -324,6 +375,79 @@ def resolve_eval_data(data: dict[str, Any]) -> Any: raise ValueError("Invalid eval data payload.") +async def resolve_sampling_source(data: Any) -> Any: + current = data + while True: + if callable(current): + current = current() + continue + if inspect.isawaitable(current): + current = await current + continue + return current + + +async def iter_data_source(data: Any, batch_size_hint: int | None = None): + resolved = await resolve_sampling_source(data) + if isinstance(resolved, Dataset): + fetched = resolved.fetch(batch_size=batch_size_hint) + if hasattr(fetched, "__aiter__"): + async for item in fetched: + yield item + return + for item in fetched: + yield item + return + if isinstance(resolved, (str, bytes, dict)): + raise ValueError( + "Sampling is only supported for arrays, iterables, and Braintrust datasets." + ) + if hasattr(resolved, "__aiter__"): + async for item in resolved: + yield item + return + try: + iterator = iter(resolved) + except TypeError as exc: + raise ValueError( + "Sampling is only supported for arrays, iterables, and Braintrust datasets." + ) from exc + for item in iterator: + yield item + + +async def collect_first_records(data: Any, count: int) -> list[Any]: + items: list[Any] = [] + async for item in iter_data_source(data, batch_size_hint=count): + items.append(item) + if len(items) >= count: + break + return items + + +async def reservoir_sample_records(data: Any, count: int, seed: int) -> list[Any]: + rng = random.Random(seed) + sample: list[Any] = [] + seen = 0 + async for item in iter_data_source(data): + seen += 1 + if len(sample) < count: + sample.append(item) + continue + index = rng.randrange(seen) + if index < count: + sample[index] = item + return sample + + +async def apply_sampling_to_data(data: Any, config: RunnerConfig) -> Any: + if config.first is not None: + return await collect_first_records(data, config.first) + if config.sample is not None: + return await reservoir_sample_records(data, config.sample, config.sample_seed or 0) + return data + + def make_eval_scorer( score: dict[str, Any], project_id: str | None, @@ -680,7 +804,13 @@ def run_evaluator_supports_stream() -> bool: async def run_evaluator_task( - evaluator, position: int, no_send_logs: bool, progress_cb, progress_mode: str, sse: SseWriter | None, + evaluator, + position: int, + no_send_logs: bool, + progress_cb, + progress_mode: str, + sse: SseWriter | None, + config: RunnerConfig, parent: str | None = None, ): experiment = None @@ -690,6 +820,9 @@ async def run_evaluator_task( fallback_progress = progress_cb is not None and progress_mode != "progress" original_task = evaluator.task + original_data = evaluator.data + sampled_data = await apply_sampling_to_data(original_data, config) + evaluator.data = sampled_data supports_stream = run_evaluator_supports_stream() if fallback_progress: @@ -727,6 +860,7 @@ async def run_evaluator_task( ) finally: evaluator.task = original_task + evaluator.data = original_data if fallback_progress: progress_cb("stop", None) if experiment: @@ -812,6 +946,7 @@ async def run_requested_eval( progress_cb, progress_mode, sse, + config, parent=parent, ) except Exception as exc: @@ -820,9 +955,9 @@ async def run_requested_eval( return False if sse: - sse.send("summary", format_summary(result.summary.as_dict())) + sse.send("summary", format_summary(result.summary.as_dict(), config)) elif config.jsonl: - print(json.dumps(format_summary(result.summary.as_dict()))) + print(json.dumps(format_summary(result.summary.as_dict(), config))) else: print(result.summary) @@ -889,6 +1024,7 @@ async def run_single_evaluator( progress_cb, progress_mode, sse, + config, ) except Exception as exc: err = serialize_error(str(exc), traceback.format_exc()) @@ -921,9 +1057,9 @@ async def run_single_evaluator( continue if sse: - sse.send("summary", format_summary(result.summary.as_dict())) + sse.send("summary", format_summary(result.summary.as_dict(), config)) elif config.jsonl: - print(json.dumps(format_summary(result.summary.as_dict()))) + print(json.dumps(format_summary(result.summary.as_dict(), config))) else: print(result.summary) diff --git a/scripts/eval-runner.ts b/scripts/eval-runner.ts index 2a19c10..9c8b6a6 100644 --- a/scripts/eval-runner.ts +++ b/scripts/eval-runner.ts @@ -116,6 +116,9 @@ type RunnerConfig = { list: boolean; terminateOnFailure: boolean; filters: EvalFilter[]; + first: number | null; + sample: number | null; + sampleSeed: number | null; devMode: "list" | "eval" | null; devRequestJson: string | null; }; @@ -253,12 +256,45 @@ function parseDevMode(value: string | undefined): "list" | "eval" | null { throw new Error(`Invalid BT_EVAL_DEV_MODE value: ${value}`); } +function parsePositiveIntegerEnv(name: string): number | null { + const value = process.env[name]; + if (!value) { + return null; + } + if (!/^[0-9]+$/.test(value)) { + throw new Error(`${name} must be a positive integer.`); + } + const parsed = Number(value); + if (!Number.isSafeInteger(parsed) || parsed < 1) { + throw new Error(`${name} must be a positive integer.`); + } + return parsed; +} + +function parseIntegerEnv(name: string): number | null { + const value = process.env[name]; + if (!value) { + return null; + } + if (!/^-?[0-9]+$/.test(value)) { + throw new Error(`${name} must be an integer.`); + } + const parsed = Number(value); + if (!Number.isSafeInteger(parsed)) { + throw new Error(`${name} must be an integer.`); + } + return parsed; +} + function readRunnerConfig(): RunnerConfig { return { jsonl: envFlag("BT_EVAL_JSONL"), list: envFlag("BT_EVAL_LIST"), terminateOnFailure: envFlag("BT_EVAL_TERMINATE_ON_FAILURE"), filters: parseSerializedFilters(process.env.BT_EVAL_FILTER_PARSED), + first: parsePositiveIntegerEnv("BT_EVAL_FIRST"), + sample: parsePositiveIntegerEnv("BT_EVAL_SAMPLE"), + sampleSeed: parseIntegerEnv("BT_EVAL_SAMPLE_SEED"), devMode: parseDevMode(process.env.BT_EVAL_DEV_MODE), devRequestJson: process.env.BT_EVAL_DEV_REQUEST_JSON ?? null, }; @@ -1487,6 +1523,244 @@ function resolveEvalData( throw new Error("Invalid eval data payload."); } +type SamplingMetadata = { + runMode: "full" | "first" | "sample"; + isFinal: boolean; + runLabel: string; + sampleCount?: number; + sampleSeed?: number; +}; + +function samplingMetadata(config: RunnerConfig): SamplingMetadata { + if (config.first !== null) { + return { + runMode: "first", + isFinal: false, + runLabel: `Run mode: first ${config.first} examples (non-final smoke run)`, + sampleCount: config.first, + }; + } + if (config.sample !== null) { + const seed = config.sampleSeed ?? 0; + return { + runMode: "sample", + isFinal: false, + runLabel: `Run mode: random sample of ${config.sample} examples (seed ${seed}, non-final smoke run)`, + sampleCount: config.sample, + sampleSeed: seed, + }; + } + return { + runMode: "full", + isFinal: true, + runLabel: "Run mode: full dataset", + }; +} + +function attachSamplingSummary( + summary: unknown, + config: RunnerConfig, +): unknown { + const metadata = samplingMetadata(config); + if (isObject(summary)) { + return { + ...summary, + runMode: metadata.runMode, + isFinal: metadata.isFinal, + runLabel: metadata.runLabel, + ...(metadata.sampleCount !== undefined + ? { sampleCount: metadata.sampleCount } + : {}), + ...(metadata.sampleSeed !== undefined + ? { sampleSeed: metadata.sampleSeed } + : {}), + }; + } + return { + summary, + runMode: metadata.runMode, + isFinal: metadata.isFinal, + runLabel: metadata.runLabel, + ...(metadata.sampleCount !== undefined + ? { sampleCount: metadata.sampleCount } + : {}), + ...(metadata.sampleSeed !== undefined + ? { sampleSeed: metadata.sampleSeed } + : {}), + }; +} + +function isPromiseLike(value: unknown): value is PromiseLike { + return ( + typeof value === "object" && + value !== null && + typeof Reflect.get(value, "then") === "function" + ); +} + +function isAsyncIterable(value: unknown): value is AsyncIterable { + return ( + typeof value === "object" && + value !== null && + typeof Reflect.get(value, Symbol.asyncIterator) === "function" + ); +} + +function isIterable(value: unknown): value is Iterable { + return ( + typeof value === "object" && + value !== null && + typeof Reflect.get(value, Symbol.iterator) === "function" + ); +} + +function isDatasetLike(value: unknown): value is { + fetch: (options?: { batchSize?: number }) => AsyncIterable; +} { + return ( + typeof value === "object" && + value !== null && + typeof Reflect.get(value, "fetch") === "function" && + typeof Reflect.get(value, "summarize") === "function" + ); +} + +function createSeededRandom(seed: number): () => number { + // SplitMix64 keeps seed entropy beyond 32 bits so distinct seeds stay distinct. + let state = BigInt.asUintN(64, BigInt(seed)); + return () => { + state = BigInt.asUintN(64, state + 0x9e3779b97f4a7c15n); + let z = state; + z = BigInt.asUintN(64, (z ^ (z >> 30n)) * 0xbf58476d1ce4e5b9n); + z = BigInt.asUintN(64, (z ^ (z >> 27n)) * 0x94d049bb133111ebn); + z = BigInt.asUintN(64, z ^ (z >> 31n)); + return Number(z & 0x1fffffffffffffn) / 9007199254740992; + }; +} + +type SamplingSourceOptions = { + initialCallReceiver?: unknown; +}; + +async function resolveSamplingSource( + source: unknown, + options?: SamplingSourceOptions, +): Promise { + let current: unknown = source; + let firstCall = true; + while (true) { + current = isPromiseLike(current) ? await current : current; + if (typeof current !== "function") { + return current; + } + if (firstCall && options?.initialCallReceiver !== undefined) { + current = Reflect.apply( + current as (...args: unknown[]) => unknown, + options.initialCallReceiver, + [], + ); + firstCall = false; + continue; + } + current = (current as () => unknown)(); + firstCall = false; + } +} + +async function* iterateDataSource( + source: unknown, + batchSizeHint?: number, + options?: SamplingSourceOptions, +): AsyncGenerator { + const resolved = await resolveSamplingSource(source, options); + if (Array.isArray(resolved)) { + for (const item of resolved) { + yield item; + } + return; + } + if (isDatasetLike(resolved)) { + const options = + batchSizeHint !== undefined ? { batchSize: batchSizeHint } : undefined; + for await (const item of resolved.fetch(options)) { + yield item; + } + return; + } + if (isAsyncIterable(resolved)) { + for await (const item of resolved) { + yield item; + } + return; + } + if (typeof resolved !== "string" && isIterable(resolved)) { + for (const item of resolved) { + yield item; + } + return; + } + throw new Error( + "Sampling is only supported for arrays, iterables, async iterables, and Braintrust datasets.", + ); +} + +async function collectFirstRecords( + source: unknown, + count: number, + options?: SamplingSourceOptions, +): Promise { + const items: unknown[] = []; + for await (const item of iterateDataSource(source, count, options)) { + items.push(item); + if (items.length >= count) { + break; + } + } + return items; +} + +async function reservoirSampleRecords( + source: unknown, + count: number, + seed: number, + options?: SamplingSourceOptions, +): Promise { + const random = createSeededRandom(seed); + const sample: unknown[] = []; + let seen = 0; + for await (const item of iterateDataSource(source, undefined, options)) { + seen += 1; + if (sample.length < count) { + sample.push(item); + continue; + } + const index = Math.floor(random() * seen); + if (index < count) { + sample[index] = item; + } + } + return sample; +} + +async function applySamplingToData( + data: unknown, + config: RunnerConfig, + options?: SamplingSourceOptions, +): Promise { + if (config.first !== null) { + return await collectFirstRecords(data, config.first, options); + } + if (config.sample !== null) { + return await reservoirSampleRecords( + data, + config.sample, + config.sampleSeed ?? 0, + options, + ); + } + return data; +} + function convertFunctionId( functionId: Record, ): Record { @@ -1886,8 +2160,15 @@ async function createEvalRunner(config: RunnerConfig): Promise { globalThis._lazy_load = false; const evaluatorName = getEvaluatorName(evaluator, projectName); const opts = makeEvalOptions(evaluatorName, options); - const wrappedEvaluator = wrapTaskForStreamingProgress(evaluator); + const sampledData = await applySamplingToData(evaluator.data, config, { + initialCallReceiver: evaluator, + }); + const wrappedEvaluator = wrapTaskForStreamingProgress({ + ...evaluator, + data: sampledData, + }); const result = await Eval(projectName, wrappedEvaluator, opts); + const summary = attachSamplingSummary(result.summary, config); const failingResults = result.results.filter( (r: { error?: unknown }) => r.error !== undefined, ); @@ -1898,9 +2179,9 @@ async function createEvalRunner(config: RunnerConfig): Promise { ); } if (sse) { - sse.send("summary", result.summary); + sse.send("summary", summary); } else if (config.jsonl) { - console.log(JSON.stringify(result.summary)); + console.log(JSON.stringify(summary)); } return result; }; diff --git a/src/eval.rs b/src/eval.rs index 128a51e..17f9ab7 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -59,6 +59,17 @@ const HEADER_CORS_ALLOW_PRIVATE_NETWORK: &str = "access-control-allow-private-ne const SSE_SOCKET_BIND_MAX_ATTEMPTS: u8 = 16; const EVAL_NODE_MAX_OLD_SPACE_SIZE_MB: usize = 8192; const MAX_DEFERRED_EVAL_ERRORS: usize = 8; +const DEFAULT_EVAL_SAMPLE_SEED: u64 = 0; + +fn parse_positive_usize(value: &str) -> std::result::Result { + let parsed = value + .parse::() + .map_err(|_| format!("invalid positive integer '{value}'"))?; + if parsed == 0 { + return Err("value must be greater than 0".to_string()); + } + Ok(parsed) +} static SSE_SOCKET_COUNTER: AtomicU64 = AtomicU64::new(0); struct EvalRunOutput { @@ -294,6 +305,35 @@ pub struct EvalArgs { )] pub filter: Vec, + /// Run only the first N dataset records. Marks the run as non-final. + #[arg( + long, + env = "BT_EVAL_FIRST", + value_name = "N", + value_parser = parse_positive_usize, + conflicts_with = "sample" + )] + pub first: Option, + + /// Run a deterministic random sample of N dataset records. Marks the run as non-final. + #[arg( + long, + env = "BT_EVAL_SAMPLE", + value_name = "N", + value_parser = parse_positive_usize, + conflicts_with = "first" + )] + pub sample: Option, + + /// Seed used with --sample. + #[arg( + long = "sample-seed", + env = "BT_EVAL_SAMPLE_SEED", + value_name = "SEED", + requires = "sample" + )] + pub sample_seed: Option, + /// Show verbose evaluator errors and stderr output. #[arg( long, @@ -350,6 +390,13 @@ pub struct EvalArgs { pub dev_allowed_origin: Vec, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum EvalSamplingMode { + Full, + First { count: usize }, + Sample { count: usize, seed: u64 }, +} + #[derive(Debug, Clone)] struct EvalRunOptions { jsonl: bool, @@ -357,6 +404,7 @@ struct EvalRunOptions { num_workers: Option, list: bool, filter: Vec, + sampling: EvalSamplingMode, verbose: bool, extra_args: Vec, } @@ -379,12 +427,24 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { } validate_eval_input_files(&files)?; + let sampling = if let Some(first) = args.first { + EvalSamplingMode::First { count: first } + } else if let Some(sample) = args.sample { + EvalSamplingMode::Sample { + count: sample, + seed: args.sample_seed.unwrap_or(DEFAULT_EVAL_SAMPLE_SEED), + } + } else { + EvalSamplingMode::Full + }; + let options = EvalRunOptions { jsonl: args.jsonl, terminate_on_failure: args.terminate_on_failure, num_workers: args.num_workers, list: args.list, filter: args.filter, + sampling, verbose: args.verbose, extra_args: args.extra_args, }; @@ -746,6 +806,16 @@ async fn spawn_eval_runner( serde_json::to_string(&parsed).context("failed to serialize eval filters")?; cmd.env("BT_EVAL_FILTER_PARSED", serialized); } + match options.sampling { + EvalSamplingMode::Full => {} + EvalSamplingMode::First { count } => { + cmd.env("BT_EVAL_FIRST", count.to_string()); + } + EvalSamplingMode::Sample { count, seed } => { + cmd.env("BT_EVAL_SAMPLE", count.to_string()); + cmd.env("BT_EVAL_SAMPLE_SEED", seed.to_string()); + } + } if language == EvalLanguage::JavaScript && force_esm { cmd.env("BT_EVAL_FORCE_ESM", "1"); } @@ -2556,6 +2626,16 @@ struct ExperimentSummary { comparison_experiment_name: Option, scores: HashMap, metrics: Option>, + #[serde(default)] + run_mode: Option, + #[serde(default)] + is_final: Option, + #[serde(default)] + run_label: Option, + #[serde(default)] + sample_count: Option, + #[serde(default)] + sample_seed: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -3037,6 +3117,15 @@ fn format_start_line(start: &ExperimentStart) -> Option { fn format_experiment_summary(summary: &ExperimentSummary) -> String { let mut parts: Vec = Vec::new(); + if let Some(run_label) = summary.run_label.as_deref() { + let line = if summary.is_final == Some(false) { + run_label.yellow().to_string() + } else { + run_label.to_string() + }; + parts.push(line); + } + if let Some(comparison) = summary.comparison_experiment_name.as_deref() { let line = format!( "{baseline} {baseline_tag} ← {comparison_name} {comparison_tag}", @@ -4230,6 +4319,63 @@ mod tests { } } + #[test] + fn handle_sse_event_parses_summary_run_metadata() { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + handle_sse_event( + Some("summary".to_string()), + r#"{ + "projectName":"Topics", + "experimentName":"sample-run", + "scores":{}, + "runMode":"sample", + "isFinal":false, + "runLabel":"Run mode: random sample of 20 examples (seed 7, non-final smoke run)", + "sampleCount":20, + "sampleSeed":7 + }"# + .to_string(), + &tx, + ); + + match rx.try_recv().expect("summary event should be emitted") { + EvalEvent::Summary(summary) => { + assert_eq!(summary.run_mode.as_deref(), Some("sample")); + assert_eq!(summary.is_final, Some(false)); + assert_eq!( + summary.run_label.as_deref(), + Some("Run mode: random sample of 20 examples (seed 7, non-final smoke run)") + ); + assert_eq!(summary.sample_count, Some(20)); + assert_eq!(summary.sample_seed, Some(7)); + } + other => panic!("unexpected event: {other:?}"), + } + } + + #[test] + fn format_experiment_summary_includes_run_label() { + let summary = ExperimentSummary { + project_name: "Demo".to_string(), + experiment_name: "sample-run".to_string(), + project_id: None, + experiment_id: None, + project_url: None, + experiment_url: None, + comparison_experiment_name: None, + scores: HashMap::new(), + metrics: None, + run_mode: Some("first".to_string()), + is_final: Some(false), + run_label: Some("Run mode: first 20 examples (non-final smoke run)".to_string()), + sample_count: Some(20), + sample_seed: None, + }; + + let rendered = format_experiment_summary(&summary); + assert!(rendered.contains("Run mode: first 20 examples (non-final smoke run)")); + } + #[test] fn parse_eval_filter_expression_splits_path_and_pattern() { let parsed = @@ -4262,6 +4408,52 @@ mod tests { ); } + #[test] + fn eval_args_parse_first_sampling_flag() { + let _guard = env_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let keys = ["BT_EVAL_FIRST", "BT_EVAL_SAMPLE", "BT_EVAL_SAMPLE_SEED"]; + let previous: Vec<(&str, Option)> = + keys.iter().map(|key| (*key, clear_env_var(key))).collect(); + + let parsed = EvalArgsHarness::try_parse_from(["bt", "--first", "20", "sample.eval.ts"]) + .expect("first flag should parse"); + assert_eq!(parsed.eval.first, Some(20)); + assert_eq!(parsed.eval.sample, None); + + for (key, value) in previous { + restore_env_var(key, value); + } + } + + #[test] + fn eval_args_parse_sample_and_seed_flags() { + let _guard = env_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let keys = ["BT_EVAL_FIRST", "BT_EVAL_SAMPLE", "BT_EVAL_SAMPLE_SEED"]; + let previous: Vec<(&str, Option)> = + keys.iter().map(|key| (*key, clear_env_var(key))).collect(); + + let parsed = EvalArgsHarness::try_parse_from([ + "bt", + "--sample", + "20", + "--sample-seed", + "7", + "sample.eval.ts", + ]) + .expect("sample flags should parse"); + assert_eq!(parsed.eval.sample, Some(20)); + assert_eq!(parsed.eval.sample_seed, Some(7)); + assert_eq!(parsed.eval.first, None); + + for (key, value) in previous { + restore_env_var(key, value); + } + } + #[test] fn eval_args_from_env_populates_supported_fields() { let _guard = env_test_lock() @@ -4273,6 +4465,9 @@ mod tests { "BT_EVAL_NUM_WORKERS", "BT_EVAL_LIST", "BT_EVAL_FILTER", + "BT_EVAL_FIRST", + "BT_EVAL_SAMPLE", + "BT_EVAL_SAMPLE_SEED", "BT_EVAL_VERBOSE", "BT_EVAL_WATCH", "BT_EVAL_DEV", diff --git a/tests/evals/js/direct-basic/direct.eval.ts b/tests/evals/js/direct-basic/direct.eval.ts index e037a8a..12dd7e8 100644 --- a/tests/evals/js/direct-basic/direct.eval.ts +++ b/tests/evals/js/direct-basic/direct.eval.ts @@ -14,14 +14,37 @@ function exactMatch({ output, expected }: ScoreArgs) { return output === expected ? 1 : 0; } +type EvalInput = { + text: string; + shouldFail: boolean; +}; + export async function btEvalMain(ctx: EvalContext) { - await ctx.runEval("BT CLI Tests", { + const evaluator = { evalName: "direct-basic", - data: () => [ - { input: "Cara", expected: "Hello Cara" }, - { input: "Dan", expected: "Hello Dan" }, + records: [ + { input: { text: "sample-0", shouldFail: true }, expected: "sample-0" }, + { input: { text: "sample-1", shouldFail: true }, expected: "sample-1" }, + { input: { text: "sample-2", shouldFail: true }, expected: "sample-2" }, + { input: { text: "sample-3", shouldFail: true }, expected: "sample-3" }, + { input: { text: "sample-4", shouldFail: true }, expected: "sample-4" }, + { input: { text: "sample-5", shouldFail: true }, expected: "sample-5" }, + { input: { text: "sample-6", shouldFail: true }, expected: "sample-6" }, + { input: { text: "sample-7", shouldFail: true }, expected: "sample-7" }, + { input: { text: "sample-8", shouldFail: true }, expected: "sample-8" }, + { input: { text: "sample-9", shouldFail: false }, expected: "sample-9" }, ], - task: (input: string) => `Hello ${input}`, + data() { + return this.records; + }, + task: (input: EvalInput) => { + if (input.shouldFail) { + throw new Error("intentional fixture failure"); + } + return input.text; + }, scores: [exactMatch], - }); + }; + + await ctx.runEval("BT CLI Tests", evaluator); } diff --git a/tests/evals/js/direct-basic/fixture.json b/tests/evals/js/direct-basic/fixture.json index 242aee2..86ba970 100644 --- a/tests/evals/js/direct-basic/fixture.json +++ b/tests/evals/js/direct-basic/fixture.json @@ -1,3 +1,4 @@ { - "files": ["direct.eval.ts"] + "files": ["direct.eval.ts"], + "args": ["--sample", "1", "--sample-seed", "4294967297"] }