Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions ax/metrics/map_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,109 @@
logger: Logger = get_logger(__name__)


class MapDataReplayState:
"""Shared state coordinator for replaying historical map data.

Manages normalized cursor-based progression across multiple metrics
and trials. The cursor model uses a global min/max MAP_KEY across
all metrics to preserve cross-metric timing alignment.

This class serves original MAP_KEY values (not normalized). Downstream
early stopping strategies apply normalization independently via
``_maybe_normalize_map_key`` in ``ax.adapter.data_utils``.
"""

def __init__(
self,
map_data: Data,
metric_signatures: list[str],
step_size: float = 0.01,
) -> None:
"""Initialize replay state from historical data.

Args:
map_data: Historical data containing progression data.
metric_signatures: List of metric signatures to replay.
step_size: Cursor increment per advancement step. Determines
the granularity of replay (e.g. 0.01 = 100 steps).
"""
self.step_size: float = step_size

# Pre-index data by (trial_index, metric_signature) for O(1) lookups
self._data: dict[tuple[int, str], pd.DataFrame] = {}
all_trial_indices: set[int] = set()
all_prog_values: list[float] = []
per_trial_max_prog: dict[int, float] = {}

for metric_signature in metric_signatures:
df = map_data.full_df
df = df[df["metric_signature"] == metric_signature]
replay_df = df.sort_values(
by=["trial_index", MAP_KEY], ascending=True, ignore_index=True
)
for trial_index, group in replay_df.groupby("trial_index"):
trial_index = int(trial_index)
self._data[(trial_index, metric_signature)] = group.reset_index(
drop=True
)
all_trial_indices.add(trial_index)
prog_values = group[MAP_KEY].values
all_prog_values.extend(prog_values.tolist())
trial_max = float(prog_values.max())
if trial_index in per_trial_max_prog:
per_trial_max_prog[trial_index] = max(
per_trial_max_prog[trial_index], trial_max
)
else:
per_trial_max_prog[trial_index] = trial_max

if all_prog_values:
self.min_prog: float = float(min(all_prog_values))
self.max_prog: float = float(max(all_prog_values))
else:
self.min_prog = 0.0
self.max_prog = 0.0

self._per_trial_max_prog: dict[int, float] = per_trial_max_prog
self._trial_cursors: defaultdict[int, float] = defaultdict(float)
self._trial_indices: set[int] = all_trial_indices

def advance_trial(self, trial_index: int) -> None:
"""Advance the cursor for a trial by one step."""
self._trial_cursors[trial_index] = min(
self._trial_cursors[trial_index] + self.step_size, 1.0
)

def has_trial_data(self, trial_index: int) -> bool:
"""Check if any replay data exists for a given trial."""
return trial_index in self._trial_indices

def is_trial_complete(self, trial_index: int) -> bool:
"""Check if a trial's cursor has reached its maximum progression."""
if self.min_prog == self.max_prog:
return True
curr_prog = self.min_prog + self._trial_cursors[trial_index] * (
self.max_prog - self.min_prog
)
return curr_prog >= self._per_trial_max_prog.get(trial_index, 0.0)

def get_data(self, trial_index: int, metric_signature: str) -> pd.DataFrame:
"""Get replay data for a trial up to the current cursor position.

Returns a DataFrame filtered to rows where MAP_KEY <= current
progression value, with original (non-normalized) MAP_KEY values.
"""
df = self._data.get((trial_index, metric_signature))
if df is None:
return pd.DataFrame()
if self.min_prog == self.max_prog:
return df
curr_prog = self.min_prog + self._trial_cursors[trial_index] * (
self.max_prog - self.min_prog
)
return df[df[MAP_KEY] <= curr_prog]


class MapDataReplayMetric(MapMetric):
"""A metric for replaying historical map data."""

Expand Down
Loading
Loading