diff --git a/ax/metrics/map_replay.py b/ax/metrics/map_replay.py index a05e36e9266..e719378bd4c 100644 --- a/ax/metrics/map_replay.py +++ b/ax/metrics/map_replay.py @@ -23,6 +23,109 @@ logger: Logger = get_logger(__name__) +class MapDataReplayState: + """Shared state coordinator for replaying historical map data. + + Manages normalized cursor-based progression across multiple metrics + and trials. The cursor model uses a global min/max MAP_KEY across + all metrics to preserve cross-metric timing alignment. + + This class serves original MAP_KEY values (not normalized). Downstream + early stopping strategies apply normalization independently via + ``_maybe_normalize_map_key`` in ``ax.adapter.data_utils``. + """ + + def __init__( + self, + map_data: Data, + metric_signatures: list[str], + step_size: float = 0.01, + ) -> None: + """Initialize replay state from historical data. + + Args: + map_data: Historical data containing progression data. + metric_signatures: List of metric signatures to replay. + step_size: Cursor increment per advancement step. Determines + the granularity of replay (e.g. 0.01 = 100 steps). + """ + self.step_size: float = step_size + + # Pre-index data by (trial_index, metric_signature) for O(1) lookups + self._data: dict[tuple[int, str], pd.DataFrame] = {} + all_trial_indices: set[int] = set() + all_prog_values: list[float] = [] + per_trial_max_prog: dict[int, float] = {} + + for metric_signature in metric_signatures: + df = map_data.full_df + df = df[df["metric_signature"] == metric_signature] + replay_df = df.sort_values( + by=["trial_index", MAP_KEY], ascending=True, ignore_index=True + ) + for trial_index, group in replay_df.groupby("trial_index"): + trial_index = int(trial_index) + self._data[(trial_index, metric_signature)] = group.reset_index( + drop=True + ) + all_trial_indices.add(trial_index) + prog_values = group[MAP_KEY].values + all_prog_values.extend(prog_values.tolist()) + trial_max = float(prog_values.max()) + if trial_index in per_trial_max_prog: + per_trial_max_prog[trial_index] = max( + per_trial_max_prog[trial_index], trial_max + ) + else: + per_trial_max_prog[trial_index] = trial_max + + if all_prog_values: + self.min_prog: float = float(min(all_prog_values)) + self.max_prog: float = float(max(all_prog_values)) + else: + self.min_prog = 0.0 + self.max_prog = 0.0 + + self._per_trial_max_prog: dict[int, float] = per_trial_max_prog + self._trial_cursors: defaultdict[int, float] = defaultdict(float) + self._trial_indices: set[int] = all_trial_indices + + def advance_trial(self, trial_index: int) -> None: + """Advance the cursor for a trial by one step.""" + self._trial_cursors[trial_index] = min( + self._trial_cursors[trial_index] + self.step_size, 1.0 + ) + + def has_trial_data(self, trial_index: int) -> bool: + """Check if any replay data exists for a given trial.""" + return trial_index in self._trial_indices + + def is_trial_complete(self, trial_index: int) -> bool: + """Check if a trial's cursor has reached its maximum progression.""" + if self.min_prog == self.max_prog: + return True + curr_prog = self.min_prog + self._trial_cursors[trial_index] * ( + self.max_prog - self.min_prog + ) + return curr_prog >= self._per_trial_max_prog.get(trial_index, 0.0) + + def get_data(self, trial_index: int, metric_signature: str) -> pd.DataFrame: + """Get replay data for a trial up to the current cursor position. + + Returns a DataFrame filtered to rows where MAP_KEY <= current + progression value, with original (non-normalized) MAP_KEY values. + """ + df = self._data.get((trial_index, metric_signature)) + if df is None: + return pd.DataFrame() + if self.min_prog == self.max_prog: + return df + curr_prog = self.min_prog + self._trial_cursors[trial_index] * ( + self.max_prog - self.min_prog + ) + return df[df[MAP_KEY] <= curr_prog] + + class MapDataReplayMetric(MapMetric): """A metric for replaying historical map data.""" diff --git a/ax/metrics/tests/test_map_replay.py b/ax/metrics/tests/test_map_replay.py index c2957bd0ae5..39d22693ca5 100644 --- a/ax/metrics/tests/test_map_replay.py +++ b/ax/metrics/tests/test_map_replay.py @@ -12,7 +12,7 @@ from ax.core.experiment import Experiment from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig -from ax.metrics.map_replay import MapDataReplayMetric +from ax.metrics.map_replay import MapDataReplayMetric, MapDataReplayState from ax.runners.synthetic import SyntheticRunner from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -23,6 +23,241 @@ from pandas.testing import assert_frame_equal +def _make_map_data( + trial_metric_data: dict[int, dict[str, list[tuple[float, float, float]]]], +) -> Data: + """Helper to build map data from a nested dict. + + Args: + trial_metric_data: + {trial_index: {metric_name: [(step, mean, sem), ...]}} + """ + rows = [] + for trial_index, metrics in trial_metric_data.items(): + for metric_name, points in metrics.items(): + for step, mean, sem in points: + rows.append( + { + "trial_index": trial_index, + "arm_name": f"{trial_index}_0", + "metric_name": metric_name, + "metric_signature": metric_name, + "mean": mean, + "sem": sem, + MAP_KEY: step, + } + ) + return Data(df=DataFrame(rows)) + + +class MapDataReplayStateTest(TestCase): + def test_state_computation(self) -> None: + """Test min_prog, max_prog, and per_trial_max_prog for various data shapes.""" + with self.subTest("uniform_steps"): + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (1.0, 2.0, 0.0)]}, + 1: {"m1": [(0.0, 3.0, 0.0), (1.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state.min_prog, 0.0) + self.assertEqual(state.max_prog, 1.0) + self.assertEqual(state._per_trial_max_prog, {0: 1.0, 1: 1.0}) + + with self.subTest("non_uniform_steps"): + map_data = _make_map_data( + { + 0: {"m1": [(0.25, 1.0, 0.0), (0.95, 2.0, 0.0)]}, + 1: {"m1": [(0.25, 3.0, 0.0), (1.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state.min_prog, 0.25) + self.assertEqual(state.max_prog, 1.0) + self.assertEqual(state._per_trial_max_prog, {0: 0.95, 1: 1.0}) + + with self.subTest("multi_metric"): + map_data = _make_map_data( + { + 0: { + "m1": [(0.0, 1.0, 0.0), (5.0, 2.0, 0.0)], + "m2": [(1.0, 3.0, 0.0), (10.0, 4.0, 0.0)], + }, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1", "m2"] + ) + self.assertEqual(state.min_prog, 0.0) + self.assertEqual(state.max_prog, 10.0) + self.assertEqual(state._per_trial_max_prog, {0: 10.0}) + + with self.subTest("single_trial"): + map_data = _make_map_data({0: {"m1": [(0.0, 1.0, 0.0), (1.0, 2.0, 0.0)]}}) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state._trial_indices, {0}) + self.assertTrue(state.has_trial_data(trial_index=0)) + self.assertFalse(state.has_trial_data(trial_index=1)) + + with self.subTest("non_contiguous_trial_indices"): + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0)]}, + 5: {"m1": [(0.0, 2.0, 0.0)]}, + 10: {"m1": [(0.0, 3.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state._trial_indices, {0, 5, 10}) + self.assertTrue(state.has_trial_data(trial_index=5)) + self.assertFalse(state.has_trial_data(trial_index=3)) + + with self.subTest("min_equals_max_prog"): + map_data = _make_map_data( + { + 0: {"m1": [(3.0, 1.0, 0.0)]}, + 1: {"m1": [(3.0, 2.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state.min_prog, 3.0) + self.assertEqual(state.max_prog, 3.0) + self.assertTrue(state.is_trial_complete(trial_index=0)) + self.assertTrue(state.is_trial_complete(trial_index=1)) + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 1 + ) + + with self.subTest("empty_metric_data"): + map_data = _make_map_data({0: {"m1": [(0.0, 1.0, 0.0), (1.0, 2.0, 0.0)]}}) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1", "m_empty"] + ) + self.assertTrue(state.has_trial_data(trial_index=0)) + self.assertTrue( + state.get_data(trial_index=0, metric_signature="m_empty").empty + ) + self.assertEqual(state.min_prog, 0.0) + self.assertEqual(state.max_prog, 1.0) + + with self.subTest("different_num_points_per_trial"): + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (0.5, 2.0, 0.0), (1.0, 3.0, 0.0)]}, + 1: {"m1": [(0.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state._per_trial_max_prog, {0: 1.0, 1: 0.0}) + + def test_cursor_advancement_and_data_serving(self) -> None: + """Test cursor advancement, capping, progressive data serving, + per-trial independence, and trial completion transitions.""" + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (0.5, 2.0, 0.0), (1.0, 3.0, 0.0)]}, + 1: {"m1": [(0.0, 4.0, 0.0), (1.0, 5.0, 0.0)]}, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1"], step_size=0.5 + ) + + with self.subTest("initial_cursor_is_zero"): + self.assertEqual(state._trial_cursors[0], 0.0) + self.assertEqual(state._trial_cursors[1], 0.0) + + with self.subTest("progressive_data_at_cursor_0"): + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 1 + ) + self.assertEqual( + len(state.get_data(trial_index=1, metric_signature="m1")), 1 + ) + + with self.subTest("advance_and_check_independence"): + state.advance_trial(trial_index=0) + self.assertAlmostEqual(state._trial_cursors[0], 0.5) + self.assertAlmostEqual(state._trial_cursors[1], 0.0) + + with self.subTest("progressive_data_at_cursor_0_5"): + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 2 + ) + self.assertEqual( + len(state.get_data(trial_index=1, metric_signature="m1")), 1 + ) + + with self.subTest("advance_to_full"): + state.advance_trial(trial_index=0) + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 3 + ) + + with self.subTest("cursor_caps_at_one"): + state.advance_trial(trial_index=0) + self.assertAlmostEqual(state._trial_cursors[0], 1.0) + + with self.subTest("trial_completion_transitions"): + self.assertTrue(state.is_trial_complete(trial_index=0)) + self.assertFalse(state.is_trial_complete(trial_index=1)) + state.advance_trial(trial_index=1) + state.advance_trial(trial_index=1) + self.assertTrue(state.is_trial_complete(trial_index=1)) + + with self.subTest("heterogeneous_max_prog_completion"): + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (0.5, 2.0, 0.0)]}, + 1: {"m1": [(0.0, 3.0, 0.0), (1.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1"], step_size=0.5 + ) + state.advance_trial(trial_index=0) + state.advance_trial(trial_index=1) + self.assertTrue(state.is_trial_complete(trial_index=0)) + self.assertFalse(state.is_trial_complete(trial_index=1)) + + def test_multi_metric_and_data_integrity(self) -> None: + """Test multi-metric shared timeline, original MAP_KEY preservation, + and get_data for nonexistent trial/metric.""" + map_data = _make_map_data( + { + 0: { + "m1": [(10.0, 1.0, 0.0), (20.0, 2.0, 0.0)], + "m2": [(10.0, 10.0, 0.0), (20.0, 20.0, 0.0)], + }, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1", "m2"], step_size=1.0 + ) + state.advance_trial(trial_index=0) + + with self.subTest("shared_timeline"): + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), + len(state.get_data(trial_index=0, metric_signature="m2")), + ) + + with self.subTest("original_map_key_values"): + self.assertListEqual( + state.get_data(trial_index=0, metric_signature="m1")[MAP_KEY].tolist(), + [10.0, 20.0], + ) + + with self.subTest("nonexistent_trial"): + self.assertTrue(state.get_data(trial_index=99, metric_signature="m1").empty) + + with self.subTest("nonexistent_metric"): + self.assertTrue( + state.get_data(trial_index=0, metric_signature="m_missing").empty + ) + + class MapDataReplayMetricTest(TestCase): def test_map_replay(self) -> None: historical_experiment = get_test_map_data_experiment( @@ -37,11 +272,7 @@ def test_map_replay(self) -> None: ) # Verify offset and scaling factor for uniform step data. - # The test data has 2 trials, each with 2 fetches, resulting in steps 0 and 1. - # offset = min(first step of each trial) = min(0, 0) = 0 self.assertEqual(replay_metric.offset, 0) - # scaling_factor = mean((final_step - offset) / num_points) - # = mean((1 - 0) / 2, (1 - 0) / 2) = mean(0.5, 0.5) = 0.5 self.assertEqual(replay_metric.scaling_factor, 0.5) experiment = Experiment( @@ -61,9 +292,7 @@ def test_map_replay(self) -> None: trial.add_arm(Arm(parameters={"x1": float(i), "x2": 0.0})) trial.run() - # fetch once for MAP_KEY = 0 experiment.fetch_data() - # the second fetch will be for MAP_KEY = 0 and MAP_KEY = 1 data = experiment.fetch_data() metric_name = [replay_metric.name] * 4 expected_df = Data( @@ -86,10 +315,6 @@ def test_map_replay_non_uniform(self) -> None: num_trials=2, num_fetches=2, num_complete=2 ) full_df = historical_experiment.lookup_data().full_df - # The original data has 6 rows: 4 for branin_map and 2 for branin. - # After assinging steps, we have following steps for branin_map: - # Trial 0: steps [0.25, 0.95] - # Trial 1: steps [0.25, 1.0] full_df[MAP_KEY] = pd.Series([0.25, 0.0, 0.95, 0.25, 0.0, 1.0]) historical_data = Data(df=full_df) replay_metric = MapDataReplayMetric( @@ -98,12 +323,7 @@ def test_map_replay_non_uniform(self) -> None: metric_name="branin_map", lower_is_better=True, ) - # Verify offset: min(first step of each trial after sorting) self.assertEqual(replay_metric.offset, 0.25) - # Verify scaling_factor: mean((final_step - offset) / num_points) across trials - # Trial 0: (0.95 - 0.25) / 2 = 0.35 - # Trial 1: (1.0 - 0.25) / 2 = 0.375 - # scaling_factor = (0.35 + 0.375) / 2 = 0.3625 self.assertEqual(replay_metric.scaling_factor, 0.3625) experiment = Experiment( @@ -138,17 +358,13 @@ def test_map_replay_non_uniform(self) -> None: ) ).full_df - # Test that as we step through with steps of size 0.3625, we - # first get both points at step 0.25. data = experiment.fetch_data() assert_frame_equal( data.full_df, full_expected_df.iloc[[0, 2]].reset_index(drop=True) ) - # Next, we add the point at step 0.95 of Trial 0. data = experiment.fetch_data() assert_frame_equal(data.full_df, full_expected_df.iloc[:3]) - # Finally, we get the point at step 1.0 of Trial 1. data = experiment.fetch_data() assert_frame_equal(data.full_df, full_expected_df.iloc[:4])