Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,17 +659,17 @@ def should_stop_trial_early(self, trial_index: int) -> bool:

es_response = none_throws(
self._early_stopping_strategy_or_choose()
).should_stop_trials_early(
).should_stop_arms(
trial_indices={trial_index},
experiment=self._experiment,
current_node=self._generation_strategy_or_choose()._curr,
)

if trial_index in es_response:
logger.info(
f"Trial {trial_index} should be stopped early: "
f"{es_response[trial_index]}"
)
# Extract reason from arm-level decisions (use first arm's reason)
arm_decisions = es_response[trial_index]
reason = next(iter(arm_decisions.values())) if arm_decisions else None
logger.info(f"Trial {trial_index} should be stopped early: {reason}")
return True

return False
Expand Down
6 changes: 6 additions & 0 deletions ax/early_stopping/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
# pyre-strict

from ax.early_stopping.strategies.base import (
BaseArmStoppingStrategy,
BaseEarlyStoppingStrategy,
ModelBasedArmStoppingStrategy,
ModelBasedEarlyStoppingStrategy,
TArmsToStop,
)
from ax.early_stopping.strategies.logical import (
AndEarlyStoppingStrategy,
Expand All @@ -20,8 +23,11 @@


__all__ = [
"BaseArmStoppingStrategy",
"BaseEarlyStoppingStrategy",
"ModelBasedArmStoppingStrategy",
"ModelBasedEarlyStoppingStrategy",
"TArmsToStop",
"PercentileEarlyStoppingStrategy",
"ThresholdEarlyStoppingStrategy",
"AndEarlyStoppingStrategy",
Expand Down
55 changes: 29 additions & 26 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import pandas as pd
from ax.adapter.data_utils import _maybe_normalize_map_key
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data, MAP_KEY
from ax.core.experiment import Experiment
from ax.core.trial_status import TrialStatus
Expand All @@ -36,10 +35,19 @@
# backwards compatibility when loading old strategies.
REMOVED_EARLY_STOPPING_STRATEGY_KWARGS: set[str] = {"trial_indices_to_ignore"}

# Type alias for arm-level stopping decisions:
# trial_index -> {arm_name -> optional_reason}
TArmsToStop = dict[int, dict[str, str | None]]

class BaseEarlyStoppingStrategy(ABC, Base):

class BaseArmStoppingStrategy(ABC, Base):
"""Interface for heuristics that halt trials early, typically based on early
results from that trial."""
results from that trial.

Stopping decisions are made at the arm level: the return type of
``should_stop_arms`` is ``dict[int, dict[str, str | None]]``
mapping ``trial_index -> {arm_name -> optional_reason}``. For single-arm
``Trial`` objects this dict will contain exactly one entry per trial."""

def __init__(
self,
Expand Down Expand Up @@ -120,12 +128,12 @@ def __init__(
self._last_check_progressions: dict[int, float] = {}

@abstractmethod
def _should_stop_trials_early(
def _should_stop_arms(
self,
trial_indices: set[int],
experiment: Experiment,
current_node: GenerationNode | None = None,
) -> dict[int, str | None]:
) -> TArmsToStop:
"""Decide whether to complete trials before evaluation is fully concluded.

Typical examples include stopping a machine learning model's training, or
Expand All @@ -140,8 +148,9 @@ def _should_stop_trials_early(
stopping decisions.

Returns:
A dictionary mapping trial indices that should be early stopped to
(optional) messages with the associated reason.
A dictionary mapping trial indices to arm-level stopping decisions.
Each value is a dict mapping arm names to (optional) reason strings
for arms that should be stopped.
"""
pass

Expand All @@ -163,12 +172,12 @@ def _is_harmful(
"""
pass

def should_stop_trials_early(
def should_stop_arms(
self,
trial_indices: set[int],
experiment: Experiment,
current_node: GenerationNode | None = None,
) -> dict[int, str | None]:
) -> TArmsToStop:
"""Decide whether trials should be stopped before evaluation is fully concluded.
This method identifies trials that should be stopped based on early signals that
are indicative of final performance. Early stopping is not applied if doing so
Expand All @@ -183,17 +192,17 @@ def should_stop_trials_early(
stopping decisions.

Returns:
A dictionary mapping trial indices that should be early stopped to
(optional) messages with the associated reason. Returns an empty
dictionary if early stopping would be harmful (when safety check is
enabled).
A dictionary mapping trial indices to arm-level stopping decisions.
Each value is a dict mapping arm names to (optional) reason strings
for arms that should be stopped. Returns an empty dictionary if
early stopping would be harmful (when safety check is enabled).
"""
if self.check_safe and self._is_harmful(
trial_indices=trial_indices,
experiment=experiment,
):
return {}
return self._should_stop_trials_early(
return self._should_stop_arms(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
Expand Down Expand Up @@ -340,17 +349,6 @@ def is_eligible_any(
then we can skip costly steps, such as model fitting, that occur before
individual trials are considered for stopping.
"""
# check for batch trials
for idx, trial in experiment.trials.items():
if isinstance(trial, BatchTrial):
# In particular, align_partial_results requires a 1-1 mapping between
# trial indices and arm names, which is not the case for batch trials.
# See align_partial_results for more details.
raise ValueError(
f"Trial {idx} is a BatchTrial, which is not yet supported by "
"early stopping strategies."
)

# check that there are sufficient completed trials
num_completed = len(experiment.trial_indices_by_status[TrialStatus.COMPLETED])
if self.min_curves is not None and num_completed < self.min_curves:
Expand Down Expand Up @@ -585,7 +583,7 @@ def _prepare_aligned_data(
return long_df, multilevel_wide_df


class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
class ModelBasedArmStoppingStrategy(BaseArmStoppingStrategy):
"""A base class for model based early stopping strategies. Includes
a helper function for processing Data into arrays."""

Expand Down Expand Up @@ -666,3 +664,8 @@ def _lookup_and_validate_data(
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
map_data = Data(df=full_df)
return map_data


# Deprecated aliases for backward compatibility.
BaseEarlyStoppingStrategy = BaseArmStoppingStrategy
ModelBasedEarlyStoppingStrategy = ModelBasedArmStoppingStrategy
69 changes: 47 additions & 22 deletions ax/early_stopping/strategies/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from functools import reduce

from ax.core.experiment import Experiment
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy, TArmsToStop
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_node import GenerationNode

Expand Down Expand Up @@ -41,25 +41,35 @@ def _is_harmful(
experiment=experiment,
)

def _should_stop_trials_early(
def _should_stop_arms(
self,
trial_indices: set[int],
experiment: Experiment,
current_node: GenerationNode | None = None,
) -> dict[int, str | None]:
left = self.left.should_stop_trials_early(
) -> TArmsToStop:
left = self.left.should_stop_arms(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
)
right = self.right.should_stop_trials_early(
right = self.right.should_stop_arms(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
)
return {
trial: f"{left[trial]}, {right[trial]}" for trial in left if trial in right
}
# Combine at the arm level: only stop arms that both strategies agree on
result: TArmsToStop = {}
for trial in left:
if trial in right:
combined_arms: dict[str, str | None] = {}
for arm_name in left[trial]:
if arm_name in right[trial]:
combined_arms[arm_name] = (
f"{left[trial][arm_name]}, {right[trial][arm_name]}"
)
if combined_arms:
result[trial] = combined_arms
return result


class OrEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
Expand Down Expand Up @@ -91,21 +101,36 @@ def _is_harmful(
experiment=experiment,
)

def _should_stop_trials_early(
def _should_stop_arms(
self,
trial_indices: set[int],
experiment: Experiment,
current_node: GenerationNode | None = None,
) -> dict[int, str | None]:
return {
**self.left.should_stop_trials_early(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
),
**self.right.should_stop_trials_early(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
),
}
) -> TArmsToStop:
left = self.left.should_stop_arms(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
)
right = self.right.should_stop_arms(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
)
# Merge at arm level: stop arms that either strategy wants to stop
result: TArmsToStop = {}
all_trials = set(left) | set(right)
for trial in all_trials:
left_arms = left.get(trial, {})
right_arms = right.get(trial, {})
merged_arms: dict[str, str | None] = {}
for arm_name in set(left_arms) | set(right_arms):
reasons = []
if arm_name in left_arms and left_arms[arm_name] is not None:
reasons.append(left_arms[arm_name])
if arm_name in right_arms and right_arms[arm_name] is not None:
reasons.append(right_arms[arm_name])
merged_arms[arm_name] = ", ".join(reasons) if reasons else None
if merged_arms:
result[trial] = merged_arms
return result
30 changes: 14 additions & 16 deletions ax/early_stopping/strategies/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ax.core.experiment import Experiment
from ax.core.trial_status import TrialStatus
from ax.early_stopping.simulation import best_trial_vulnerable
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy, TArmsToStop
from ax.early_stopping.utils import _is_worse
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.generation_strategy.generation_node import GenerationNode
Expand Down Expand Up @@ -175,12 +175,12 @@ def _is_harmful(

return simulated_result.best_stopped

def _should_stop_trials_early(
def _should_stop_arms(
self,
trial_indices: set[int],
experiment: Experiment,
current_node: GenerationNode | None = None,
) -> dict[int, str | None]:
) -> TArmsToStop:
"""Stop a trial if its performance is in the bottom `percentile_threshold`
of the trials at the same step.

Expand All @@ -193,9 +193,9 @@ def _should_stop_trials_early(
stopping decisions.

Returns:
A dictionary mapping trial indices that should be early stopped to
(optional) messages with the associated reason. An empty dictionary
means no suggested updates to any trial's status.
A dictionary mapping trial indices to arm-level stopping decisions.
Each value is a dict mapping arm names to (optional) reason strings.
An empty dictionary means no suggested updates.
"""
metric_signature, minimize = self._default_objective_and_direction(
experiment=experiment
Expand All @@ -216,21 +216,19 @@ def _should_stop_trials_early(
):
return {}

decisions = {
trial_index: self._should_stop_trial_early(
result: TArmsToStop = {}
for trial_index in trial_indices:
should_stop, reason = self._should_stop_trial_early(
trial_index=trial_index,
experiment=experiment,
wide_df=wide_df,
long_df=long_df,
minimize=minimize,
)
for trial_index in trial_indices
}
return {
trial_index: reason
for trial_index, (should_stop, reason) in decisions.items()
if should_stop
}
if should_stop:
trial = experiment.trials[trial_index]
result[trial_index] = {a.name: reason for a in trial.arms}
return result

def _should_stop_trial_early(
self,
Expand Down Expand Up @@ -287,7 +285,7 @@ def _should_stop_trial_early(
window_num_active_trials: pd.Series = window_active_trials.sum(axis=1)

# Verify that sufficiently many trials have data at each progression in
# the patience window. Note: `is_eligible_any` in `should_stop_trials_early`
# the patience window. Note: `is_eligible_any` in `should_stop_arms`
# already checks that at least `min_curves` trials have completed and uses
# `align_partial_results` to interpolate missing values. This condition
# should only trigger if `align_partial_results` fails or if this method
Expand Down
Loading
Loading