Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 116 additions & 25 deletions ax/early_stopping/experiment_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
# pyre-strict

import logging
import warnings
from logging import Logger
from time import perf_counter

from ax.adapter.registry import Generators
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.metric import Metric
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.objective import MultiObjective, Objective
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.early_stopping.dispatch import get_default_ess_or_none
Expand All @@ -25,7 +31,7 @@
GenerationStep,
GenerationStrategy,
)
from ax.metrics.map_replay import MapDataReplayMetric
from ax.metrics.map_replay import MapDataReplayMetric, MapDataReplayState
from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions
from ax.runners.map_replay import MapDataReplayRunner
from ax.utils.common.logger import get_logger
Expand All @@ -43,16 +49,51 @@ def replay_experiment(
historical_experiment: Experiment,
num_samples_per_curve: int,
max_replay_trials: int,
metric: Metric,
max_pending_trials: int,
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
metrics: list[Metric] | None = None,
max_pending_trials: int = MAX_PENDING_TRIALS,
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
logging_level: int = logging.ERROR,
# Deprecated backward-compat kwarg
metric: Metric | None = None,
) -> Experiment | None:
"""A utility function for replaying a historical experiment's data
by initializing a Orchestrator that quickly steps through the existing data.
The main purpose of this function is to compute an hypothetical capacity
savings for a given `early_stopping_strategy`.
"""Replay a historical experiment's data through an Orchestrator.

Initializes an Orchestrator that steps through existing data to compute
hypothetical capacity savings for a given ``early_stopping_strategy``.
Supports both single-objective and multi-objective optimization.

Args:
historical_experiment: The experiment whose data to replay.
num_samples_per_curve: Deprecated. Number of samples per curve for
subsampling. Use ``step_size`` on ``MapDataReplayState`` instead.
max_replay_trials: Maximum number of trials to replay.
metrics: List of metrics to replay. For single-objective, provide
one metric. For multi-objective, provide multiple metrics.
max_pending_trials: Maximum number of pending trials for the
replay orchestrator.
early_stopping_strategy: The early stopping strategy to evaluate.
logging_level: Logging level for the orchestrator.
metric: Deprecated. Use ``metrics`` instead.
"""
# Backward compat: accept metric= (singular) and wrap in list
if metric is not None:
warnings.warn(
"The `metric` parameter is deprecated. Use `metrics` instead.",
DeprecationWarning,
stacklevel=2,
)
if metrics is not None:
raise ValueError("Cannot specify both `metric` and `metrics`.")
metrics = [metric]
if metrics is None:
raise ValueError("Must specify `metrics`.")
warnings.warn(
"The `num_samples_per_curve` parameter is deprecated and will be "
"removed in a future release. The `step_size` parameter on "
"`MapDataReplayState` controls replay granularity.",
DeprecationWarning,
stacklevel=2,
)
historical_map_data = historical_experiment.lookup_data()
if not historical_map_data.has_step_column:
logger.warning(
Expand All @@ -62,16 +103,51 @@ def replay_experiment(
historical_map_data = historical_map_data.subsample(
limit_rows_per_group=num_samples_per_curve, include_first_last=True
)
replay_metric = MapDataReplayMetric(
name=f"replay_{historical_experiment.name}",
map_data=historical_map_data,
metric_name=metric.name,
lower_is_better=metric.lower_is_better,
)
optimization_config = OptimizationConfig(
objective=Objective(metric=replay_metric),

# Re-index non-contiguous trial indices to contiguous 0, 1, 2, ...
# so that replay trial N maps to the Nth historical trial.
df = historical_map_data.full_df
sorted_trial_indices = sorted(df["trial_index"].unique())
trial_index_map = {old: new for new, old in enumerate(sorted_trial_indices)}
df = df.copy()
df["trial_index"] = df["trial_index"].map(trial_index_map)
historical_map_data = Data(df=df)

metric_signatures = [m.signature for m in metrics]
replay_state = MapDataReplayState(
map_data=historical_map_data, metric_signatures=metric_signatures
)
runner = MapDataReplayRunner(replay_metric=replay_metric)

replay_metrics = [
MapDataReplayMetric(
name=m.name,
replay_state=replay_state,
metric_signature=m.signature,
lower_is_better=m.lower_is_better,
)
for m in metrics
]

if len(replay_metrics) == 1:
optimization_config: OptimizationConfig = OptimizationConfig(
objective=Objective(metric=replay_metrics[0]),
)
else:
# Extract objective thresholds from the historical experiment's config
historical_opt_config = historical_experiment.optimization_config
objective_thresholds: list[OutcomeConstraint] = []
if isinstance(historical_opt_config, MultiObjectiveOptimizationConfig):
objective_thresholds = [
ot.clone() for ot in historical_opt_config.objective_thresholds
]
optimization_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(
objectives=[Objective(metric=m) for m in replay_metrics]
),
objective_thresholds=objective_thresholds,
)

runner = MapDataReplayRunner(replay_state=replay_state)

# Setup a new experiment with a dummy search space
dummy_search_space = SearchSpace(
Expand All @@ -89,10 +165,10 @@ def replay_experiment(
optimization_config=optimization_config,
search_space=dummy_search_space,
runner=runner,
metrics=[replay_metric],
metrics=replay_metrics,
)

# Setup a Orchestrator with a dummy gs to replay the historical experiment
# Setup an Orchestrator with a dummy gs to replay the historical experiment
dummy_sobol_gs = GenerationStrategy(
name="sobol",
steps=[
Expand All @@ -101,7 +177,7 @@ def replay_experiment(
)
options = OrchestratorOptions(
max_pending_trials=max_pending_trials,
total_trials=min(len(historical_experiment.trials), max_replay_trials),
total_trials=min(len(sorted_trial_indices), max_replay_trials),
seconds_between_polls_backoff_factor=1.0,
min_seconds_before_poll=0.0,
init_seconds_between_polls=0,
Expand All @@ -119,8 +195,10 @@ def replay_experiment(

def estimate_hypothetical_early_stopping_savings(
experiment: Experiment,
metric: Metric,
metrics: list[Metric] | None = None,
max_pending_trials: int = MAX_PENDING_TRIALS,
# Deprecated backward-compat kwarg
metric: Metric | None = None,
) -> float:
"""Estimate hypothetical early stopping savings using experiment replay.

Expand All @@ -130,9 +208,10 @@ def estimate_hypothetical_early_stopping_savings(

Args:
experiment: The experiment to analyze.
metric: The metric to use for early stopping replay.
metrics: The metrics to use for early stopping replay.
max_pending_trials: Maximum number of pending trials for the replay
orchestrator. Defaults to 5.
metric: Deprecated. Use ``metrics`` instead.

Returns:
Estimated savings as a fraction (0.0 to 1.0).
Expand All @@ -145,6 +224,18 @@ def estimate_hypothetical_early_stopping_savings(
- The experiment data does not have progression data for replay
- The experiment replay fails due to invalid experiment state
"""
# Backward compat: accept metric= (singular) and wrap in list
if metric is not None:
warnings.warn(
"The `metric` parameter is deprecated. Use `metrics` instead.",
DeprecationWarning,
stacklevel=2,
)
if metrics is not None:
raise ValueError("Cannot specify both `metric` and `metrics`.")
metrics = [metric]
if metrics is None:
raise ValueError("Must specify `metrics`.")
default_ess = get_default_ess_or_none(experiment=experiment)
if default_ess is None:
raise UnsupportedError(
Expand All @@ -156,7 +247,7 @@ def estimate_hypothetical_early_stopping_savings(
historical_experiment=experiment,
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
max_replay_trials=MAX_REPLAY_TRIALS,
metric=metric,
metrics=metrics,
max_pending_trials=max_pending_trials,
early_stopping_strategy=default_ess,
)
Expand Down
Loading
Loading