diff --git a/ax/analysis/healthcheck/early_stopping_healthcheck.py b/ax/analysis/healthcheck/early_stopping_healthcheck.py index ccc283b935a..11ee45f8739 100644 --- a/ax/analysis/healthcheck/early_stopping_healthcheck.py +++ b/ax/analysis/healthcheck/early_stopping_healthcheck.py @@ -413,7 +413,7 @@ def _report_early_stopping_nudge( try: savings = estimate_hypothetical_early_stopping_savings( experiment=experiment, - metric=metric, + metrics=[metric], max_pending_trials=self.max_pending_trials, ) except Exception as e: diff --git a/ax/early_stopping/experiment_replay.py b/ax/early_stopping/experiment_replay.py index a6b81b7226e..23f693d72f0 100644 --- a/ax/early_stopping/experiment_replay.py +++ b/ax/early_stopping/experiment_replay.py @@ -7,14 +7,20 @@ # pyre-strict import logging +import warnings from logging import Logger from time import perf_counter from ax.adapter.registry import Generators +from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.metric import Metric -from ax.core.objective import Objective -from ax.core.optimization_config import OptimizationConfig +from ax.core.objective import MultiObjective, Objective +from ax.core.optimization_config import ( + MultiObjectiveOptimizationConfig, + OptimizationConfig, +) +from ax.core.outcome_constraint import OutcomeConstraint from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.early_stopping.dispatch import get_default_ess_or_none @@ -25,7 +31,7 @@ GenerationStep, GenerationStrategy, ) -from ax.metrics.map_replay import MapDataReplayMetric +from ax.metrics.map_replay import MapDataReplayMetric, MapDataReplayState from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions from ax.runners.map_replay import MapDataReplayRunner from ax.utils.common.logger import get_logger @@ -43,16 +49,36 @@ def replay_experiment( historical_experiment: Experiment, num_samples_per_curve: int, max_replay_trials: int, - metric: Metric, + metrics: list[Metric], max_pending_trials: int, early_stopping_strategy: BaseEarlyStoppingStrategy | None, logging_level: int = logging.ERROR, ) -> Experiment | None: - """A utility function for replaying a historical experiment's data - by initializing a Orchestrator that quickly steps through the existing data. - The main purpose of this function is to compute an hypothetical capacity - savings for a given `early_stopping_strategy`. + """Replay a historical experiment's data through an Orchestrator. + + Initializes an Orchestrator that steps through existing data to compute + hypothetical capacity savings for a given ``early_stopping_strategy``. + Supports both single-objective and multi-objective optimization. + + Args: + historical_experiment: The experiment whose data to replay. + num_samples_per_curve: Deprecated. Number of samples per curve for + subsampling. Use ``step_size`` on ``MapDataReplayState`` instead. + max_replay_trials: Maximum number of trials to replay. + metrics: List of metrics to replay. For single-objective, provide + one metric. For multi-objective, provide multiple metrics. + max_pending_trials: Maximum number of pending trials for the + replay orchestrator. + early_stopping_strategy: The early stopping strategy to evaluate. + logging_level: Logging level for the orchestrator. """ + warnings.warn( + "The `num_samples_per_curve` parameter is deprecated and will be " + "removed in a future release. The `step_size` parameter on " + "`MapDataReplayState` controls replay granularity.", + DeprecationWarning, + stacklevel=2, + ) historical_map_data = historical_experiment.lookup_data() if not historical_map_data.has_step_column: logger.warning( @@ -62,16 +88,51 @@ def replay_experiment( historical_map_data = historical_map_data.subsample( limit_rows_per_group=num_samples_per_curve, include_first_last=True ) - replay_metric = MapDataReplayMetric( - name=f"replay_{historical_experiment.name}", - map_data=historical_map_data, - metric_name=metric.name, - lower_is_better=metric.lower_is_better, - ) - optimization_config = OptimizationConfig( - objective=Objective(metric=replay_metric), + + # Re-index non-contiguous trial indices to contiguous 0, 1, 2, ... + # so that replay trial N maps to the Nth historical trial. + df = historical_map_data.full_df + sorted_trial_indices = sorted(df["trial_index"].unique()) + trial_index_map = {old: new for new, old in enumerate(sorted_trial_indices)} + df = df.copy() + df["trial_index"] = df["trial_index"].map(trial_index_map) + historical_map_data = Data(df=df) + + metric_signatures = [m.signature for m in metrics] + replay_state = MapDataReplayState( + map_data=historical_map_data, metric_signatures=metric_signatures ) - runner = MapDataReplayRunner(replay_metric=replay_metric) + + replay_metrics = [ + MapDataReplayMetric( + name=m.name, + replay_state=replay_state, + metric_signature=m.signature, + lower_is_better=m.lower_is_better, + ) + for m in metrics + ] + + if len(replay_metrics) == 1: + optimization_config: OptimizationConfig = OptimizationConfig( + objective=Objective(metric=replay_metrics[0]), + ) + else: + # Extract objective thresholds from the historical experiment's config + historical_opt_config = historical_experiment.optimization_config + objective_thresholds: list[OutcomeConstraint] = [] + if isinstance(historical_opt_config, MultiObjectiveOptimizationConfig): + objective_thresholds = [ + ot.clone() for ot in historical_opt_config.objective_thresholds + ] + optimization_config = MultiObjectiveOptimizationConfig( + objective=MultiObjective( + objectives=[Objective(metric=m) for m in replay_metrics] + ), + objective_thresholds=objective_thresholds, + ) + + runner = MapDataReplayRunner(replay_state=replay_state) # Setup a new experiment with a dummy search space dummy_search_space = SearchSpace( @@ -89,10 +150,10 @@ def replay_experiment( optimization_config=optimization_config, search_space=dummy_search_space, runner=runner, - metrics=[replay_metric], + metrics=replay_metrics, ) - # Setup a Orchestrator with a dummy gs to replay the historical experiment + # Setup an Orchestrator with a dummy gs to replay the historical experiment dummy_sobol_gs = GenerationStrategy( name="sobol", steps=[ @@ -101,7 +162,7 @@ def replay_experiment( ) options = OrchestratorOptions( max_pending_trials=max_pending_trials, - total_trials=min(len(historical_experiment.trials), max_replay_trials), + total_trials=min(len(sorted_trial_indices), max_replay_trials), seconds_between_polls_backoff_factor=1.0, min_seconds_before_poll=0.0, init_seconds_between_polls=0, @@ -119,7 +180,7 @@ def replay_experiment( def estimate_hypothetical_early_stopping_savings( experiment: Experiment, - metric: Metric, + metrics: list[Metric], max_pending_trials: int = MAX_PENDING_TRIALS, ) -> float: """Estimate hypothetical early stopping savings using experiment replay. @@ -130,7 +191,7 @@ def estimate_hypothetical_early_stopping_savings( Args: experiment: The experiment to analyze. - metric: The metric to use for early stopping replay. + metrics: The metrics to use for early stopping replay. max_pending_trials: Maximum number of pending trials for the replay orchestrator. Defaults to 5. @@ -156,7 +217,7 @@ def estimate_hypothetical_early_stopping_savings( historical_experiment=experiment, num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE, max_replay_trials=MAX_REPLAY_TRIALS, - metric=metric, + metrics=metrics, max_pending_trials=max_pending_trials, early_stopping_strategy=default_ess, ) diff --git a/ax/early_stopping/tests/test_experiment_replay.py b/ax/early_stopping/tests/test_experiment_replay.py index f62873dd27b..232d1cf0477 100644 --- a/ax/early_stopping/tests/test_experiment_replay.py +++ b/ax/early_stopping/tests/test_experiment_replay.py @@ -8,18 +8,274 @@ from unittest.mock import patch +from ax.core.base_trial import TrialStatus +from ax.core.data import MAP_KEY from ax.early_stopping.experiment_replay import ( estimate_hypothetical_early_stopping_savings, + replay_experiment, ) +from ax.early_stopping.strategies.percentile import PercentileEarlyStoppingStrategy +from ax.early_stopping.utils import estimate_early_stopping_savings from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, get_branin_experiment_with_timestamp_map_metric, + get_test_map_data_experiment, ) from pyre_extensions import none_throws +class TestReplayExperiment(TestCase): + def test_single_objective_replay(self) -> None: + """Single-objective replay with heterogeneous trials.""" + historical_experiment = get_test_map_data_experiment( + num_trials=3, num_fetches=3, num_complete=3 + ) + metric_name = none_throws( + historical_experiment.optimization_config + ).objective.metric_names[0] + metric = historical_experiment.get_metric(metric_name) + + replayed = replay_experiment( + historical_experiment=historical_experiment, + num_samples_per_curve=20, + max_replay_trials=3, + metrics=[metric], + max_pending_trials=3, + early_stopping_strategy=None, + ) + replayed = none_throws(replayed) + # All trials should have been processed + self.assertGreater(len(replayed.trials), 0) + self.assertLessEqual(len(replayed.trials), 3) + + def test_multi_objective_replay(self) -> None: + """Multi-objective replay with shared state.""" + historical_experiment = get_test_map_data_experiment( + num_trials=2, + num_fetches=2, + num_complete=2, + multi_objective=True, + ) + opt_config = none_throws(historical_experiment.optimization_config) + metric_names = opt_config.objective.metric_names + metrics = [historical_experiment.get_metric(mn) for mn in metric_names] + + replayed = replay_experiment( + historical_experiment=historical_experiment, + num_samples_per_curve=20, + max_replay_trials=2, + metrics=metrics, + max_pending_trials=2, + early_stopping_strategy=None, + ) + replayed = none_throws(replayed) + # Should have replay metrics for each original metric + replay_metric_names = {m.name for m in replayed.metrics.values()} + for mn in metric_names: + self.assertIn(mn, replay_metric_names) + + def test_multi_objective_replayed_data_matches_historical(self) -> None: + """Verify that MOO replay serves correct data for every objective + metric across all trials.""" + historical_experiment = get_test_map_data_experiment( + num_trials=2, + num_fetches=3, + num_complete=2, + multi_objective=True, + ) + opt_config = none_throws(historical_experiment.optimization_config) + metric_names = opt_config.objective.metric_names + metrics = [historical_experiment.get_metric(mn) for mn in metric_names] + + replayed = none_throws( + replay_experiment( + historical_experiment=historical_experiment, + num_samples_per_curve=20, + max_replay_trials=2, + metrics=metrics, + max_pending_trials=2, + early_stopping_strategy=None, + ) + ) + + with self.subTest("all_trials_completed"): + for t in replayed.trials.values(): + self.assertEqual(t.status, TrialStatus.COMPLETED) + + historical_data = historical_experiment.lookup_data() + historical_subsampled = historical_data.subsample( + limit_rows_per_group=20, include_first_last=True + ) + hist_df = historical_subsampled.full_df + replayed_data = replayed.lookup_data() + replay_df = replayed_data.full_df + + for mn in metric_names: + with self.subTest(f"data_matches_for_{mn}"): + hist_metric_df = hist_df[hist_df["metric_name"] == mn] + sorted_hist_indices = sorted(hist_metric_df["trial_index"].unique()) + + for replay_trial_index in sorted(replay_df["trial_index"].unique()): + replay_metric_df = replay_df[ + (replay_df["trial_index"] == replay_trial_index) + & (replay_df["metric_name"] == mn) + ] + replay_steps = sorted(replay_metric_df[MAP_KEY].tolist()) + + hist_trial_index = sorted_hist_indices[int(replay_trial_index)] + hist_steps = sorted( + hist_metric_df[ + hist_metric_df["trial_index"] == hist_trial_index + ][MAP_KEY].tolist() + ) + self.assertEqual( + replay_steps, + hist_steps, + f"Metric {mn}, trial {replay_trial_index}: " + f"replayed steps {replay_steps} " + f"!= historical steps {hist_steps}", + ) + + def test_replay_with_early_stopping(self) -> None: + """End-to-end replay with a PercentileEarlyStoppingStrategy.""" + historical_experiment = get_test_map_data_experiment( + num_trials=3, num_fetches=5, num_complete=3 + ) + metric_name = none_throws( + historical_experiment.optimization_config + ).objective.metric_names[0] + metric = historical_experiment.get_metric(metric_name) + + ess = PercentileEarlyStoppingStrategy( + percentile_threshold=50.0, + min_curves=1, + min_progression=0.1, + ) + replayed = replay_experiment( + historical_experiment=historical_experiment, + num_samples_per_curve=20, + max_replay_trials=3, + metrics=[metric], + max_pending_trials=3, + early_stopping_strategy=ess, + ) + replayed = none_throws(replayed) + self.assertEqual(len(replayed.trials), 3) + + def test_replay_no_step_column(self) -> None: + """Test that replay returns None when data has no step column.""" + exp = get_branin_experiment(has_optimization_config=True) + metric_name = none_throws(exp.optimization_config).objective.metric_names[0] + metric = exp.get_metric(metric_name) + result = replay_experiment( + historical_experiment=exp, + num_samples_per_curve=20, + max_replay_trials=50, + metrics=[metric], + max_pending_trials=5, + early_stopping_strategy=None, + ) + self.assertIsNone(result) + + def test_replayed_data_matches_historical(self) -> None: + """Verify that after replay without ESS, every trial's replayed data + contains exactly the same set of MAP_KEY values and metric values + as the historical data (after subsampling).""" + historical_experiment = get_test_map_data_experiment( + num_trials=3, num_fetches=5, num_complete=3 + ) + metric_name = none_throws( + historical_experiment.optimization_config + ).objective.metric_names[0] + metric = historical_experiment.get_metric(metric_name) + + replayed = none_throws( + replay_experiment( + historical_experiment=historical_experiment, + num_samples_per_curve=20, + max_replay_trials=3, + metrics=[metric], + max_pending_trials=1, + early_stopping_strategy=None, + ) + ) + + with self.subTest("all_trials_completed"): + for t in replayed.trials.values(): + self.assertEqual(t.status, TrialStatus.COMPLETED) + + with self.subTest("replayed_data_matches_historical"): + # Subsample the historical data the same way replay_experiment does + historical_data = historical_experiment.lookup_data() + historical_subsampled = historical_data.subsample( + limit_rows_per_group=20, include_first_last=True + ) + hist_df = historical_subsampled.full_df + hist_metric_df = hist_df[hist_df["metric_name"] == metric_name] + + replayed_data = replayed.lookup_data() + replay_df = replayed_data.full_df + + # For each replayed trial, verify step values match historical + for replay_trial_index in sorted(replay_df["trial_index"].unique()): + replay_steps = sorted( + replay_df[replay_df["trial_index"] == replay_trial_index][ + MAP_KEY + ].tolist() + ) + # Map replay trial index back to historical trial index + sorted_hist_indices = sorted(hist_metric_df["trial_index"].unique()) + hist_trial_index = sorted_hist_indices[int(replay_trial_index)] + hist_steps = sorted( + hist_metric_df[hist_metric_df["trial_index"] == hist_trial_index][ + MAP_KEY + ].tolist() + ) + self.assertEqual( + replay_steps, + hist_steps, + f"Trial {replay_trial_index}: replayed steps {replay_steps} " + f"!= historical steps {hist_steps}", + ) + + def test_early_stopping_produces_savings(self) -> None: + """Verify that replay with an ESS completes successfully and + produces a valid savings estimate (>= 0).""" + historical_experiment = get_test_map_data_experiment( + num_trials=5, num_fetches=10, num_complete=5 + ) + metric_name = none_throws( + historical_experiment.optimization_config + ).objective.metric_names[0] + metric = historical_experiment.get_metric(metric_name) + + ess = PercentileEarlyStoppingStrategy( + percentile_threshold=70.0, + min_curves=2, + min_progression=0.1, + ) + replayed = none_throws( + replay_experiment( + historical_experiment=historical_experiment, + num_samples_per_curve=20, + max_replay_trials=5, + metrics=[metric], + max_pending_trials=5, + early_stopping_strategy=ess, + ) + ) + + with self.subTest("all_trials_created"): + self.assertEqual(len(replayed.trials), 5) + + with self.subTest("savings_are_valid"): + savings = estimate_early_stopping_savings(experiment=replayed) + self.assertGreaterEqual(savings, 0.0) + self.assertLessEqual(savings, 1.0) + + class TestEstimateHypotheticalEss(TestCase): def setUp(self) -> None: super().setUp() @@ -40,7 +296,7 @@ def test_estimate_hypothetical_ess_no_default_strategy(self) -> None: with self.assertRaises(UnsupportedError) as e: estimate_hypothetical_early_stopping_savings( experiment=exp, - metric=metric, + metrics=[metric], ) self.assertIn( @@ -58,7 +314,7 @@ def test_estimate_hypothetical_ess_no_progression_data(self) -> None: with self.assertRaises(UnsupportedError) as e: estimate_hypothetical_early_stopping_savings( experiment=self.exp, - metric=self.metric, + metrics=[self.metric], ) self.assertIn( @@ -79,7 +335,7 @@ def test_estimate_hypothetical_ess_success(self) -> None: ): result = estimate_hypothetical_early_stopping_savings( experiment=self.exp, - metric=self.metric, + metrics=[self.metric], ) self.assertEqual(result, 0.25) @@ -95,7 +351,7 @@ def test_estimate_hypothetical_ess_exception(self) -> None: with self.assertRaises(ValueError) as e: estimate_hypothetical_early_stopping_savings( experiment=self.exp, - metric=self.metric, + metrics=[self.metric], ) self.assertIn("Experiment's name is None.", str(e.exception)) diff --git a/ax/metrics/map_replay.py b/ax/metrics/map_replay.py index a05e36e9266..90a268c65d8 100644 --- a/ax/metrics/map_replay.py +++ b/ax/metrics/map_replay.py @@ -23,113 +23,140 @@ logger: Logger = get_logger(__name__) +class MapDataReplayState: + """Shared state coordinator for replaying historical map data. + + Manages normalized cursor-based progression across multiple metrics + and trials. The cursor model uses a global min/max MAP_KEY across + all metrics to preserve cross-metric timing alignment. + + This class serves original MAP_KEY values (not normalized). Downstream + early stopping strategies apply normalization independently via + ``_maybe_normalize_map_key`` in ``ax.adapter.data_utils``. + """ + + def __init__( + self, + map_data: Data, + metric_signatures: list[str], + step_size: float = 0.01, + ) -> None: + """Initialize replay state from historical data. + + Args: + map_data: Historical data containing progression data. + metric_signatures: List of metric signatures to replay. + step_size: Cursor increment per advancement step. Determines + the granularity of replay (e.g. 0.01 = 100 steps). + """ + self.step_size: float = step_size + + # Pre-index data by (trial_index, metric_signature) for O(1) lookups + self._data: dict[tuple[int, str], pd.DataFrame] = {} + all_trial_indices: set[int] = set() + all_prog_values: list[float] = [] + per_trial_max_prog: dict[int, float] = {} + + for metric_signature in metric_signatures: + replay_df = _prepare_replay_dataframe( + map_data=map_data, metric_signature=metric_signature + ) + for trial_index, group in replay_df.groupby("trial_index"): + trial_index = int(trial_index) + self._data[(trial_index, metric_signature)] = group.reset_index( + drop=True + ) + all_trial_indices.add(trial_index) + prog_values = group[MAP_KEY].values + all_prog_values.extend(prog_values.tolist()) + trial_max = float(prog_values.max()) + if trial_index in per_trial_max_prog: + per_trial_max_prog[trial_index] = max( + per_trial_max_prog[trial_index], trial_max + ) + else: + per_trial_max_prog[trial_index] = trial_max + + if all_prog_values: + self.min_prog: float = float(min(all_prog_values)) + self.max_prog: float = float(max(all_prog_values)) + else: + self.min_prog = 0.0 + self.max_prog = 0.0 + + self._per_trial_max_prog: dict[int, float] = per_trial_max_prog + self._trial_cursors: defaultdict[int, float] = defaultdict(float) + self._trial_indices: set[int] = all_trial_indices + + def advance_trial(self, trial_index: int) -> None: + """Advance the cursor for a trial by one resolution step.""" + self._trial_cursors[trial_index] = min( + self._trial_cursors[trial_index] + self.step_size, 1.0 + ) + + def has_trial_data(self, trial_index: int) -> bool: + """Check if any replay data exists for a given trial.""" + return trial_index in self._trial_indices + + def is_trial_complete(self, trial_index: int) -> bool: + """Check if a trial's cursor has reached its maximum progression.""" + if self.min_prog == self.max_prog: + return True + curr_prog = self.min_prog + self._trial_cursors[trial_index] * ( + self.max_prog - self.min_prog + ) + return curr_prog >= self._per_trial_max_prog.get(trial_index, 0.0) + + def get_data(self, trial_index: int, metric_signature: str) -> pd.DataFrame: + """Get replay data for a trial up to the current cursor position. + + Returns a DataFrame filtered to rows where MAP_KEY <= current + progression value, with original (non-normalized) MAP_KEY values. + """ + df = self._data.get((trial_index, metric_signature)) + if df is None: + return pd.DataFrame() + if self.min_prog == self.max_prog: + return df + curr_prog = self.min_prog + self._trial_cursors[trial_index] * ( + self.max_prog - self.min_prog + ) + return df[df[MAP_KEY] <= curr_prog] + + class MapDataReplayMetric(MapMetric): - """A metric for replaying historical map data.""" + """A metric for replaying historical map data. + + Delegates data storage and progression state to a shared + ``MapDataReplayState`` instance, allowing multiple metrics + to share the same progression timeline. + """ def __init__( self, name: str, - map_data: Data, - metric_name: str, - max_steps_validation: int | None = 200, + replay_state: MapDataReplayState, + metric_signature: str, lower_is_better: bool | None = None, ) -> None: - """Inits MapDataReplayMetric. + """Initialize a replay metric. Args: - name: The name of the metric. - map_data: Historical data to use for replaying. It is assumed that - there is a single curve (arm) per trial (i.e., no batch trials). - metric_name: The metric to replay from `map_data`. - max_steps_validation: If not None, we check to see that the inferred - scaling factor and offset does not lead to a number of replay steps - that is larger than `max_steps_validation` for any trial. + name: The name of this metric in the replay experiment. + replay_state: Shared state coordinator for replay progression. + metric_signature: The metric signature to replay from the + historical data. lower_is_better: If True, lower metric values are considered desirable. """ - self.map_data: Data = map_data - self.max_steps_validation = max_steps_validation - self.metric_name: str = metric_name - # Store pre-processed DataFrame sorted by trial_index and step - self._replay_df: pd.DataFrame = _prepare_replay_dataframe( - map_data=map_data, metric_name=self.metric_name - ) - # Pre-group by trial_index for O(1) trial lookups instead of O(n) filtering - self._trial_groups: dict[int, pd.DataFrame] = { - int(trial_idx): group - for trial_idx, group in self._replay_df.groupby("trial_index") - } - # Pre-compute trial statistics using vectorized groupby, then extract - # offset and scaling_factor once, and store only last_step as a dict - trial_stats = _compute_trial_stats(self._replay_df) - self.offset: float = trial_stats["first_step"].min() - self.scaling_factor: float = _compute_scaling_factor( - trial_stats=trial_stats, offset=self.offset - ) - # Store only last_step as dict for O(1) lookups in hot paths - # Explicitly convert keys to int for consistency with _trial_groups - self._trial_last_step: dict[int, float] = { - int(k): float(v) for k, v in trial_stats["last_step"].items() - } - self._trial_index_to_step: dict[int, int] = defaultdict(int) + self._replay_state: MapDataReplayState = replay_state + self._metric_signature: str = metric_signature super().__init__(name=name, lower_is_better=lower_is_better) - self._validate_replay_feasibility(trial_stats=trial_stats) @classmethod def is_available_while_running(cls) -> bool: return True - def _validate_replay_feasibility(self, trial_stats: pd.DataFrame) -> None: - """Check that the offset and scaling factor results in a reasonable number - of steps for all trials (i.e., we don't want an intractable number of trials - if (trial_max_step - offset) / scaling_factor is too large). - - Args: - trial_stats: DataFrame with trial statistics (first_step, last_step, - num_points). Passed in to avoid recomputing or storing it. - """ - if self.max_steps_validation is None: - return - - # Vectorized computation of max steps per trial - max_steps_per_trial = ( - trial_stats["last_step"] - self.offset - ) / self.scaling_factor - max_steps = max_steps_per_trial.max() - - # Find violating trials - violating = max_steps_per_trial[max_steps_per_trial > self.max_steps_validation] - if not violating.empty: - trial_idx = violating.index[0] - max_steps_trial = violating.iloc[0] - raise ValueError( - f"For trial {trial_idx}, the computed offset {self.offset} and " - f"scaling factor {self.scaling_factor} lead to " - f"{max_steps_trial} steps, which is larger than " - f"{self.max_steps_validation} steps to replay." - ) - logger.debug( - f"Validated MapReplayMetric {self.name} with " - f"{len(trial_stats)} trials, scaling factor = " - f"{self.scaling_factor:.2f}, and offset = {self.offset:.2f}, " - f"resulting in maximum steps = {max_steps}." - ) - - def has_trial_data(self, trial_idx: int) -> bool: - """Check if any replay data exists for a given trial.""" - # Use pre-grouped dict for O(1) lookup instead of checking DataFrame index - return trial_idx in self._trial_groups - - def more_replay_available(self, trial_idx: int) -> bool: - """Check if more replay data is available for a given trial.""" - trial_max_step = self._trial_last_step.get(trial_idx) - if trial_max_step is None: - return False - current_step = ( - self.offset + self._trial_index_to_step[trial_idx] * self.scaling_factor - ) - return current_step < trial_max_step - def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: try: if not isinstance(trial, Trial): @@ -137,29 +164,14 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult "Only (non-batch) Trials are supported by " f"{self.__class__.__name__}." ) - trial_idx = trial.index - # Increment the step counter if we can. - if trial.status.is_running and self.more_replay_available( - trial_idx=trial_idx - ): - self._trial_index_to_step[trial_idx] += 1 - trial_scaled_step = ( - self.offset + self._trial_index_to_step[trial_idx] * self.scaling_factor + trial_data = self._replay_state.get_data( + trial_index=trial.index, + metric_signature=self._metric_signature, ) - logger.info(f"Trial {trial_idx} is at step {trial_scaled_step}.") - - # Use pre-grouped data for O(1) lookup instead of filtering full DataFrame - trial_group = self._trial_groups.get(trial_idx) - if trial_group is None: - return Ok(value=Data.from_multiple_data(data=[])) - - # Filter only the trial's subset (much smaller than full DataFrame) - trial_data = trial_group[trial_group[MAP_KEY] <= trial_scaled_step] if trial_data.empty: return Ok(value=Data()) - # Create the result DataFrame in one operation result_df = pd.DataFrame( { "arm_name": none_throws(trial.arm).name, @@ -180,50 +192,15 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult ) -def _prepare_replay_dataframe(map_data: Data, metric_name: str) -> pd.DataFrame: +def _prepare_replay_dataframe(map_data: Data, metric_signature: str) -> pd.DataFrame: """Prepare a pre-sorted DataFrame for efficient replay lookups. - Filters the data to the specified metric and sorts by trial_index and step. - This allows efficient vectorized filtering during fetch_trial_data. + Filters the data to the specified metric signature and sorts by + trial_index and step. """ df = map_data.full_df - df = df[df["metric_name"] == metric_name] + df = df[df["metric_signature"] == metric_signature] # Sort once upfront for efficient lookups return df.sort_values( by=["trial_index", MAP_KEY], ascending=True, ignore_index=True ) - - -def _compute_trial_stats(replay_df: pd.DataFrame) -> pd.DataFrame: - """Compute per-trial statistics using vectorized groupby operations. - - Returns a DataFrame indexed by trial_index with columns: - - first_step: the first (minimum) step value for each trial - - last_step: the last (maximum) step value for each trial - - num_points: the number of data points per trial - """ - stats = replay_df.groupby("trial_index")[MAP_KEY].agg( - first_step="first", # Data is pre-sorted, so first/last are min/max - last_step="last", - num_points="count", - ) - return stats - - -def _compute_scaling_factor(trial_stats: pd.DataFrame, offset: float) -> float: - """Compute the scaling factor for replay data using vectorized operations. - - The scaling factor is: - `mean_{trial in trials} (max_steps_trial - offset) / num_points_trial`. - """ - # Vectorized computation of per-trial scaling factors - valid_mask = (trial_stats["num_points"] > 0) & (trial_stats["last_step"] > offset) - if not valid_mask.any(): - return 1.0 - - scaling_factors = ( - trial_stats.loc[valid_mask, "last_step"] - offset - ) / trial_stats.loc[valid_mask, "num_points"] - scaling_factor = float(scaling_factors.mean()) - - return scaling_factor if scaling_factor > 0.0 else 1.0 diff --git a/ax/metrics/tests/test_map_replay.py b/ax/metrics/tests/test_map_replay.py index c2957bd0ae5..99975096a40 100644 --- a/ax/metrics/tests/test_map_replay.py +++ b/ax/metrics/tests/test_map_replay.py @@ -12,7 +12,7 @@ from ax.core.experiment import Experiment from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig -from ax.metrics.map_replay import MapDataReplayMetric +from ax.metrics.map_replay import MapDataReplayMetric, MapDataReplayState from ax.runners.synthetic import SyntheticRunner from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -21,49 +21,286 @@ ) from pandas import DataFrame from pandas.testing import assert_frame_equal +from pyre_extensions import none_throws + + +def _make_map_data( + trial_metric_data: dict[int, dict[str, list[tuple[float, float, float]]]], +) -> Data: + """Helper to build map data from a nested dict. + + Args: + trial_metric_data: + {trial_index: {metric_name: [(step, mean, sem), ...]}} + """ + rows = [] + for trial_index, metrics in trial_metric_data.items(): + for metric_name, points in metrics.items(): + for step, mean, sem in points: + rows.append( + { + "trial_index": trial_index, + "arm_name": f"{trial_index}_0", + "metric_name": metric_name, + "metric_signature": metric_name, + "mean": mean, + "sem": sem, + MAP_KEY: step, + } + ) + return Data(df=DataFrame(rows)) + + +class MapDataReplayStateTest(TestCase): + def test_state_computation(self) -> None: + """Test min_prog, max_prog, and per_trial_max_prog for various data shapes.""" + with self.subTest("uniform_steps"): + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (1.0, 2.0, 0.0)]}, + 1: {"m1": [(0.0, 3.0, 0.0), (1.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state.min_prog, 0.0) + self.assertEqual(state.max_prog, 1.0) + self.assertEqual(state._per_trial_max_prog, {0: 1.0, 1: 1.0}) + + with self.subTest("non_uniform_steps"): + map_data = _make_map_data( + { + 0: {"m1": [(0.25, 1.0, 0.0), (0.95, 2.0, 0.0)]}, + 1: {"m1": [(0.25, 3.0, 0.0), (1.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state.min_prog, 0.25) + self.assertEqual(state.max_prog, 1.0) + self.assertEqual(state._per_trial_max_prog, {0: 0.95, 1: 1.0}) + + with self.subTest("multi_metric"): + map_data = _make_map_data( + { + 0: { + "m1": [(0.0, 1.0, 0.0), (5.0, 2.0, 0.0)], + "m2": [(1.0, 3.0, 0.0), (10.0, 4.0, 0.0)], + }, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1", "m2"] + ) + self.assertEqual(state.min_prog, 0.0) + self.assertEqual(state.max_prog, 10.0) + self.assertEqual(state._per_trial_max_prog, {0: 10.0}) + + with self.subTest("single_trial"): + map_data = _make_map_data({0: {"m1": [(0.0, 1.0, 0.0), (1.0, 2.0, 0.0)]}}) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state._trial_indices, {0}) + self.assertTrue(state.has_trial_data(trial_index=0)) + self.assertFalse(state.has_trial_data(trial_index=1)) + + with self.subTest("non_contiguous_trial_indices"): + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0)]}, + 5: {"m1": [(0.0, 2.0, 0.0)]}, + 10: {"m1": [(0.0, 3.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state._trial_indices, {0, 5, 10}) + self.assertTrue(state.has_trial_data(trial_index=5)) + self.assertFalse(state.has_trial_data(trial_index=3)) + + with self.subTest("min_equals_max_prog"): + map_data = _make_map_data( + { + 0: {"m1": [(3.0, 1.0, 0.0)]}, + 1: {"m1": [(3.0, 2.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state.min_prog, 3.0) + self.assertEqual(state.max_prog, 3.0) + self.assertTrue(state.is_trial_complete(trial_index=0)) + self.assertTrue(state.is_trial_complete(trial_index=1)) + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 1 + ) + + with self.subTest("empty_metric_data"): + map_data = _make_map_data({0: {"m1": [(0.0, 1.0, 0.0), (1.0, 2.0, 0.0)]}}) + # Request a metric signature that has no data + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1", "m_empty"] + ) + # m1 should be present, m_empty should return empty + self.assertTrue(state.has_trial_data(trial_index=0)) + self.assertTrue( + state.get_data(trial_index=0, metric_signature="m_empty").empty + ) + # min/max should be computed from m1 data only + self.assertEqual(state.min_prog, 0.0) + self.assertEqual(state.max_prog, 1.0) + + with self.subTest("different_num_points_per_trial"): + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (0.5, 2.0, 0.0), (1.0, 3.0, 0.0)]}, + 1: {"m1": [(0.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState(map_data=map_data, metric_signatures=["m1"]) + self.assertEqual(state._per_trial_max_prog, {0: 1.0, 1: 0.0}) + + def test_cursor_advancement_and_data_serving(self) -> None: + """Test cursor advancement, capping, progressive data serving, + per-trial independence, and trial completion transitions.""" + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (0.5, 2.0, 0.0), (1.0, 3.0, 0.0)]}, + 1: {"m1": [(0.0, 4.0, 0.0), (1.0, 5.0, 0.0)]}, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1"], step_size=0.5 + ) + + with self.subTest("initial_cursor_is_zero"): + self.assertEqual(state._trial_cursors[0], 0.0) + self.assertEqual(state._trial_cursors[1], 0.0) + + with self.subTest("progressive_data_at_cursor_0"): + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 1 + ) + self.assertEqual( + len(state.get_data(trial_index=1, metric_signature="m1")), 1 + ) + + with self.subTest("advance_and_check_independence"): + state.advance_trial(trial_index=0) + self.assertAlmostEqual(state._trial_cursors[0], 0.5) + self.assertAlmostEqual(state._trial_cursors[1], 0.0) + + with self.subTest("progressive_data_at_cursor_0_5"): + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 2 + ) + self.assertEqual( + len(state.get_data(trial_index=1, metric_signature="m1")), 1 + ) + + with self.subTest("advance_to_full"): + state.advance_trial(trial_index=0) + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), 3 + ) + + with self.subTest("cursor_caps_at_one"): + state.advance_trial(trial_index=0) + self.assertAlmostEqual(state._trial_cursors[0], 1.0) + + with self.subTest("trial_completion_transitions"): + self.assertTrue(state.is_trial_complete(trial_index=0)) + self.assertFalse(state.is_trial_complete(trial_index=1)) + state.advance_trial(trial_index=1) + state.advance_trial(trial_index=1) + self.assertTrue(state.is_trial_complete(trial_index=1)) + + with self.subTest("heterogeneous_max_prog_completion"): + # Trial with lower max_prog completes before trial with higher + map_data = _make_map_data( + { + 0: {"m1": [(0.0, 1.0, 0.0), (0.5, 2.0, 0.0)]}, + 1: {"m1": [(0.0, 3.0, 0.0), (1.0, 4.0, 0.0)]}, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1"], step_size=0.5 + ) + # Global range [0.0, 1.0]; trial 0 max=0.5, trial 1 max=1.0 + state.advance_trial(trial_index=0) + state.advance_trial(trial_index=1) + # cursor=0.5, curr_prog=0.5: trial 0 complete, trial 1 not + self.assertTrue(state.is_trial_complete(trial_index=0)) + self.assertFalse(state.is_trial_complete(trial_index=1)) + + def test_multi_metric_and_data_integrity(self) -> None: + """Test multi-metric shared timeline, original MAP_KEY preservation, + and get_data for nonexistent trial/metric.""" + map_data = _make_map_data( + { + 0: { + "m1": [(10.0, 1.0, 0.0), (20.0, 2.0, 0.0)], + "m2": [(10.0, 10.0, 0.0), (20.0, 20.0, 0.0)], + }, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1", "m2"], step_size=1.0 + ) + state.advance_trial(trial_index=0) + + with self.subTest("shared_timeline"): + self.assertEqual( + len(state.get_data(trial_index=0, metric_signature="m1")), + len(state.get_data(trial_index=0, metric_signature="m2")), + ) + + with self.subTest("original_map_key_values"): + self.assertListEqual( + state.get_data(trial_index=0, metric_signature="m1")[MAP_KEY].tolist(), + [10.0, 20.0], + ) + + with self.subTest("nonexistent_trial"): + self.assertTrue(state.get_data(trial_index=99, metric_signature="m1").empty) + + with self.subTest("nonexistent_metric"): + self.assertTrue( + state.get_data(trial_index=0, metric_signature="m_missing").empty + ) class MapDataReplayMetricTest(TestCase): - def test_map_replay(self) -> None: + def test_map_replay_uniform(self) -> None: + """Test metric data fetching with uniform steps.""" historical_experiment = get_test_map_data_experiment( num_trials=2, num_fetches=2, num_complete=2 ) historical_data: Data = historical_experiment.lookup_data() + state = MapDataReplayState( + map_data=historical_data, + metric_signatures=["branin_map"], + step_size=1.0, + ) replay_metric = MapDataReplayMetric( name="test_metric", - map_data=historical_data, - metric_name="branin_map", + replay_state=state, + metric_signature="branin_map", lower_is_better=True, ) - # Verify offset and scaling factor for uniform step data. - # The test data has 2 trials, each with 2 fetches, resulting in steps 0 and 1. - # offset = min(first step of each trial) = min(0, 0) = 0 - self.assertEqual(replay_metric.offset, 0) - # scaling_factor = mean((final_step - offset) / num_points) - # = mean((1 - 0) / 2, (1 - 0) / 2) = mean(0.5, 0.5) = 0.5 - self.assertEqual(replay_metric.scaling_factor, 0.5) - experiment = Experiment( name="dummy_experiment", search_space=get_branin_search_space(), optimization_config=OptimizationConfig( - objective=Objective( - metric=replay_metric, - minimize=True, - ) + objective=Objective(metric=replay_metric, minimize=True) ), tracking_metrics=[replay_metric], runner=SyntheticRunner(), ) - for i in range(0, 2): + for i in range(2): trial = experiment.new_trial() trial.add_arm(Arm(parameters={"x1": float(i), "x2": 0.0})) trial.run() - # fetch once for MAP_KEY = 0 - experiment.fetch_data() - # the second fetch will be for MAP_KEY = 0 and MAP_KEY = 1 + state.advance_trial(trial_index=0) + state.advance_trial(trial_index=1) + data = experiment.fetch_data() metric_name = [replay_metric.name] * 4 expected_df = Data( @@ -82,43 +319,36 @@ def test_map_replay(self) -> None: assert_frame_equal(data.full_df, expected_df) def test_map_replay_non_uniform(self) -> None: + """Test metric data fetching with non-uniform steps and progressive + cursor advancement.""" historical_experiment = get_test_map_data_experiment( num_trials=2, num_fetches=2, num_complete=2 ) full_df = historical_experiment.lookup_data().full_df - # The original data has 6 rows: 4 for branin_map and 2 for branin. - # After assinging steps, we have following steps for branin_map: - # Trial 0: steps [0.25, 0.95] - # Trial 1: steps [0.25, 1.0] full_df[MAP_KEY] = pd.Series([0.25, 0.0, 0.95, 0.25, 0.0, 1.0]) historical_data = Data(df=full_df) + state = MapDataReplayState( + map_data=historical_data, + metric_signatures=["branin_map"], + step_size=0.5, + ) replay_metric = MapDataReplayMetric( name="test_metric", - map_data=historical_data, - metric_name="branin_map", + replay_state=state, + metric_signature="branin_map", lower_is_better=True, ) - # Verify offset: min(first step of each trial after sorting) - self.assertEqual(replay_metric.offset, 0.25) - # Verify scaling_factor: mean((final_step - offset) / num_points) across trials - # Trial 0: (0.95 - 0.25) / 2 = 0.35 - # Trial 1: (1.0 - 0.25) / 2 = 0.375 - # scaling_factor = (0.35 + 0.375) / 2 = 0.3625 - self.assertEqual(replay_metric.scaling_factor, 0.3625) experiment = Experiment( name="dummy_experiment", search_space=get_branin_search_space(), optimization_config=OptimizationConfig( - objective=Objective( - metric=replay_metric, - minimize=True, - ) + objective=Objective(metric=replay_metric, minimize=True) ), tracking_metrics=[replay_metric], runner=SyntheticRunner(), ) - for i in range(0, 2): + for i in range(2): trial = experiment.new_trial() trial.add_arm(Arm(parameters={"x1": float(i), "x2": 0.0})) trial.run() @@ -138,17 +368,69 @@ def test_map_replay_non_uniform(self) -> None: ) ).full_df - # Test that as we step through with steps of size 0.3625, we - # first get both points at step 0.25. - data = experiment.fetch_data() - assert_frame_equal( - data.full_df, full_expected_df.iloc[[0, 2]].reset_index(drop=True) + with self.subTest("cursor_0_only_first_points"): + data = experiment.fetch_data() + assert_frame_equal( + data.full_df, full_expected_df.iloc[[0, 2]].reset_index(drop=True) + ) + + with self.subTest("cursor_0_5_intermediate"): + state.advance_trial(trial_index=0) + state.advance_trial(trial_index=1) + data = experiment.fetch_data() + assert_frame_equal( + data.full_df, full_expected_df.iloc[[0, 2]].reset_index(drop=True) + ) + + with self.subTest("cursor_1_0_all_data"): + state.advance_trial(trial_index=0) + state.advance_trial(trial_index=1) + data = experiment.fetch_data() + assert_frame_equal(data.full_df, full_expected_df) + + def test_fetch_trial_data_multi_metric(self) -> None: + """Test that fetch_trial_data returns data filtered by metric + signature when multiple metrics share the same state.""" + map_data = _make_map_data( + { + 0: { + "m1": [(0.0, 1.0, 0.0), (1.0, 2.0, 0.0)], + "m2": [(0.0, 10.0, 0.0), (1.0, 20.0, 0.0)], + }, + } + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1", "m2"], step_size=1.0 + ) + metric_m1 = MapDataReplayMetric( + name="replay_m1", replay_state=state, metric_signature="m1" + ) + metric_m2 = MapDataReplayMetric( + name="replay_m2", replay_state=state, metric_signature="m2" ) - # Next, we add the point at step 0.95 of Trial 0. - data = experiment.fetch_data() - assert_frame_equal(data.full_df, full_expected_df.iloc[:3]) + experiment = Experiment( + name="dummy", + search_space=get_branin_search_space(), + runner=SyntheticRunner(), + ) + trial = experiment.new_trial() + trial.add_arm(Arm(parameters={"x1": 0.0, "x2": 0.0})) + trial.run() + state.advance_trial(trial_index=0) - # Finally, we get the point at step 1.0 of Trial 1. - data = experiment.fetch_data() - assert_frame_equal(data.full_df, full_expected_df.iloc[:4]) + with self.subTest("m1_filtered"): + result_m1 = metric_m1.fetch_trial_data(trial=trial) + self.assertTrue(result_m1.is_ok()) + df_m1 = none_throws(result_m1.ok).full_df + self.assertEqual(len(df_m1), 2) + self.assertTrue((df_m1["metric_name"] == "replay_m1").all()) + self.assertListEqual(df_m1["mean"].tolist(), [1.0, 2.0]) + + with self.subTest("m2_filtered"): + result_m2 = metric_m2.fetch_trial_data(trial=trial) + self.assertTrue(result_m2.is_ok()) + df_m2 = none_throws(result_m2.ok).full_df + self.assertEqual(len(df_m2), 2) + self.assertTrue((df_m2["metric_name"] == "replay_m2").all()) + self.assertListEqual(df_m2["mean"].tolist(), [10.0, 20.0]) diff --git a/ax/runners/map_replay.py b/ax/runners/map_replay.py index 59ada1248c5..d6e390f6e7e 100644 --- a/ax/runners/map_replay.py +++ b/ax/runners/map_replay.py @@ -12,18 +12,21 @@ from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.runner import Runner -from ax.metrics.map_replay import MapDataReplayMetric +from ax.metrics.map_replay import MapDataReplayState STARTED_KEY = "replay_started" class MapDataReplayRunner(Runner): - """A runner that uses a `MapDataReplayMetric` to determine trial statuses. - This runner does not actually 'run' anything.""" + """A runner that determines trial statuses from a shared + ``MapDataReplayState`` and advances replay progression on each poll. - def __init__(self, replay_metric: MapDataReplayMetric) -> None: - self.replay_metric: MapDataReplayMetric = replay_metric + This runner does not actually 'run' anything. + """ + + def __init__(self, replay_state: MapDataReplayState) -> None: + self._replay_state: MapDataReplayState = replay_state def run(self, trial: BaseTrial) -> dict[str, Any]: return {STARTED_KEY: True} @@ -35,17 +38,13 @@ def poll_trial_status( self, trials: Iterable[BaseTrial] ) -> dict[TrialStatus, set[int]]: result = defaultdict(set) - # For each trial, if it hasn't been started yet by this runner, - # then mark is as a CANDIDATE. If there is no replay data - # associated with that trial at all, mark is FAILED. Otherwise, - # depending on whether or not there is more data available, - # mark it either RUNNING or COMPLETED. for t in trials: if not t.run_metadata.get(STARTED_KEY, False): result[TrialStatus.CANDIDATE].add(t.index) - elif not self.replay_metric.has_trial_data(t.index): + elif not self._replay_state.has_trial_data(trial_index=t.index): result[TrialStatus.ABANDONED].add(t.index) - elif self.replay_metric.more_replay_available(t.index): + elif not self._replay_state.is_trial_complete(trial_index=t.index): + self._replay_state.advance_trial(trial_index=t.index) result[TrialStatus.RUNNING].add(t.index) else: result[TrialStatus.COMPLETED].add(t.index) diff --git a/ax/runners/tests/test_map_replay.py b/ax/runners/tests/test_map_replay.py index 5a9b6d8f2a1..458410a7bea 100644 --- a/ax/runners/tests/test_map_replay.py +++ b/ax/runners/tests/test_map_replay.py @@ -5,62 +5,109 @@ # LICENSE file in the root directory of this source tree. # pyre-strict + from ax.core.arm import Arm -from ax.core.data import Data +from ax.core.data import Data, MAP_KEY from ax.core.experiment import Experiment from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig from ax.core.trial_status import TrialStatus -from ax.metrics.map_replay import MapDataReplayMetric +from ax.metrics.map_replay import MapDataReplayMetric, MapDataReplayState from ax.runners.map_replay import MapDataReplayRunner from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_search_space, get_test_map_data_experiment, ) +from pandas import DataFrame class MapReplayRunnerTest(TestCase): - def test_map_replay(self) -> None: + def test_trial_lifecycle(self) -> None: + """Test CANDIDATE -> RUNNING -> COMPLETED transitions, ABANDONED for + unknown trials, and cursor advancement during polling.""" historical_experiment = get_test_map_data_experiment( num_trials=2, num_fetches=2, num_complete=2 ) historical_data: Data = historical_experiment.lookup_data() + state = MapDataReplayState( + map_data=historical_data, + metric_signatures=["branin_map"], + step_size=1.0, + ) metric = MapDataReplayMetric( name="test_metric", - map_data=historical_data, - metric_name="branin_map", + replay_state=state, + metric_signature="branin_map", lower_is_better=True, ) - runner = MapDataReplayRunner( - replay_metric=metric, - ) + runner = MapDataReplayRunner(replay_state=state) experiment = Experiment( name="dummy_experiment", search_space=get_branin_search_space(), optimization_config=OptimizationConfig( - objective=Objective( - metric=metric, - minimize=True, - ) + objective=Objective(metric=metric, minimize=True) ), runner=runner, tracking_metrics=[metric], ) - for _ in range(2): + + # Create 3 trials: 2 with data (indices 0, 1), 1 without (index 2) + for _ in range(3): trial = experiment.new_trial() trial.add_arm(Arm(parameters={"x1": 0.0, "x2": 0.0})) - trial.run() - self.assertTrue(trial.run_metadata.get("replay_started")) - # After 1 fetch, both trials should still be running since there is - # still data available to replay - experiment.fetch_data() - trial_status = runner.poll_trial_status(trials=experiment.trials.values()) - self.assertEqual(trial_status[TrialStatus.RUNNING], {0, 1}) + with self.subTest("unstarted_trials_are_candidates"): + trial_status = runner.poll_trial_status(trials=experiment.trials.values()) + self.assertEqual(trial_status[TrialStatus.CANDIDATE], {0, 1, 2}) + + # Start all trials + for t in experiment.trials.values(): + t.run() + self.assertTrue(t.run_metadata.get("replay_started")) + + with self.subTest("first_poll_running_and_abandoned"): + trial_status = runner.poll_trial_status(trials=experiment.trials.values()) + # Trials 0, 1 have data -> RUNNING; trial 2 has no data -> ABANDONED + self.assertEqual(trial_status[TrialStatus.RUNNING], {0, 1}) + self.assertIn(2, trial_status[TrialStatus.ABANDONED]) + + with self.subTest("second_poll_completed"): + trial_status = runner.poll_trial_status(trials=experiment.trials.values()) + self.assertEqual(trial_status[TrialStatus.COMPLETED], {0, 1}) + + def test_cursor_advances_during_poll(self) -> None: + """Test that the runner advances the cursor for running trials on each + poll cycle.""" + map_data = Data( + df=DataFrame( + { + "trial_index": [0, 0, 0], + "arm_name": ["0_0", "0_0", "0_0"], + "metric_name": ["m1", "m1", "m1"], + "metric_signature": ["m1", "m1", "m1"], + "mean": [1.0, 2.0, 3.0], + "sem": [0.0, 0.0, 0.0], + MAP_KEY: [0.0, 0.5, 1.0], + } + ) + ) + state = MapDataReplayState( + map_data=map_data, metric_signatures=["m1"], step_size=0.25 + ) + runner = MapDataReplayRunner(replay_state=state) + + experiment = Experiment( + name="dummy", + search_space=get_branin_search_space(), + runner=runner, + ) + trial = experiment.new_trial() + trial.add_arm(Arm(parameters={"x1": 0.0, "x2": 0.0})) + trial.run() - # After 2 fetches, there is no data left to replay and both trials should - # be completed - experiment.fetch_data() - trial_status = runner.poll_trial_status(trials=experiment.trials.values()) - self.assertEqual(trial_status[TrialStatus.COMPLETED], {0, 1}) + self.assertAlmostEqual(state._trial_cursors[0], 0.0) + runner.poll_trial_status(trials=experiment.trials.values()) + self.assertAlmostEqual(state._trial_cursors[0], 0.25) + runner.poll_trial_status(trials=experiment.trials.values()) + self.assertAlmostEqual(state._trial_cursors[0], 0.50)