From 81473844bb867892cc761066e8ff24b65509c656 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Mon, 6 Apr 2026 07:20:32 -0700 Subject: [PATCH] =?UTF-8?q?Rename=20EarlyStoppingStrategy=20=E2=86=92=20Ar?= =?UTF-8?q?mStoppingStrategy=20with=20arm-level=20decisions=20(#5134)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5134 # Context The current `EarlyStoppingStrategy` was built for single-arm `Trial`s only — it returns `dict[int, str | None]` (trial index → reason) and explicitly rejects `BatchTrial`s. This diff extends it to support `BatchTrial`s by making stopping decisions at the arm level. But we are going to need batch-level stopping. # Key design decision: `ArmStoppingStrategy` or `TrialStoppingStrategy`? [in Ax TLs sync, we decided we liked `ArmStoppingStrategy` One choice we're going to need to make is whether we want arm-level and trial-level stopping strategies to be the same or separate. The use cases for arm stopping will have to do with safety, constraint violations etc. The use cases for stopping trials will have more to do with normal orchestration, e.g. we need to stop a trial in order to run a new one. # Next step: add `Runner.stop_arm` and `Runner.stop_trial` I think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too. # Another likely next step: make GS and ESS decisions jointly I think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too. I know not everyone loves this idea; let's discuss. Related discussion about a use case here: https://docs.google.com/document/d/19K3kBXX9c5WUIUu_t_gC9KZkeAMo_EffSymKjaPTSyI/edit?tab=t.0, re-pasting for convenience: - Lena [on whether an experiment that does not yet do any generation and only stopping]: My preferred design would be that we use a GNode to fit the model in GS, then the ESS uses this, but at the moment ESS is applied first, then GS (in Orchestrator and thus Axolotl). I think this current order is right as long as we don't merge ESS and GS (which I'd like to do eventually, including for reasons like this). So what we can do for now is just put a GNode within an ESS, then worry about the rest later. And we can just have an empty GS for now to keep the Orchestrator happy. - Sam: Yeah we could do that. Calling ESS before GS in the orchestrator fundamentally seems like the wrong order as we move toward model-based early stopping (e.g., in the conductor case). I wonder if we should merge GS and ESS sooner rather than later to resolve that, rather than create some tech debt by having ESS fit its own adapter. Eventually we decided that we would like to (later this year) merge ESS and GS, such that the actual cycle is `gen` --> cache results --> compare ROI on new vs. running trial --> make decision on stopping and running jointly. We thought some kind of `DecisionNode` might do this: {F1987649989} ---- # Claude stuff below Key changes: - Rename `BaseEarlyStoppingStrategy` → `BaseArmStoppingStrategy` (with backward-compat alias) - Rename `ModelBasedEarlyStoppingStrategy` → `ModelBasedArmStoppingStrategy` (with alias) - Change return type of `should_stop_trials_early` / `_should_stop_trials_early` from `dict[int, str | None]` → `dict[int, dict[str, str | None]]` (trial_index → {arm_name → reason}) - Remove `BatchTrial` rejection check in `is_eligible_any` - Add `_wrap_trial_results_with_arms()` helper to convert trial-level decisions to arm-level format - Update all subclasses (percentile, threshold, logical, multi_objective, quickbo) to use the new return type - Update orchestrator to check if all arms are stopped before stopping a trial (raises `NotImplementedError` for partial arm stopping) - Update `ax_client`, `api/client`, and `internal_client` to extract reasons from arm-level dict - Update all tests Differential Revision: D97304068 --- ax/api/client.py | 10 +- ax/early_stopping/strategies/__init__.py | 6 + ax/early_stopping/strategies/base.py | 55 +++++---- ax/early_stopping/strategies/logical.py | 69 +++++++---- ax/early_stopping/strategies/percentile.py | 30 +++-- ax/early_stopping/strategies/threshold.py | 28 ++--- ax/early_stopping/tests/test_strategies.py | 124 +++++++++----------- ax/early_stopping/utils.py | 27 ++--- ax/orchestration/orchestrator.py | 29 ++++- ax/orchestration/orchestrator_options.py | 2 +- ax/orchestration/tests/test_orchestrator.py | 44 ++++--- ax/service/ax_client.py | 13 +- ax/service/tests/test_ax_client.py | 11 +- ax/service/tests/test_early_stopping.py | 15 +-- ax/service/utils/early_stopping.py | 11 +- ax/utils/testing/core_stubs.py | 12 +- 16 files changed, 266 insertions(+), 220 deletions(-) diff --git a/ax/api/client.py b/ax/api/client.py index f059d2dc0ac..8c4ddaa6ad4 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -659,17 +659,17 @@ def should_stop_trial_early(self, trial_index: int) -> bool: es_response = none_throws( self._early_stopping_strategy_or_choose() - ).should_stop_trials_early( + ).should_stop_arms( trial_indices={trial_index}, experiment=self._experiment, current_node=self._generation_strategy_or_choose()._curr, ) if trial_index in es_response: - logger.info( - f"Trial {trial_index} should be stopped early: " - f"{es_response[trial_index]}" - ) + # Extract reason from arm-level decisions (use first arm's reason) + arm_decisions = es_response[trial_index] + reason = next(iter(arm_decisions.values())) if arm_decisions else None + logger.info(f"Trial {trial_index} should be stopped early: {reason}") return True return False diff --git a/ax/early_stopping/strategies/__init__.py b/ax/early_stopping/strategies/__init__.py index 12f808ad79b..0f1114af973 100644 --- a/ax/early_stopping/strategies/__init__.py +++ b/ax/early_stopping/strategies/__init__.py @@ -7,8 +7,11 @@ # pyre-strict from ax.early_stopping.strategies.base import ( + BaseArmStoppingStrategy, BaseEarlyStoppingStrategy, + ModelBasedArmStoppingStrategy, ModelBasedEarlyStoppingStrategy, + TArmsToStop, ) from ax.early_stopping.strategies.logical import ( AndEarlyStoppingStrategy, @@ -20,8 +23,11 @@ __all__ = [ + "BaseArmStoppingStrategy", "BaseEarlyStoppingStrategy", + "ModelBasedArmStoppingStrategy", "ModelBasedEarlyStoppingStrategy", + "TArmsToStop", "PercentileEarlyStoppingStrategy", "ThresholdEarlyStoppingStrategy", "AndEarlyStoppingStrategy", diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index 9a60f8f81c3..78e4655b94d 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -14,7 +14,6 @@ import pandas as pd from ax.adapter.data_utils import _maybe_normalize_map_key -from ax.core.batch_trial import BatchTrial from ax.core.data import Data, MAP_KEY from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus @@ -36,10 +35,19 @@ # backwards compatibility when loading old strategies. REMOVED_EARLY_STOPPING_STRATEGY_KWARGS: set[str] = {"trial_indices_to_ignore"} +# Type alias for arm-level stopping decisions: +# trial_index -> {arm_name -> optional_reason} +TArmsToStop = dict[int, dict[str, str | None]] -class BaseEarlyStoppingStrategy(ABC, Base): + +class BaseArmStoppingStrategy(ABC, Base): """Interface for heuristics that halt trials early, typically based on early - results from that trial.""" + results from that trial. + + Stopping decisions are made at the arm level: the return type of + ``should_stop_arms`` is ``dict[int, dict[str, str | None]]`` + mapping ``trial_index -> {arm_name -> optional_reason}``. For single-arm + ``Trial`` objects this dict will contain exactly one entry per trial.""" def __init__( self, @@ -120,12 +128,12 @@ def __init__( self._last_check_progressions: dict[int, float] = {} @abstractmethod - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: """Decide whether to complete trials before evaluation is fully concluded. Typical examples include stopping a machine learning model's training, or @@ -140,8 +148,9 @@ def _should_stop_trials_early( stopping decisions. Returns: - A dictionary mapping trial indices that should be early stopped to - (optional) messages with the associated reason. + A dictionary mapping trial indices to arm-level stopping decisions. + Each value is a dict mapping arm names to (optional) reason strings + for arms that should be stopped. """ pass @@ -163,12 +172,12 @@ def _is_harmful( """ pass - def should_stop_trials_early( + def should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: """Decide whether trials should be stopped before evaluation is fully concluded. This method identifies trials that should be stopped based on early signals that are indicative of final performance. Early stopping is not applied if doing so @@ -183,17 +192,17 @@ def should_stop_trials_early( stopping decisions. Returns: - A dictionary mapping trial indices that should be early stopped to - (optional) messages with the associated reason. Returns an empty - dictionary if early stopping would be harmful (when safety check is - enabled). + A dictionary mapping trial indices to arm-level stopping decisions. + Each value is a dict mapping arm names to (optional) reason strings + for arms that should be stopped. Returns an empty dictionary if + early stopping would be harmful (when safety check is enabled). """ if self.check_safe and self._is_harmful( trial_indices=trial_indices, experiment=experiment, ): return {} - return self._should_stop_trials_early( + return self._should_stop_arms( trial_indices=trial_indices, experiment=experiment, current_node=current_node, @@ -340,17 +349,6 @@ def is_eligible_any( then we can skip costly steps, such as model fitting, that occur before individual trials are considered for stopping. """ - # check for batch trials - for idx, trial in experiment.trials.items(): - if isinstance(trial, BatchTrial): - # In particular, align_partial_results requires a 1-1 mapping between - # trial indices and arm names, which is not the case for batch trials. - # See align_partial_results for more details. - raise ValueError( - f"Trial {idx} is a BatchTrial, which is not yet supported by " - "early stopping strategies." - ) - # check that there are sufficient completed trials num_completed = len(experiment.trial_indices_by_status[TrialStatus.COMPLETED]) if self.min_curves is not None and num_completed < self.min_curves: @@ -585,7 +583,7 @@ def _prepare_aligned_data( return long_df, multilevel_wide_df -class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy): +class ModelBasedArmStoppingStrategy(BaseArmStoppingStrategy): """A base class for model based early stopping strategies. Includes a helper function for processing Data into arrays.""" @@ -666,3 +664,8 @@ def _lookup_and_validate_data( full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling] map_data = Data(df=full_df) return map_data + + +# Deprecated aliases for backward compatibility. +BaseEarlyStoppingStrategy = BaseArmStoppingStrategy +ModelBasedEarlyStoppingStrategy = ModelBasedArmStoppingStrategy diff --git a/ax/early_stopping/strategies/logical.py b/ax/early_stopping/strategies/logical.py index e9d62d6b1c4..1ea9deecb85 100644 --- a/ax/early_stopping/strategies/logical.py +++ b/ax/early_stopping/strategies/logical.py @@ -9,7 +9,7 @@ from functools import reduce from ax.core.experiment import Experiment -from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy, TArmsToStop from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_node import GenerationNode @@ -41,25 +41,35 @@ def _is_harmful( experiment=experiment, ) - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: - left = self.left.should_stop_trials_early( + ) -> TArmsToStop: + left = self.left.should_stop_arms( trial_indices=trial_indices, experiment=experiment, current_node=current_node, ) - right = self.right.should_stop_trials_early( + right = self.right.should_stop_arms( trial_indices=trial_indices, experiment=experiment, current_node=current_node, ) - return { - trial: f"{left[trial]}, {right[trial]}" for trial in left if trial in right - } + # Combine at the arm level: only stop arms that both strategies agree on + result: TArmsToStop = {} + for trial in left: + if trial in right: + combined_arms: dict[str, str | None] = {} + for arm_name in left[trial]: + if arm_name in right[trial]: + combined_arms[arm_name] = ( + f"{left[trial][arm_name]}, {right[trial][arm_name]}" + ) + if combined_arms: + result[trial] = combined_arms + return result class OrEarlyStoppingStrategy(LogicalEarlyStoppingStrategy): @@ -91,21 +101,36 @@ def _is_harmful( experiment=experiment, ) - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: - return { - **self.left.should_stop_trials_early( - trial_indices=trial_indices, - experiment=experiment, - current_node=current_node, - ), - **self.right.should_stop_trials_early( - trial_indices=trial_indices, - experiment=experiment, - current_node=current_node, - ), - } + ) -> TArmsToStop: + left = self.left.should_stop_arms( + trial_indices=trial_indices, + experiment=experiment, + current_node=current_node, + ) + right = self.right.should_stop_arms( + trial_indices=trial_indices, + experiment=experiment, + current_node=current_node, + ) + # Merge at arm level: stop arms that either strategy wants to stop + result: TArmsToStop = {} + all_trials = set(left) | set(right) + for trial in all_trials: + left_arms = left.get(trial, {}) + right_arms = right.get(trial, {}) + merged_arms: dict[str, str | None] = {} + for arm_name in set(left_arms) | set(right_arms): + reasons = [] + if arm_name in left_arms and left_arms[arm_name] is not None: + reasons.append(left_arms[arm_name]) + if arm_name in right_arms and right_arms[arm_name] is not None: + reasons.append(right_arms[arm_name]) + merged_arms[arm_name] = ", ".join(reasons) if reasons else None + if merged_arms: + result[trial] = merged_arms + return result diff --git a/ax/early_stopping/strategies/percentile.py b/ax/early_stopping/strategies/percentile.py index 77995c29daa..d099b7fdf19 100644 --- a/ax/early_stopping/strategies/percentile.py +++ b/ax/early_stopping/strategies/percentile.py @@ -13,7 +13,7 @@ from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus from ax.early_stopping.simulation import best_trial_vulnerable -from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy, TArmsToStop from ax.early_stopping.utils import _is_worse from ax.exceptions.core import UnsupportedError, UserInputError from ax.generation_strategy.generation_node import GenerationNode @@ -175,12 +175,12 @@ def _is_harmful( return simulated_result.best_stopped - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: """Stop a trial if its performance is in the bottom `percentile_threshold` of the trials at the same step. @@ -193,9 +193,9 @@ def _should_stop_trials_early( stopping decisions. Returns: - A dictionary mapping trial indices that should be early stopped to - (optional) messages with the associated reason. An empty dictionary - means no suggested updates to any trial's status. + A dictionary mapping trial indices to arm-level stopping decisions. + Each value is a dict mapping arm names to (optional) reason strings. + An empty dictionary means no suggested updates. """ metric_signature, minimize = self._default_objective_and_direction( experiment=experiment @@ -216,21 +216,19 @@ def _should_stop_trials_early( ): return {} - decisions = { - trial_index: self._should_stop_trial_early( + result: TArmsToStop = {} + for trial_index in trial_indices: + should_stop, reason = self._should_stop_trial_early( trial_index=trial_index, experiment=experiment, wide_df=wide_df, long_df=long_df, minimize=minimize, ) - for trial_index in trial_indices - } - return { - trial_index: reason - for trial_index, (should_stop, reason) in decisions.items() - if should_stop - } + if should_stop: + trial = experiment.trials[trial_index] + result[trial_index] = {a.name: reason for a in trial.arms} + return result def _should_stop_trial_early( self, @@ -287,7 +285,7 @@ def _should_stop_trial_early( window_num_active_trials: pd.Series = window_active_trials.sum(axis=1) # Verify that sufficiently many trials have data at each progression in - # the patience window. Note: `is_eligible_any` in `should_stop_trials_early` + # the patience window. Note: `is_eligible_any` in `should_stop_arms` # already checks that at least `min_curves` trials have completed and uses # `align_partial_results` to interpolate missing values. This condition # should only trigger if `align_partial_results` fails or if this method diff --git a/ax/early_stopping/strategies/threshold.py b/ax/early_stopping/strategies/threshold.py index 5dac5d62a48..3aa0e61bb2e 100644 --- a/ax/early_stopping/strategies/threshold.py +++ b/ax/early_stopping/strategies/threshold.py @@ -12,7 +12,7 @@ import pandas as pd from ax.core.data import MAP_KEY from ax.core.experiment import Experiment -from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy, TArmsToStop from ax.exceptions.core import UnsupportedError from ax.generation_strategy.generation_node import GenerationNode from ax.utils.common.logger import get_logger @@ -87,12 +87,12 @@ def _is_harmful( ) -> bool: return False - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: """Stop a trial if its performance doesn't reach a pre-specified threshold by `min_progression`. @@ -105,9 +105,9 @@ def _should_stop_trials_early( stopping decisions. Returns: - A dictionary mapping trial indices that should be early stopped to - (optional) messages with the associated reason. An empty dictionary - means no suggested updates to any trial's status. + A dictionary mapping trial indices to arm-level stopping decisions. + Each value is a dict mapping arm names to (optional) reason strings. + An empty dictionary means no suggested updates. """ metric_signature, minimize = self._default_objective_and_direction( experiment=experiment @@ -129,20 +129,18 @@ def _should_stop_trials_early( return {} df_objective = df[df["metric_signature"] == metric_signature] - decisions = { - trial_index: self._should_stop_trial_early( + result: TArmsToStop = {} + for trial_index in trial_indices: + should_stop, reason = self._should_stop_trial_early( trial_index=trial_index, experiment=experiment, df=df_objective, minimize=minimize, ) - for trial_index in trial_indices - } - return { - trial_index: reason - for trial_index, (should_stop, reason) in decisions.items() - if should_stop - } + if should_stop: + trial = experiment.trials[trial_index] + result[trial_index] = {a.name: reason for a in trial.arms} + return result def _should_stop_trial_early( self, diff --git a/ax/early_stopping/tests/test_strategies.py b/ax/early_stopping/tests/test_strategies.py index 0af83a47b05..4ca83d3eaf9 100644 --- a/ax/early_stopping/tests/test_strategies.py +++ b/ax/early_stopping/tests/test_strategies.py @@ -20,6 +20,7 @@ BaseEarlyStoppingStrategy, ModelBasedEarlyStoppingStrategy, PercentileEarlyStoppingStrategy, + TArmsToStop, ThresholdEarlyStoppingStrategy, ) from ax.early_stopping.strategies.base import logger @@ -49,12 +50,12 @@ def _is_harmful( ) -> bool: return False - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: return {} @@ -66,12 +67,12 @@ def _is_harmful( ) -> bool: return False - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: if current_node is None: raise ValueError("current_node is required") return {} @@ -85,12 +86,12 @@ def _is_harmful( ) -> bool: return False - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: return {} @@ -460,16 +461,15 @@ def test_is_eligible(self, _: MagicMock) -> None: )[0] ) - # testing batch trial error + # BatchTrials are now supported at the type level (stopping decisions + # are wrapped with arm names). Verify no error is raised. experiment.new_batch_trial() - with self.assertRaisesRegex( - ValueError, "is a BatchTrial, which is not yet supported" - ): - es_strategy.is_eligible_any( - trial_indices={0}, - experiment=experiment, - df=map_data.full_df, - ) + # is_eligible_any should not raise for batch trials + es_strategy.is_eligible_any( + trial_indices={0}, + experiment=experiment, + df=map_data.full_df, + ) def test_progression_interval(self) -> None: """Test progression interval with min_progression=0.""" @@ -695,7 +695,7 @@ def test_check_safe_parameter(self) -> None: # Execute: Patch _is_harmful to verify it's not called with patch.object(strategy, "_is_harmful") as mock_is_harmful: - strategy.should_stop_trials_early( + strategy.should_stop_arms( trial_indices=trial_indices, experiment=experiment, ) @@ -711,7 +711,7 @@ def test_check_safe_parameter(self) -> None: with patch.object( strategy, "_is_harmful", return_value=False ) as mock_is_harmful: - strategy.should_stop_trials_early( + strategy.should_stop_arms( trial_indices=trial_indices, experiment=experiment, ) @@ -728,7 +728,7 @@ def test_check_safe_parameter(self) -> None: # Execute: Patch _is_harmful to return True (indicating harmful) with patch.object(strategy, "_is_harmful", return_value=True): - result = strategy.should_stop_trials_early( + result = strategy.should_stop_arms( trial_indices=trial_indices, experiment=experiment, ) @@ -752,12 +752,12 @@ def test_with_current_node(self) -> None: es_strategy = FakeStrategyRequiresNode(min_progression=3, max_progression=5) with self.assertRaisesRegex(ValueError, "current_node is required"): - es_strategy.should_stop_trials_early( + es_strategy.should_stop_arms( trial_indices={0}, experiment=exp, ) - es_strategy.should_stop_trials_early( + es_strategy.should_stop_arms( trial_indices={0}, experiment=exp, current_node=Mock() ) @@ -777,7 +777,7 @@ def test_percentile_early_stopping_strategy_validation(self, _: MagicMock) -> No exp.attach_data(data=exp.fetch_data()) # data without "step" attached - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(should_stop, {}) @@ -788,7 +788,7 @@ def test_percentile_early_stopping_strategy_validation(self, _: MagicMock) -> No trial.run() # No data attached - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(should_stop, {}) @@ -799,7 +799,7 @@ def test_percentile_early_stopping_strategy_validation(self, _: MagicMock) -> No early_stopping_strategy = PercentileEarlyStoppingStrategy( min_curves=6, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(should_stop, {}) @@ -808,7 +808,7 @@ def test_percentile_early_stopping_strategy_validation(self, _: MagicMock) -> No early_stopping_strategy = PercentileEarlyStoppingStrategy( min_progression=3, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(should_stop, {}) @@ -892,7 +892,7 @@ def _test_percentile_early_stopping_strategy( min_curves=4, min_progression=0.1, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) if metric_signatures is None: @@ -912,13 +912,13 @@ def _test_percentile_early_stopping_strategy( min_curves=4, min_progression=0.1, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(set(should_stop), {0, 3}) # respect trial_indices argument - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices={0}, experiment=exp ) self.assertEqual(set(should_stop), {0}) @@ -929,7 +929,7 @@ def _test_percentile_early_stopping_strategy( min_curves=4, min_progression=0.1, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(set(should_stop), {0, 3, 1}) @@ -941,7 +941,7 @@ def _test_percentile_early_stopping_strategy( min_curves=5, min_progression=0.1, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(should_stop, {}) @@ -974,7 +974,7 @@ def test_percentile_early_stopping_with_n_best_trials_to_complete(self) -> None: min_progression=0.1, n_best_trials_to_complete=3, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) # Only trials 0 and 3 should be stopped (trial 1 is protected as it's in top 3) @@ -987,7 +987,7 @@ def test_percentile_early_stopping_with_n_best_trials_to_complete(self) -> None: min_progression=0.1, n_best_trials_to_complete=4, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) # Only trial 0 should be stopped (trials 1, 2, 3, 4 are protected as top 4) @@ -1000,7 +1000,7 @@ def test_percentile_early_stopping_with_n_best_trials_to_complete(self) -> None: min_progression=0.1, n_best_trials_to_complete=5, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) # No trials should be stopped (all 5 are protected) @@ -1013,7 +1013,7 @@ def test_percentile_early_stopping_with_n_best_trials_to_complete(self) -> None: min_progression=0.1, n_best_trials_to_complete=10, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) # No trials should be stopped (all 5 are protected) @@ -1027,7 +1027,7 @@ def test_percentile_early_stopping_with_n_best_trials_to_complete(self) -> None: min_progression=0.1, n_best_trials_to_complete=2, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) # Trial 0 is worst and not in top 2, so should be stopped @@ -1058,12 +1058,13 @@ def test_percentile_reason_messages(self) -> None: min_curves=4, min_progression=0.1, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=trial_indices, experiment=experiment ) # Trial 0 should be stopped self.assertIn(0, should_stop) - reason = none_throws(should_stop[0]) + arm_decisions = should_stop[0] + reason = none_throws(next(iter(arm_decisions.values()))) # Verify reason contains key information in correct format self.assertRegex( reason, @@ -1304,7 +1305,7 @@ def test_patience_basic_functionality(self) -> None: min_curves=4, patience=0, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=set(exp.trials.keys()), experiment=exp ) self.assertEqual(set(should_stop), {0}) @@ -1319,7 +1320,7 @@ def test_patience_basic_functionality(self) -> None: min_curves=4, patience=2, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=set(exp.trials.keys()), experiment=exp ) self.assertEqual(set(should_stop), {0}) @@ -1369,7 +1370,7 @@ def test_patience_underperformance_patterns(self) -> None: min_progression=2, # Must be >= patience patience=2, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices={0}, experiment=exp ) # Trial 0 should NOT be stopped due to inconsistent performance @@ -1418,7 +1419,7 @@ def test_patience_underperformance_patterns(self) -> None: patience=2, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices={0}, experiment=exp ) @@ -1510,7 +1511,7 @@ def test_patience_with_n_best_trials_interaction(self) -> None: patience=2, n_best_trials_to_complete=3, ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=set(exp.trials.keys()), experiment=exp ) @@ -1663,27 +1664,14 @@ def test_early_stopping_with_unaligned_results(self) -> None: ) self.assertEqual(set(should_stop), {0, 1}) - # test error throwing in align partial results, with non-unique trial / arm name + # Multi-arm and multi-trial arm mappings are now allowed (arm_name + # is dropped before pivoting). Verify no error is raised. exp = get_test_map_data_experiment(num_trials=5, num_fetches=3, num_complete=2) - - # manually "unalign" timestamps to simulate real-world scenario - # where each curve reports results at different steps data = exp.fetch_data() df_with_single_arm_name = data.full_df.copy() df_with_single_arm_name["arm_name"] = "0_0" - with self.assertRaisesRegex( - UnsupportedError, - "Arm 0_0 has multiple trial indices", - ): - align_partial_results(df=df_with_single_arm_name, metrics=["branin_map"]) - - df_with_single_trial_index = data.full_df.copy() - df_with_single_trial_index["trial_index"] = 0 - with self.assertRaisesRegex( - UnsupportedError, - "Trial 0 has multiple arm names", - ): - align_partial_results(df=df_with_single_trial_index, metrics=["branin_map"]) + # Should not raise -- arm_name is dropped before pivot + align_partial_results(df=df_with_single_arm_name, metrics=["branin_map"]) class TestThresholdEarlyStoppingStrategy(TestCase): @@ -1716,13 +1704,13 @@ def test_threshold_early_stopping_strategy(self, _: MagicMock) -> None: early_stopping_strategy = ThresholdEarlyStoppingStrategy( metric_threshold=50, min_progression=1 ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(set(should_stop), {0, 1, 3}) # respect trial_indices argument - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices={0}, experiment=exp ) self.assertEqual(set(should_stop), {0}) @@ -1731,7 +1719,7 @@ def test_threshold_early_stopping_strategy(self, _: MagicMock) -> None: early_stopping_strategy = ThresholdEarlyStoppingStrategy( metric_threshold=50, min_progression=3 ) - should_stop = early_stopping_strategy.should_stop_trials_early( + should_stop = early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) self.assertEqual(should_stop, {}) @@ -1774,13 +1762,13 @@ def test_and_early_stopping_strategy(self, _: MagicMock) -> None: left=left_early_stopping_strategy, right=right_early_stopping_strategy ) - left_should_stop = left_early_stopping_strategy.should_stop_trials_early( + left_should_stop = left_early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) - right_should_stop = right_early_stopping_strategy.should_stop_trials_early( + right_should_stop = right_early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) - and_should_stop = and_early_stopping_strategy.should_stop_trials_early( + and_should_stop = and_early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) @@ -1835,17 +1823,17 @@ def test_or_early_stopping_strategy(self, _: MagicMock) -> None: ) ) - left_should_stop = left_early_stopping_strategy.should_stop_trials_early( + left_should_stop = left_early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) - right_should_stop = right_early_stopping_strategy.should_stop_trials_early( + right_should_stop = right_early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) - or_should_stop = or_early_stopping_strategy.should_stop_trials_early( + or_should_stop = or_early_stopping_strategy.should_stop_arms( trial_indices=idcs, experiment=exp ) or_from_collection_should_stop = ( - or_early_stopping_strategy_from_collection.should_stop_trials_early( + or_early_stopping_strategy_from_collection.should_stop_arms( trial_indices=idcs, experiment=exp ) ) diff --git a/ax/early_stopping/utils.py b/ax/early_stopping/utils.py index d434f4d7369..9c26b16057c 100644 --- a/ax/early_stopping/utils.py +++ b/ax/early_stopping/utils.py @@ -14,7 +14,6 @@ from ax.core.data import MAP_KEY from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus -from ax.exceptions.core import UnsupportedError from ax.utils.common.logger import get_logger logger: Logger = get_logger(__name__) @@ -178,25 +177,13 @@ def align_partial_results( ) else: logger.info(f"No data from metric {m} yet.") - # drop arm names (assumes 1:1 map between trial indices and arm names) - # NOTE: this is not the case for BatchTrials and repeated arms - # if we didn't catch that there were multiple arms per trial, the interpolation - # code below would interpolate between data points from potentially different arms, - # as only the trial index is used to differentiate distinct data for interpolation. - for trial_index, trial_group in df.groupby("trial_index"): - if len(trial_group["arm_name"].unique()) != 1: - raise UnsupportedError( - f"Trial {trial_index} has multiple arm names: " - f"{trial_group['arm_name'].unique()}." - ) - - for arm_name, arm_group in df.groupby("arm_name"): - if len(arm_group["trial_index"].unique()) != 1: - raise UnsupportedError( - f"Arm {arm_name} has multiple trial indices: " - f"{arm_group['trial_index'].unique()}." - ) - + # Drop arm_name column before pivoting. The pivot operates on trial_index + # only. For single-arm trials (the common case) this is a no-op. For batch + # trials with multiple arms, rows are currently aggregated by trial_index; + # per-arm early stopping decisions are handled at the strategy level (each + # strategy natively returns arm-level results). + # TODO: For true per-arm stopping with BatchTrials, pivot on + # (trial_index, arm_name) and update downstream indexing logic. df = df.drop("arm_name", axis=1) # remove duplicates (same trial, metric, step), which can happen # if the same progression is erroneously reported more than once diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index f6ec4abebe0..d59fd890407 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -1346,20 +1346,39 @@ def poll_and_process_results(self, poll_all_trial_statuses: bool = False) -> boo # stopping decisions are based on trial data. This avoids redundant # expensive checks when no new data has been fetched. if len(trial_indices_with_new_data) > 0: - stop_trial_info = early_stopping_utils.should_stop_trials_early( + stop_arm_info = early_stopping_utils.should_stop_arms( early_stopping_strategy=self.options.early_stopping_strategy, trial_indices=self.running_trial_indices, experiment=self.experiment, current_node=self.generation_strategy._curr, ) + # For each trial, check if ALL arms are stopped. + # Partial arm stopping is not yet supported. + trials_to_stop: list[int] = [] + for trial_idx, arm_decisions in stop_arm_info.items(): + trial = self.experiment.trials[trial_idx] + trial_arm_names = {a.name for a in trial.arms} + if set(arm_decisions.keys()) == trial_arm_names: + trials_to_stop.append(trial_idx) + else: + raise NotImplementedError( + "Partial arm stopping is not yet supported. " + f"Trial {trial_idx} has arms {trial_arm_names} but " + f"only arms {set(arm_decisions.keys())} were stopped. " + "Runner.stop_arm() is needed." + ) + # Build reasons list from first arm's reason for each trial + reasons = [ + next(iter(stop_arm_info[idx].values())) for idx in trials_to_stop + ] self.experiment.stop_trial_runs( trials=[ - self.experiment.trials[trial_idx] for trial_idx in stop_trial_info + self.experiment.trials[trial_idx] for trial_idx in trials_to_stop ], - reasons=list(stop_trial_info.values()), + reasons=reasons, ) - if len(stop_trial_info) > 0: - trial_indices_with_updated_data_or_status.update(set(stop_trial_info)) + if len(trials_to_stop) > 0: + trial_indices_with_updated_data_or_status.update(set(trials_to_stop)) updated_any_trial_status = True # UPDATE TRIALS IN DB diff --git a/ax/orchestration/orchestrator_options.py b/ax/orchestration/orchestrator_options.py index 70ab3b04cf3..1add4d2c9a3 100644 --- a/ax/orchestration/orchestrator_options.py +++ b/ax/orchestration/orchestrator_options.py @@ -94,7 +94,7 @@ class OrchestratorOptions: debug_log_run_metadata: Whether to log run_metadata for debugging purposes. early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines whether a trial should be stopped given the current state of - the experiment. Used in ``should_stop_trials_early``. + the experiment. Used in ``should_stop_arms``. global_stopping_strategy: A ``BaseGlobalStoppingStrategy`` that determines whether the full optimization should be stopped or not. suppress_storage_errors_after_retries: Whether to fully suppress SQL diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 6cead6408ac..2f0a6ecd897 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -38,6 +38,7 @@ get_pending_observation_features_based_on_trial_status, ) from ax.early_stopping.strategies import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies.base import TArmsToStop from ax.exceptions.core import ( AxError, OptimizationComplete, @@ -1161,12 +1162,24 @@ def test_run_trials_and_yield_results_with_early_stopper(self) -> None: ), db_settings=self.db_settings_if_always_needed, ) + + def _mock_stop_all( + early_stopping_strategy: Any, + trial_indices: set[int], + experiment: Any, + **kwargs: Any, + ) -> TArmsToStop: + return { + idx: {a.name: None for a in experiment.trials[idx].arms} + for idx in trial_indices + } + # All trials should be marked complete after one run. with ( patch( - "ax.service.utils.early_stopping.should_stop_trials_early", - wraps=lambda trial_indices, **kwargs: dict.fromkeys(trial_indices), - ) as mock_should_stop_trials_early, + "ax.service.utils.early_stopping.should_stop_arms", + wraps=_mock_stop_all, + ) as mock_should_stop_arms, patch.object( InfinitePollRunner, "stop", return_value=None ) as mock_stop_trial_run, @@ -1189,9 +1202,7 @@ def test_run_trials_and_yield_results_with_early_stopper(self) -> None: ) # Third trial in second batch of parallelism will be early stopped self.assertEqual(len(res_list[1]["trials_early_stopped_so_far"]), 3) - self.assertEqual( - mock_should_stop_trials_early.call_count, expected_num_polls - ) + self.assertEqual(mock_should_stop_arms.call_count, expected_num_polls) self.assertEqual( mock_stop_trial_run.call_count, len(res_list[1]["trials_early_stopped_so_far"]), @@ -1249,7 +1260,7 @@ def track_fetch_results(*args: Any, **kwargs: Any) -> set[int]: with ( patch( - "ax.service.utils.early_stopping.should_stop_trials_early", + "ax.service.utils.early_stopping.should_stop_arms", return_value={}, ) as mock_should_stop, patch.object( @@ -1279,17 +1290,22 @@ def _is_harmful( # Trials with odd indices will be early stopped # Thus, with 3 total trials, trial #1 will be early stopped - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: - return { - idx: f"Trial {idx} stopped by OddIndexEarlyStoppingStrategy" - for idx in trial_indices - if idx % 2 == 1 - } + ) -> TArmsToStop: + result: TArmsToStop = {} + for idx in trial_indices: + if idx % 2 == 1: + trial = experiment.trials[idx] + result[idx] = { + a.name: f"Trial {idx} stopped by " + "OddIndexEarlyStoppingStrategy" + for a in trial.arms + } + return result self.branin_timestamp_map_metric_experiment.runner = ( RunnerWithEarlyStoppingStrategy() diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 50c8045f0ab..82b3b434156 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -42,6 +42,7 @@ ) from ax.core.utils import compute_metric_availability, MetricAvailability from ax.early_stopping.strategies import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies.base import TArmsToStop from ax.early_stopping.utils import estimate_early_stopping_savings from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION from ax.exceptions.core import ( @@ -168,7 +169,7 @@ class AxClient(AnalysisBase, BestPointMixin, InstantiationBase): early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines whether a trial should be stopped given the current state of - the experiment. Used in ``should_stop_trials_early``. + the experiment. Used in ``should_stop_arms``. global_stopping_strategy: A ``BaseGlobalStoppingStrategy`` that determines whether the full optimization should be stopped or not. @@ -1432,24 +1433,22 @@ def verify_trial_parameterization( none_throws(self.get_trial(trial_index).arm).parameters == parameterization ) - def should_stop_trials_early( - self, trial_indices: set[int] - ) -> dict[int, str | None]: + def should_stop_arms(self, trial_indices: set[int]) -> TArmsToStop: """Evaluate whether to early-stop running trials. Args: trial_indices: Indices of trials to consider for early stopping. Returns: - A dictionary mapping trial indices that should be early stopped to - (optional) messages with the associated reason. + A dictionary mapping trial indices to arm-level stopping decisions. + Each value is a dict mapping arm names to (optional) reason strings. """ if self._early_stopping_strategy is None: logger.warning( "No early_stopping_strategy was passed to AxClient. " "Defaulting to never stopping any trials early." ) - return early_stopping_utils.should_stop_trials_early( + return early_stopping_utils.should_stop_arms( early_stopping_strategy=self._early_stopping_strategy, trial_indices=trial_indices, experiment=self.experiment, diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 386f508165b..f891486aa04 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -48,6 +48,7 @@ TParameterization, TParamValue, ) +from ax.early_stopping.strategies.base import TArmsToStop from ax.exceptions.core import ( DataRequiredError, OptimizationComplete, @@ -3166,10 +3167,10 @@ def test_with_hss(self) -> None: {"model": "Linear", "learning_rate": 1, "l2_reg_weight": 0.0001} ) - def test_should_stop_trials_early(self) -> None: - expected: dict[int, str | None] = { - 1: "Stopped due to testing.", - 3: "Stopped due to testing.", + def test_should_stop_arms(self) -> None: + expected: TArmsToStop = { + 1: {"1_0": "Stopped due to testing."}, + 3: {"3_0": "Stopped due to testing."}, } ax_client = AxClient( early_stopping_strategy=DummyEarlyStoppingStrategy(expected) @@ -3181,7 +3182,7 @@ def test_should_stop_trials_early(self) -> None: ], support_intermediate_data=True, ) - actual = ax_client.should_stop_trials_early(trial_indices={1, 2, 3}) + actual = ax_client.should_stop_arms(trial_indices={1, 2, 3}) self.assertEqual(actual, expected) def test_stop_trial_early(self) -> None: diff --git a/ax/service/tests/test_early_stopping.py b/ax/service/tests/test_early_stopping.py index b2ca0714042..b7a33a905ab 100644 --- a/ax/service/tests/test_early_stopping.py +++ b/ax/service/tests/test_early_stopping.py @@ -6,6 +6,7 @@ # pyre-strict +from ax.early_stopping.strategies.base import TArmsToStop from ax.service.utils import early_stopping as early_stopping_utils from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -22,20 +23,20 @@ def setUp(self) -> None: super().setUp() self.branin_experiment = get_branin_experiment() - def test_should_stop_trials_early(self) -> None: - expected: dict[int, str | None] = { - 1: "Stopped due to testing.", - 3: "Stopped due to testing.", + def test_should_stop_arms(self) -> None: + expected: TArmsToStop = { + 1: {"1_0": "Stopped due to testing."}, + 3: {"3_0": "Stopped due to testing."}, } - actual = early_stopping_utils.should_stop_trials_early( + actual = early_stopping_utils.should_stop_arms( early_stopping_strategy=DummyEarlyStoppingStrategy(expected), trial_indices={1, 2, 3}, experiment=self.branin_experiment, ) self.assertEqual(actual, expected) - def test_should_stop_trials_early_no_strategy(self) -> None: - actual = early_stopping_utils.should_stop_trials_early( + def test_should_stop_arms_no_strategy(self) -> None: + actual = early_stopping_utils.should_stop_arms( early_stopping_strategy=None, trial_indices={1, 2, 3}, experiment=self.branin_experiment, diff --git a/ax/service/utils/early_stopping.py b/ax/service/utils/early_stopping.py index 75b285ce1be..42d36cdeb28 100644 --- a/ax/service/utils/early_stopping.py +++ b/ax/service/utils/early_stopping.py @@ -8,16 +8,17 @@ from ax.core.experiment import Experiment from ax.early_stopping.strategies import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies.base import TArmsToStop from ax.generation_strategy.generation_node import GenerationNode from pyre_extensions import none_throws -def should_stop_trials_early( +def should_stop_arms( early_stopping_strategy: BaseEarlyStoppingStrategy | None, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, -) -> dict[int, str | None]: +) -> TArmsToStop: """Evaluate whether to early-stop running trials. Args: @@ -30,14 +31,14 @@ def should_stop_trials_early( strategies may utilize components of the current node when making stopping decisions. Returns: - A dictionary mapping trial indices that should be early stopped to - (optional) messages with the associated reason. + A dictionary mapping trial indices to arm-level stopping decisions. + Each value is a dict mapping arm names to (optional) reason strings. """ if early_stopping_strategy is None: return {} early_stopping_strategy = none_throws(early_stopping_strategy) - return early_stopping_strategy.should_stop_trials_early( + return early_stopping_strategy.should_stop_arms( trial_indices=trial_indices, experiment=experiment, current_node=current_node, diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 7c143d45463..9535c6dfea3 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -72,6 +72,7 @@ from ax.early_stopping.strategies import ( BaseEarlyStoppingStrategy, PercentileEarlyStoppingStrategy, + TArmsToStop, ThresholdEarlyStoppingStrategy, ) from ax.early_stopping.strategies.logical import ( @@ -2782,9 +2783,12 @@ def get_or_early_stopping_strategy() -> OrEarlyStoppingStrategy: class DummyEarlyStoppingStrategy(BaseEarlyStoppingStrategy): - def __init__(self, early_stop_trials: dict[int, str | None] | None = None) -> None: + def __init__( + self, + early_stop_trials: TArmsToStop | None = None, + ) -> None: super().__init__() - self.early_stop_trials: dict[int, str | None] = early_stop_trials or {} + self.early_stop_trials: TArmsToStop = early_stop_trials or {} def _is_harmful( self, @@ -2793,12 +2797,12 @@ def _is_harmful( ) -> bool: return False - def _should_stop_trials_early( + def _should_stop_arms( self, trial_indices: set[int], experiment: Experiment, current_node: GenerationNode | None = None, - ) -> dict[int, str | None]: + ) -> TArmsToStop: return self.early_stop_trials