Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ Use `--` to forward extra arguments to the eval file via `process.argv`:
bt eval foo.eval.ts -- --description "Prod" --shard=1/4
```

**Sampling modes:**

- `bt eval --first 20 qa.eval.ts` — run the first 20 examples and clearly label the summary as a non-final smoke run.
- `bt eval --sample 20 --sample-seed 7 qa.eval.ts` — run a deterministic random sample and clearly label the summary as a non-final smoke run.
- If you do not pass a sampling flag, `bt eval` runs the full dataset and marks the summary as final.

## `bt sql`

- Runs interactively on TTY by default.
Expand Down
150 changes: 143 additions & 7 deletions scripts/eval-runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import json
import os
import random
import re
import socket
import sys
Expand Down Expand Up @@ -59,6 +60,9 @@ class RunnerConfig:
terminate_on_failure: bool
num_workers: int | None
filters: list[EvalFilter]
first: int | None
sample: int | None
sample_seed: int | None
dev_mode: str | None
dev_request_json: str | None

Expand Down Expand Up @@ -142,6 +146,23 @@ def parse_dev_mode(value: str | None) -> str | None:
raise ValueError(f"Invalid BT_EVAL_DEV_MODE value: {value}")


def parse_positive_int_env(name: str) -> int | None:
value = os.getenv(name)
if value is None or value == "":
return None
parsed = int(value)
if parsed < 1:
raise ValueError(f"{name} must be a positive integer")
return parsed


def parse_int_env(name: str) -> int | None:
value = os.getenv(name)
if value is None or value == "":
return None
return int(value)


def read_runner_config() -> RunnerConfig:
num_workers_value = os.getenv("BT_EVAL_NUM_WORKERS")
num_workers = int(num_workers_value) if num_workers_value else None
Expand All @@ -151,6 +172,9 @@ def read_runner_config() -> RunnerConfig:
terminate_on_failure=env_flag("BT_EVAL_TERMINATE_ON_FAILURE"),
num_workers=num_workers,
filters=parse_serialized_filters(os.getenv("BT_EVAL_FILTER_PARSED")),
first=parse_positive_int_env("BT_EVAL_FIRST"),
sample=parse_positive_int_env("BT_EVAL_SAMPLE"),
sample_seed=parse_int_env("BT_EVAL_SAMPLE_SEED"),
dev_mode=parse_dev_mode(os.getenv("BT_EVAL_DEV_MODE")),
dev_request_json=os.getenv("BT_EVAL_DEV_REQUEST_JSON"),
)
Expand Down Expand Up @@ -202,8 +226,35 @@ def snake_to_camel(value: str) -> str:
return parts[0] + "".join(word.title() for word in parts[1:])


def format_summary(summary: dict[str, Any]) -> dict[str, Any]:
return {snake_to_camel(k): v for k, v in summary.items()}
def sampling_metadata(config: RunnerConfig) -> dict[str, Any]:
if config.first is not None:
return {
"runMode": "first",
"isFinal": False,
"runLabel": f"Run mode: first {config.first} examples (non-final smoke run)",
"sampleCount": config.first,
}
if config.sample is not None:
seed = config.sample_seed if config.sample_seed is not None else 0
return {
"runMode": "sample",
"isFinal": False,
"runLabel": f"Run mode: random sample of {config.sample} examples (seed {seed}, non-final smoke run)",
"sampleCount": config.sample,
"sampleSeed": seed,
}
return {
"runMode": "full",
"isFinal": True,
"runLabel": "Run mode: full dataset",
}


def format_summary(summary: dict[str, Any], config: RunnerConfig) -> dict[str, Any]:
return {
**{snake_to_camel(k): v for k, v in summary.items()},
**sampling_metadata(config),
}


def send_eval_progress(sse: SseWriter | None, evaluator_name: str, kind: str, total: int | None = None) -> None:
Expand Down Expand Up @@ -324,6 +375,79 @@ def resolve_eval_data(data: dict[str, Any]) -> Any:
raise ValueError("Invalid eval data payload.")


async def resolve_sampling_source(data: Any) -> Any:
current = data
while True:
if callable(current):
current = current()
continue
if inspect.isawaitable(current):
current = await current
continue
return current


async def iter_data_source(data: Any, batch_size_hint: int | None = None):
resolved = await resolve_sampling_source(data)
if isinstance(resolved, Dataset):
fetched = resolved.fetch(batch_size=batch_size_hint)
if hasattr(fetched, "__aiter__"):
async for item in fetched:
yield item
return
for item in fetched:
yield item
return
if isinstance(resolved, (str, bytes, dict)):
raise ValueError(
"Sampling is only supported for arrays, iterables, and Braintrust datasets."
)
if hasattr(resolved, "__aiter__"):
async for item in resolved:
yield item
return
try:
iterator = iter(resolved)
except TypeError as exc:
raise ValueError(
"Sampling is only supported for arrays, iterables, and Braintrust datasets."
) from exc
for item in iterator:
yield item


async def collect_first_records(data: Any, count: int) -> list[Any]:
items: list[Any] = []
async for item in iter_data_source(data, batch_size_hint=count):
items.append(item)
if len(items) >= count:
break
return items


async def reservoir_sample_records(data: Any, count: int, seed: int) -> list[Any]:
rng = random.Random(seed)
sample: list[Any] = []
seen = 0
async for item in iter_data_source(data):
seen += 1
if len(sample) < count:
sample.append(item)
continue
index = rng.randrange(seen)
if index < count:
sample[index] = item
return sample


async def apply_sampling_to_data(data: Any, config: RunnerConfig) -> Any:
if config.first is not None:
return await collect_first_records(data, config.first)
if config.sample is not None:
return await reservoir_sample_records(data, config.sample, config.sample_seed or 0)
return data


def make_eval_scorer(
score: dict[str, Any],
project_id: str | None,
Expand Down Expand Up @@ -680,7 +804,13 @@ def run_evaluator_supports_stream() -> bool:


async def run_evaluator_task(
evaluator, position: int, no_send_logs: bool, progress_cb, progress_mode: str, sse: SseWriter | None,
evaluator,
position: int,
no_send_logs: bool,
progress_cb,
progress_mode: str,
sse: SseWriter | None,
config: RunnerConfig,
parent: str | None = None,
):
experiment = None
Expand All @@ -690,6 +820,9 @@ async def run_evaluator_task(

fallback_progress = progress_cb is not None and progress_mode != "progress"
original_task = evaluator.task
original_data = evaluator.data
sampled_data = await apply_sampling_to_data(original_data, config)
evaluator.data = sampled_data
supports_stream = run_evaluator_supports_stream()

if fallback_progress:
Expand Down Expand Up @@ -727,6 +860,7 @@ async def run_evaluator_task(
)
finally:
evaluator.task = original_task
evaluator.data = original_data
if fallback_progress:
progress_cb("stop", None)
if experiment:
Expand Down Expand Up @@ -812,6 +946,7 @@ async def run_requested_eval(
progress_cb,
progress_mode,
sse,
config,
parent=parent,
)
except Exception as exc:
Expand All @@ -820,9 +955,9 @@ async def run_requested_eval(
return False

if sse:
sse.send("summary", format_summary(result.summary.as_dict()))
sse.send("summary", format_summary(result.summary.as_dict(), config))
elif config.jsonl:
print(json.dumps(format_summary(result.summary.as_dict())))
print(json.dumps(format_summary(result.summary.as_dict(), config)))
else:
print(result.summary)

Expand Down Expand Up @@ -889,6 +1024,7 @@ async def run_single_evaluator(
progress_cb,
progress_mode,
sse,
config,
)
except Exception as exc:
err = serialize_error(str(exc), traceback.format_exc())
Expand Down Expand Up @@ -921,9 +1057,9 @@ async def run_single_evaluator(
continue

if sse:
sse.send("summary", format_summary(result.summary.as_dict()))
sse.send("summary", format_summary(result.summary.as_dict(), config))
elif config.jsonl:
print(json.dumps(format_summary(result.summary.as_dict())))
print(json.dumps(format_summary(result.summary.as_dict(), config)))
else:
print(result.summary)

Expand Down
Loading
Loading