Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pandas as pd
from ax.adapter.data_utils import _maybe_normalize_map_key
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data, MAP_KEY
from ax.core.data import MAP_KEY
from ax.core.experiment import Experiment
from ax.core.trial_status import TrialStatus
from ax.early_stopping.utils import (
Expand Down Expand Up @@ -217,13 +217,10 @@ def estimate_early_stopping_savings(self, experiment: Experiment) -> float:

return estimate_early_stopping_savings(experiment=experiment)

def _lookup_and_validate_data(
def _lookup_and_validate(
self, experiment: Experiment, metric_signatures: list[str]
) -> Data | None:
"""Looks up and validates the `Data` used for early stopping that
is associated with `metric_signatures`. This function also handles normalizing
progressions.
"""
) -> pd.DataFrame | None:
"""Look up and validate experiment data for early stopping."""
data = experiment.lookup_data()
if data.df.empty:
logger.info(
Expand All @@ -250,6 +247,17 @@ def _lookup_and_validate_data(
full_df = data.full_df
full_df = full_df[full_df["metric_signature"].isin(metric_signatures)]

# Check that no arm name appears across multiple trials.
# This can happen with duplicate arm parameterizations that reuse arm
# names across trials, which would corrupt the alignment step.
arm_trial_counts = full_df.groupby("arm_name")["trial_index"].nunique()
bad_arms = arm_trial_counts[arm_trial_counts > 1]
if len(bad_arms) > 0:
raise UnsupportedError(
f"Arm(s) {bad_arms.index.tolist()} appear across multiple "
f"trial indices. Each arm name must map to exactly one trial."
)

# Drop rows with NaN values in MAP_KEY column to prevent issues in
# align_partial_results which uses MAP_KEY as the pivot index
nan_mask = full_df[MAP_KEY].isna()
Expand All @@ -264,7 +272,7 @@ def _lookup_and_validate_data(

if self.normalize_progressions:
full_df = _maybe_normalize_map_key(df=full_df)
return Data(df=full_df)
return full_df

@staticmethod
def _log_and_return_no_data(
Expand Down Expand Up @@ -547,7 +555,7 @@ def _all_objectives_and_directions(self, experiment: Experiment) -> dict[str, bo

return directions

def _prepare_aligned_data(
def _prepare_aligned_frames(
self, experiment: Experiment, metric_signatures: list[str]
) -> tuple[pd.DataFrame, pd.DataFrame] | None:
"""Get raw experiment data and align it for early stopping evaluation.
Expand All @@ -564,15 +572,15 @@ def _prepare_aligned_data(
with first level ["mean", "sem"] and second level metric signatures
Returns None if data cannot be retrieved or aligned.
"""
data = self._lookup_and_validate_data(
long_df = self._lookup_and_validate(
experiment=experiment, metric_signatures=metric_signatures
)
if data is None:
if long_df is None:
return None

try:
multilevel_wide_df = align_partial_results(
df=(long_df := data.full_df),
df=long_df,
metrics=metric_signatures,
)
except Exception as e:
Expand Down Expand Up @@ -651,18 +659,15 @@ def __init__(
)
self.min_progression_modeling = min_progression_modeling

def _lookup_and_validate_data(
def _lookup_and_validate(
self, experiment: Experiment, metric_signatures: list[str]
) -> Data | None:
"""Looks up and validates the `Data` used for early stopping that
is associated with `metric_signatures`. This function also handles normalizing
progressions.
) -> pd.DataFrame | None:
"""Look up and validate experiment data for early stopping, applying
``min_progression_modeling`` filter if configured.
"""
map_data = super()._lookup_and_validate_data(
df = super()._lookup_and_validate(
experiment=experiment, metric_signatures=metric_signatures
)
if map_data is not None and self.min_progression_modeling is not None:
full_df = map_data.full_df
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
map_data = Data(df=full_df)
return map_data
if df is not None and self.min_progression_modeling is not None:
df = df[df[MAP_KEY] >= self.min_progression_modeling]
return df
86 changes: 66 additions & 20 deletions ax/early_stopping/strategies/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,63 @@ def __init__(
"with multiple metrics."
)

def should_stop_trials_early(
self,
trial_indices: set[int],
experiment: Experiment,
current_node: GenerationNode | None = None,
) -> dict[int, str | None]:
"""Decide whether trials should be stopped before evaluation is fully concluded.

Overrides the base class to compute aligned data once and reuse it for
both the safety check (``_is_harmful``) and the stopping decision
(``_should_stop_trials_early``), avoiding redundant data lookups and
alignment when ``check_safe=True``.

Args:
trial_indices: Indices of candidate trials to stop early.
experiment: Experiment that contains the trials and other contextual data.
current_node: The current ``GenerationNode`` on the ``GenerationStrategy``
used to generate trials for the ``Experiment``.

Returns:
A dictionary mapping trial indices that should be early stopped to
(optional) messages with the associated reason.
"""
metric_signature, minimize = self._default_objective_and_direction(
experiment=experiment
)
aligned_frames = self._prepare_aligned_frames(
experiment=experiment, metric_signatures=[metric_signature]
)
if aligned_frames is None:
return {}

if self.check_safe and self._is_harmful(
trial_indices=trial_indices,
experiment=experiment,
metric_signature=metric_signature,
minimize=minimize,
aligned_frames=aligned_frames,
):
return {}

return self._should_stop_trials_early(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
metric_signature=metric_signature,
minimize=minimize,
aligned_frames=aligned_frames,
)

def _is_harmful(
self,
trial_indices: set[int],
experiment: Experiment,
metric_signature: str,
minimize: bool,
aligned_frames: tuple[pd.DataFrame, pd.DataFrame],
) -> bool:
"""Check if the early stopping strategy would stop the globally best trial.

Expand All @@ -139,21 +192,16 @@ def _is_harmful(
Args:
trial_indices: Set of trial indices being evaluated (ignored).
experiment: Experiment that contains the trials and other contextual data.
metric_signature: The metric signature to evaluate.
minimize: Whether the metric is being minimized.
aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple
from ``_prepare_aligned_frames``.

Returns:
True if the strategy would have stopped the globally best trial,
False otherwise.
"""
metric_signature, minimize = self._default_objective_and_direction(
experiment=experiment
)
maybe_aligned_dataframes = self._prepare_aligned_data(
experiment=experiment, metric_signatures=[metric_signature]
)
if maybe_aligned_dataframes is None:
return False

long_df, multilevel_wide_df = maybe_aligned_dataframes
long_df, multilevel_wide_df = aligned_frames
wide_df = multilevel_wide_df["mean"][metric_signature]

# Get completed trials
Expand All @@ -179,6 +227,9 @@ def _should_stop_trials_early(
self,
trial_indices: set[int],
experiment: Experiment,
metric_signature: str,
minimize: bool,
aligned_frames: tuple[pd.DataFrame, pd.DataFrame],
current_node: GenerationNode | None = None,
) -> dict[int, str | None]:
"""Stop a trial if its performance is in the bottom `percentile_threshold`
Expand All @@ -187,6 +238,10 @@ def _should_stop_trials_early(
Args:
trial_indices: Indices of candidate trials to consider for early stopping.
experiment: Experiment that contains the trials and other contextual data.
metric_signature: The metric signature to evaluate.
minimize: Whether the metric is being minimized.
aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple
from ``_prepare_aligned_frames``.
current_node: The current ``GenerationNode`` on the ``GenerationStrategy``
used to generate trials for the ``Experiment``. Early stopping
strategies may utilize components of the current node when making
Expand All @@ -197,16 +252,7 @@ def _should_stop_trials_early(
(optional) messages with the associated reason. An empty dictionary
means no suggested updates to any trial's status.
"""
metric_signature, minimize = self._default_objective_and_direction(
experiment=experiment
)
maybe_aligned_dataframes = self._prepare_aligned_data(
experiment=experiment, metric_signatures=[metric_signature]
)
if maybe_aligned_dataframes is None:
return {}

long_df, multilevel_wide_df = maybe_aligned_dataframes
long_df, multilevel_wide_df = aligned_frames
wide_df = multilevel_wide_df["mean"][metric_signature]

# default checks on `min_progression` and `min_curves`; if not met, don't do
Expand Down
6 changes: 2 additions & 4 deletions ax/early_stopping/strategies/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,13 @@ def _should_stop_trials_early(
metric_signature, minimize = self._default_objective_and_direction(
experiment=experiment
)
data = self._lookup_and_validate_data(
df = self._lookup_and_validate(
experiment=experiment, metric_signatures=[metric_signature]
)
if data is None:
if df is None:
# don't stop any trials if we don't get data back
return {}

df = data.full_df

# default checks on `min_progression` and `min_curves`; if not met, don't do
# early stopping at all and return {}
if not self.is_eligible_any(
Expand Down
Loading
Loading