From c6a49ee795ba434d02062d0ee8bbde5cc9a781b6 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Mon, 6 Apr 2026 15:20:19 -0700 Subject: [PATCH] Clean up align_partial_results, eliminate double alignment, and remove Data round-trip Summary: Profiling-driven cleanup and performance improvements to the early stopping pipeline. **1. Remove redundant validations from `align_partial_results` (utils.py)** `align_partial_results` contained three validation/logging blocks that are redundant with upstream checks already performed by `_lookup_and_validate` in `base.py`: - **Missing metrics check + `raise ValueError`:** `_lookup_and_validate` already verifies each metric signature exists in the data and returns `None` before `align_partial_results` is ever called. - **Per-metric logging loop:** The "no data" branch is unreachable because the upstream checks guarantee data exists for each metric. Debug logging about MAP_KEY ranges is misplaced for a pure alignment function. - **Trial-to-arm uniqueness check:** Redundant with `is_eligible_any` which already rejects `BatchTrial` -- the only way a trial gets multiple arms. The arm-to-trial uniqueness check was moved (not removed) to `_lookup_and_validate`, where it properly guards all downstream consumers, not just alignment. The `isin` filter (`df = df[df['metric_signature'].isin(metrics)]`) was kept -- it is part of `align_partial_results`'s own contract since the function accepts a `metrics` argument and callers depend on it filtering to those metrics. After cleanup, `align_partial_results` is a focused alignment function: filter to requested metrics -> drop `arm_name` -> drop duplicates -> sort -> pivot -> interpolate. **2. Eliminate double data alignment in `PercentileEarlyStoppingStrategy` (percentile.py)** When `check_safe=True`, the base class `should_stop_trials_early` calls `_is_harmful()` which calls `_prepare_aligned_frames()`, then `_should_stop_trials_early()` calls `_prepare_aligned_frames()` again -- identical data lookup + alignment running twice. Fix: override `should_stop_trials_early` to call `_default_objective_and_direction()` and `_prepare_aligned_frames()` once, passing results as required arguments (`metric_signature`, `minimize`, `aligned_frames`) to both `_is_harmful` and `_should_stop_trials_early`. **3. Eliminate the `Data` round-trip in `_lookup_and_validate` (base.py, threshold.py)** `_lookup_and_validate` (formerly `_lookup_and_validate_data`) used to wrap its filtered DataFrame back into `Data(df=filtered_df)` at the end, purely to satisfy its `Data | None` return type. The sole consumer (`_prepare_aligned_frames`) immediately called `.full_df` to get the DataFrame back. This triggered a pointless df-to-DataRow-to-df round-trip on every call: ``` DataFrame -> itertuples -> list[DataRow] -> from_records -> regex sort -> cast -> DataFrame ``` Profiling shows this round-trip is **100% overhead** at every scale: | Scale | Rows | Round-trip cost | Replacement (df.copy) | |-----------------|---------|----------------|-----------------------| | tiny (5x10) | 50 | 4ms | <0.1ms | | typical (20x100)| 2,000 | 15ms | <0.1ms | | large (50x200) | 10,000 | 60ms | <0.1ms | | xlarge (100x200)| 20,000 | 160ms | <0.1ms | | huge (200x500) | 100,000 | 733ms | <0.1ms | The cost comes from `Data.__init__` iterating every row via `itertuples()` to build `list[DataRow]` (~35ms at 10k rows), then `Data.full_df` reconstructing the DataFrame via `from_records()` (~12ms) and running regex-based arm name parsing in `sort_by_trial_index_and_arm_name()` (~19ms) -- none of which is needed since we already had the DataFrame. Fix: change `_lookup_and_validate` to return `pd.DataFrame | None` directly. Update `ModelBasedEarlyStoppingStrategy` and `ThresholdEarlyStoppingStrategy` accordingly. End-to-end profiling of `should_stop_trials_early` at the 50x200 scale (10k rows) shows a ~19% speedup (323ms -> ~263ms), with the benefit growing at larger scales (up to ~41% at 100k rows). **4. Naming cleanup** Renamed methods and variables to disambiguate between `Data` objects and `pd.DataFrame`: - `_lookup_and_validate_data` -> `_lookup_and_validate` (returns `pd.DataFrame | None`) - `_prepare_aligned_data` -> `_prepare_aligned_frames` (returns a tuple of DataFrames) - Variables holding `_lookup_and_validate` results: `data`/`map_data`/`data_lookup` -> `df` - Parameter `aligned_data` -> `aligned_frames` **5. Tests** - Removed tests for moved/removed `align_partial_results` validations. - Added arm-to-trial uniqueness check as a subtest in `test_is_eligible`. - Added profiling notebook at `ax/early_stopping/profiling.ipynb`. Reviewed By: saitcakmak Differential Revision: D98544835 --- ax/early_stopping/strategies/base.py | 51 +++--- ax/early_stopping/strategies/percentile.py | 86 +++++++--- ax/early_stopping/strategies/threshold.py | 6 +- ax/early_stopping/tests/test_strategies.py | 173 ++++++++++----------- ax/early_stopping/utils.py | 38 ----- 5 files changed, 180 insertions(+), 174 deletions(-) diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index 9a60f8f81c3..c791c980b30 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -15,7 +15,7 @@ import pandas as pd from ax.adapter.data_utils import _maybe_normalize_map_key from ax.core.batch_trial import BatchTrial -from ax.core.data import Data, MAP_KEY +from ax.core.data import MAP_KEY from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus from ax.early_stopping.utils import ( @@ -217,13 +217,10 @@ def estimate_early_stopping_savings(self, experiment: Experiment) -> float: return estimate_early_stopping_savings(experiment=experiment) - def _lookup_and_validate_data( + def _lookup_and_validate( self, experiment: Experiment, metric_signatures: list[str] - ) -> Data | None: - """Looks up and validates the `Data` used for early stopping that - is associated with `metric_signatures`. This function also handles normalizing - progressions. - """ + ) -> pd.DataFrame | None: + """Look up and validate experiment data for early stopping.""" data = experiment.lookup_data() if data.df.empty: logger.info( @@ -250,6 +247,17 @@ def _lookup_and_validate_data( full_df = data.full_df full_df = full_df[full_df["metric_signature"].isin(metric_signatures)] + # Check that no arm name appears across multiple trials. + # This can happen with duplicate arm parameterizations that reuse arm + # names across trials, which would corrupt the alignment step. + arm_trial_counts = full_df.groupby("arm_name")["trial_index"].nunique() + bad_arms = arm_trial_counts[arm_trial_counts > 1] + if len(bad_arms) > 0: + raise UnsupportedError( + f"Arm(s) {bad_arms.index.tolist()} appear across multiple " + f"trial indices. Each arm name must map to exactly one trial." + ) + # Drop rows with NaN values in MAP_KEY column to prevent issues in # align_partial_results which uses MAP_KEY as the pivot index nan_mask = full_df[MAP_KEY].isna() @@ -264,7 +272,7 @@ def _lookup_and_validate_data( if self.normalize_progressions: full_df = _maybe_normalize_map_key(df=full_df) - return Data(df=full_df) + return full_df @staticmethod def _log_and_return_no_data( @@ -547,7 +555,7 @@ def _all_objectives_and_directions(self, experiment: Experiment) -> dict[str, bo return directions - def _prepare_aligned_data( + def _prepare_aligned_frames( self, experiment: Experiment, metric_signatures: list[str] ) -> tuple[pd.DataFrame, pd.DataFrame] | None: """Get raw experiment data and align it for early stopping evaluation. @@ -564,15 +572,15 @@ def _prepare_aligned_data( with first level ["mean", "sem"] and second level metric signatures Returns None if data cannot be retrieved or aligned. """ - data = self._lookup_and_validate_data( + long_df = self._lookup_and_validate( experiment=experiment, metric_signatures=metric_signatures ) - if data is None: + if long_df is None: return None try: multilevel_wide_df = align_partial_results( - df=(long_df := data.full_df), + df=long_df, metrics=metric_signatures, ) except Exception as e: @@ -651,18 +659,15 @@ def __init__( ) self.min_progression_modeling = min_progression_modeling - def _lookup_and_validate_data( + def _lookup_and_validate( self, experiment: Experiment, metric_signatures: list[str] - ) -> Data | None: - """Looks up and validates the `Data` used for early stopping that - is associated with `metric_signatures`. This function also handles normalizing - progressions. + ) -> pd.DataFrame | None: + """Look up and validate experiment data for early stopping, applying + ``min_progression_modeling`` filter if configured. """ - map_data = super()._lookup_and_validate_data( + df = super()._lookup_and_validate( experiment=experiment, metric_signatures=metric_signatures ) - if map_data is not None and self.min_progression_modeling is not None: - full_df = map_data.full_df - full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling] - map_data = Data(df=full_df) - return map_data + if df is not None and self.min_progression_modeling is not None: + df = df[df[MAP_KEY] >= self.min_progression_modeling] + return df diff --git a/ax/early_stopping/strategies/percentile.py b/ax/early_stopping/strategies/percentile.py index 77995c29daa..84e727b4367 100644 --- a/ax/early_stopping/strategies/percentile.py +++ b/ax/early_stopping/strategies/percentile.py @@ -124,10 +124,63 @@ def __init__( "with multiple metrics." ) + def should_stop_trials_early( + self, + trial_indices: set[int], + experiment: Experiment, + current_node: GenerationNode | None = None, + ) -> dict[int, str | None]: + """Decide whether trials should be stopped before evaluation is fully concluded. + + Overrides the base class to compute aligned data once and reuse it for + both the safety check (``_is_harmful``) and the stopping decision + (``_should_stop_trials_early``), avoiding redundant data lookups and + alignment when ``check_safe=True``. + + Args: + trial_indices: Indices of candidate trials to stop early. + experiment: Experiment that contains the trials and other contextual data. + current_node: The current ``GenerationNode`` on the ``GenerationStrategy`` + used to generate trials for the ``Experiment``. + + Returns: + A dictionary mapping trial indices that should be early stopped to + (optional) messages with the associated reason. + """ + metric_signature, minimize = self._default_objective_and_direction( + experiment=experiment + ) + aligned_frames = self._prepare_aligned_frames( + experiment=experiment, metric_signatures=[metric_signature] + ) + if aligned_frames is None: + return {} + + if self.check_safe and self._is_harmful( + trial_indices=trial_indices, + experiment=experiment, + metric_signature=metric_signature, + minimize=minimize, + aligned_frames=aligned_frames, + ): + return {} + + return self._should_stop_trials_early( + trial_indices=trial_indices, + experiment=experiment, + current_node=current_node, + metric_signature=metric_signature, + minimize=minimize, + aligned_frames=aligned_frames, + ) + def _is_harmful( self, trial_indices: set[int], experiment: Experiment, + metric_signature: str, + minimize: bool, + aligned_frames: tuple[pd.DataFrame, pd.DataFrame], ) -> bool: """Check if the early stopping strategy would stop the globally best trial. @@ -139,21 +192,16 @@ def _is_harmful( Args: trial_indices: Set of trial indices being evaluated (ignored). experiment: Experiment that contains the trials and other contextual data. + metric_signature: The metric signature to evaluate. + minimize: Whether the metric is being minimized. + aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple + from ``_prepare_aligned_frames``. Returns: True if the strategy would have stopped the globally best trial, False otherwise. """ - metric_signature, minimize = self._default_objective_and_direction( - experiment=experiment - ) - maybe_aligned_dataframes = self._prepare_aligned_data( - experiment=experiment, metric_signatures=[metric_signature] - ) - if maybe_aligned_dataframes is None: - return False - - long_df, multilevel_wide_df = maybe_aligned_dataframes + long_df, multilevel_wide_df = aligned_frames wide_df = multilevel_wide_df["mean"][metric_signature] # Get completed trials @@ -179,6 +227,9 @@ def _should_stop_trials_early( self, trial_indices: set[int], experiment: Experiment, + metric_signature: str, + minimize: bool, + aligned_frames: tuple[pd.DataFrame, pd.DataFrame], current_node: GenerationNode | None = None, ) -> dict[int, str | None]: """Stop a trial if its performance is in the bottom `percentile_threshold` @@ -187,6 +238,10 @@ def _should_stop_trials_early( Args: trial_indices: Indices of candidate trials to consider for early stopping. experiment: Experiment that contains the trials and other contextual data. + metric_signature: The metric signature to evaluate. + minimize: Whether the metric is being minimized. + aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple + from ``_prepare_aligned_frames``. current_node: The current ``GenerationNode`` on the ``GenerationStrategy`` used to generate trials for the ``Experiment``. Early stopping strategies may utilize components of the current node when making @@ -197,16 +252,7 @@ def _should_stop_trials_early( (optional) messages with the associated reason. An empty dictionary means no suggested updates to any trial's status. """ - metric_signature, minimize = self._default_objective_and_direction( - experiment=experiment - ) - maybe_aligned_dataframes = self._prepare_aligned_data( - experiment=experiment, metric_signatures=[metric_signature] - ) - if maybe_aligned_dataframes is None: - return {} - - long_df, multilevel_wide_df = maybe_aligned_dataframes + long_df, multilevel_wide_df = aligned_frames wide_df = multilevel_wide_df["mean"][metric_signature] # default checks on `min_progression` and `min_curves`; if not met, don't do diff --git a/ax/early_stopping/strategies/threshold.py b/ax/early_stopping/strategies/threshold.py index 5dac5d62a48..d63a1686aef 100644 --- a/ax/early_stopping/strategies/threshold.py +++ b/ax/early_stopping/strategies/threshold.py @@ -112,15 +112,13 @@ def _should_stop_trials_early( metric_signature, minimize = self._default_objective_and_direction( experiment=experiment ) - data = self._lookup_and_validate_data( + df = self._lookup_and_validate( experiment=experiment, metric_signatures=[metric_signature] ) - if data is None: + if df is None: # don't stop any trials if we don't get data back return {} - df = data.full_df - # default checks on `min_progression` and `min_curves`; if not met, don't do # early stopping at all and return {} if not self.is_eligible_any( diff --git a/ax/early_stopping/tests/test_strategies.py b/ax/early_stopping/tests/test_strategies.py index 0af83a47b05..b9731d6ff68 100644 --- a/ax/early_stopping/tests/test_strategies.py +++ b/ax/early_stopping/tests/test_strategies.py @@ -132,11 +132,12 @@ def test_normalize_progressions(self) -> None: # Test with normalize_progressions=True es_strategy_normalized = FakeStrategy(normalize_progressions=True) - normalized_data = es_strategy_normalized._lookup_and_validate_data( - experiment, metric_signatures=[metric_signature] + normalized_df = none_throws( + es_strategy_normalized._lookup_and_validate( + experiment, metric_signatures=[metric_signature] + ) ) - normalized_data = none_throws(normalized_data) - normalized_progressions = normalized_data.full_df[MAP_KEY].astype(float) + normalized_progressions = normalized_df[MAP_KEY].astype(float) # Verify normalized progressions are in [0, 1] range self.assertAlmostEqual(normalized_progressions.min(), 0.0) @@ -164,12 +165,12 @@ def test_normalize_progressions(self) -> None: original_max = original_df[MAP_KEY].astype(float).max() es_strategy_unnormalized = FakeStrategy(normalize_progressions=False) - unnormalized_data = es_strategy_unnormalized._lookup_and_validate_data( - experiment, metric_signatures=[metric_signature] - ) - unnormalized_progressions = ( - none_throws(unnormalized_data).full_df[MAP_KEY].astype(float) + unnormalized_df = none_throws( + es_strategy_unnormalized._lookup_and_validate( + experiment, metric_signatures=[metric_signature] + ) ) + unnormalized_progressions = unnormalized_df[MAP_KEY].astype(float) # Verify progressions are NOT normalized (should match original range) self.assertAlmostEqual(unnormalized_progressions.min(), 0.0) @@ -201,11 +202,12 @@ def test_normalize_progressions(self) -> None: experiment.attach_data(data=Data(df=modified_df)) es_strategy = FakeStrategy(normalize_progressions=True) - normalized_data = es_strategy._lookup_and_validate_data( - experiment, metric_signatures=[metric_signature] + normalized_df = none_throws( + es_strategy._lookup_and_validate( + experiment, metric_signatures=[metric_signature] + ) ) - normalized_data = none_throws(normalized_data) - normalized_progressions = normalized_data.full_df[MAP_KEY].astype(float) + normalized_progressions = normalized_df[MAP_KEY].astype(float) # Verify min-max normalization produces [0, 1] range self.assertAlmostEqual(normalized_progressions.min(), 0.0) @@ -232,7 +234,7 @@ def test_nan_map_key_values_dropped_with_warning(self) -> None: # Set some MAP_KEY values to NaN for specific trials # This simulates corrupted or missing progression data - # Use metric_signature to match the filter in _lookup_and_validate_data + # Use metric_signature to match the filter in _lookup_and_validate trial_0_mask = (modified_df["trial_index"] == 0) & ( modified_df["metric_signature"] == metric_signature ) @@ -246,7 +248,7 @@ def test_nan_map_key_values_dropped_with_warning(self) -> None: # Verify warning is logged when NaN values are dropped with patch.object(logger, "warning") as mock_warning: - result = es_strategy._lookup_and_validate_data( + df = es_strategy._lookup_and_validate( experiment, metric_signatures=[metric_signature] ) @@ -261,10 +263,10 @@ def test_nan_map_key_values_dropped_with_warning(self) -> None: ) # Verify result is not None and NaN rows are dropped - self.assertIsNotNone(result) + self.assertIsNotNone(df) # Verify no NaN values remain in MAP_KEY column - self.assertFalse(result.full_df[MAP_KEY].isna().any()) + self.assertFalse(df[MAP_KEY].isna().any()) def test_all_objectives_and_directions_raises_error_when_lower_is_better_is_none( self, @@ -408,8 +410,8 @@ def test_is_eligible(self, _: MagicMock) -> None: experiment=experiment ) - map_data = none_throws( - es_strategy._lookup_and_validate_data( + df = none_throws( + es_strategy._lookup_and_validate( experiment, metric_signatures=[metric_signature], ) @@ -418,12 +420,12 @@ def test_is_eligible(self, _: MagicMock) -> None: es_strategy.is_eligible( trial_index=0, experiment=experiment, - df=map_data.full_df, + df=df, )[0] ) # try to get data from different metric name - fake_df = deepcopy(map_data.full_df) + fake_df = deepcopy(df) trial_index = 0 fake_df = fake_df.drop(fake_df.index[fake_df["trial_index"] == trial_index]) fake_es, fake_reason = es_strategy.is_eligible( @@ -436,18 +438,18 @@ def test_is_eligible(self, _: MagicMock) -> None: fake_reason, "No data available to make an early stopping decision." ) - fake_map_data = es_strategy._lookup_and_validate_data( + fake_df = es_strategy._lookup_and_validate( experiment, metric_signatures=["fake_metric_name"], ) - self.assertIsNone(fake_map_data) + self.assertIsNone(fake_df) es_strategy = FakeStrategy(min_progression=5) self.assertFalse( es_strategy.is_eligible( trial_index=0, experiment=experiment, - df=map_data.full_df, + df=df, )[0] ) @@ -456,7 +458,7 @@ def test_is_eligible(self, _: MagicMock) -> None: es_strategy.is_eligible( trial_index=0, experiment=experiment, - df=map_data.full_df, + df=df, )[0] ) @@ -468,8 +470,26 @@ def test_is_eligible(self, _: MagicMock) -> None: es_strategy.is_eligible_any( trial_indices={0}, experiment=experiment, - df=map_data.full_df, + df=df, + ) + + # testing shared arm names across trials + with self.subTest("rejects_shared_arm_names"): + exp = get_test_map_data_experiment( + num_trials=5, num_fetches=3, num_complete=2 ) + strategy = FakeStrategy() + arm_df = exp.fetch_data().full_df.copy() + arm_df["arm_name"] = "shared_arm" + exp.attach_data(Data(df=arm_df)) + + with self.assertRaisesRegex( + UnsupportedError, + "Arm.*appear across multiple trial indices", + ): + strategy._lookup_and_validate( + experiment=exp, metric_signatures=["branin_map"] + ) def test_progression_interval(self) -> None: """Test progression interval with min_progression=0.""" @@ -485,18 +505,16 @@ def test_progression_interval(self) -> None: experiment=experiment ) - map_data = es_strategy._lookup_and_validate_data( - experiment, - metric_signatures=[metric_signature], + df = none_throws( + es_strategy._lookup_and_validate( + experiment, + metric_signatures=[metric_signature], + ) ) - full_df = none_throws(map_data).full_df - - # Trial 0 has progressions at 0, 1, 2, 3, 4 - # Simulate orchestrator checks at different progressions - # Check 1: Trial at progression 1 (between boundaries 0 and 2) + # Check 1: Trial at progression 1 # First check, so should be eligible - df_at_1 = full_df[full_df[MAP_KEY] <= 1] + df_at_1 = df[df[MAP_KEY] <= 1] is_eligible, reason = es_strategy.is_eligible( trial_index=0, experiment=experiment, @@ -507,7 +525,7 @@ def test_progression_interval(self) -> None: # Check 2: Trial at progression 2 (at boundary 2) # Has crossed boundary from 1 to 2, should be eligible - df_at_2 = full_df[full_df[MAP_KEY] <= 2] + df_at_2 = df[df[MAP_KEY] <= 2] is_eligible, reason = es_strategy.is_eligible( trial_index=0, experiment=experiment, @@ -518,7 +536,7 @@ def test_progression_interval(self) -> None: # Check 3: Trial at progression 3 (between boundaries 2 and 4) # Has NOT crossed boundary from 2 to 3, should NOT be eligible - df_at_3 = full_df[full_df[MAP_KEY] <= 3] + df_at_3 = df[df[MAP_KEY] <= 3] is_eligible, reason = es_strategy.is_eligible( trial_index=0, experiment=experiment, @@ -540,7 +558,7 @@ def test_progression_interval(self) -> None: is_eligible, reason = es_strategy.is_eligible( trial_index=0, experiment=experiment, - df=full_df, + df=df, ) self.assertTrue(is_eligible) self.assertIsNone(reason) @@ -556,11 +574,12 @@ def test_progression_interval_with_min_progression(self) -> None: experiment=experiment ) - map_data = es_strategy._lookup_and_validate_data( - experiment, - metric_signatures=[metric_signature], + full_df = none_throws( + es_strategy._lookup_and_validate( + experiment, + metric_signatures=[metric_signature], + ) ) - full_df = none_throws(map_data).full_df # Trial 0 has progressions at 0, 1, 2, 3, 4 # With min_progression=1.0, boundaries are at 1, 3, 5, 7... @@ -1084,19 +1103,19 @@ def test_percentile_reason_messages(self) -> None: min_progression=0.1, ) # Use _should_stop_trial_early directly to get reason for non-stopped trial - data = none_throws( - early_stopping_strategy._lookup_and_validate_data( + df = none_throws( + early_stopping_strategy._lookup_and_validate( experiment, metric_signatures=["branin_map"] ) ) - aligned_df = align_partial_results(df=data.full_df, metrics=["branin_map"]) + aligned_df = align_partial_results(df=df, metrics=["branin_map"]) aligned_means = aligned_df["mean"]["branin_map"] should_stop, reason = early_stopping_strategy._should_stop_trial_early( trial_index=2, # Best trial experiment=experiment, wide_df=aligned_means, - long_df=data.full_df, + long_df=df, minimize=True, ) self.assertFalse(should_stop) @@ -1139,10 +1158,10 @@ def test_top_trials_reason_messages_with_percentile_info(self) -> None: n_best_trials_to_complete=3, ) - data = none_throws( - early_stopping_strategy._lookup_and_validate_data(exp, ["branin_map"]) + df = none_throws( + early_stopping_strategy._lookup_and_validate(exp, ["branin_map"]) ) - aligned_df = align_partial_results(df=data.full_df, metrics=["branin_map"]) + aligned_df = align_partial_results(df=df, metrics=["branin_map"]) aligned_means = aligned_df["mean"]["branin_map"] # Test trial 1 which is in top 3 but below percentile threshold @@ -1150,7 +1169,7 @@ def test_top_trials_reason_messages_with_percentile_info(self) -> None: trial_index=1, # Trial 1 is in top 3 but below percentile experiment=exp, wide_df=aligned_means, - long_df=data.full_df, + long_df=df, minimize=True, ) @@ -1200,12 +1219,10 @@ def test_early_stopping_with_n_best_protection_handles_ties(self) -> None: n_best_trials_to_complete=2, # Less than the 3 tied top trials ) - data_lookup = none_throws( - early_stopping_strategy._lookup_and_validate_data(exp, ["branin_map"]) - ) - aligned_df = align_partial_results( - df=data_lookup.full_df, metrics=["branin_map"] + df = none_throws( + early_stopping_strategy._lookup_and_validate(exp, ["branin_map"]) ) + aligned_df = align_partial_results(df=df, metrics=["branin_map"]) aligned_means = aligned_df["mean"]["branin_map"] # Test that ALL three tied top trials (0, 1, 2) are protected @@ -1215,7 +1232,7 @@ def test_early_stopping_with_n_best_protection_handles_ties(self) -> None: trial_index=trial_idx, experiment=exp, wide_df=aligned_means, - long_df=data_lookup.full_df, + long_df=df, minimize=True, ) @@ -1445,12 +1462,12 @@ def test_patience_with_insufficient_data(self) -> None: patience=1, # Window [1, 2] ) - data_lookup = none_throws( - early_stopping_strategy._lookup_and_validate_data(exp, ["branin_map"]) + df = none_throws( + early_stopping_strategy._lookup_and_validate(exp, ["branin_map"]) ) # Modify the data to simulate insufficient trials at progression 2 - modified_df = data_lookup.full_df.copy() + modified_df = df.copy() selector = ~( (modified_df["metric_name"] == "branin_map") & (modified_df[MAP_KEY] == 2) @@ -1533,10 +1550,10 @@ def test_patience_reason_messages(self) -> None: patience=2, ) - data = none_throws( - early_stopping_strategy._lookup_and_validate_data(exp, ["branin_map"]) + df = none_throws( + early_stopping_strategy._lookup_and_validate(exp, ["branin_map"]) ) - aligned_df = align_partial_results(df=data.full_df, metrics=["branin_map"]) + aligned_df = align_partial_results(df=df, metrics=["branin_map"]) aligned_means = aligned_df["mean"]["branin_map"] # Test trial that should be stopped @@ -1544,7 +1561,7 @@ def test_patience_reason_messages(self) -> None: trial_index=0, experiment=exp, wide_df=aligned_means, - long_df=data.full_df, + long_df=df, minimize=True, ) @@ -1663,28 +1680,6 @@ def test_early_stopping_with_unaligned_results(self) -> None: ) self.assertEqual(set(should_stop), {0, 1}) - # test error throwing in align partial results, with non-unique trial / arm name - exp = get_test_map_data_experiment(num_trials=5, num_fetches=3, num_complete=2) - - # manually "unalign" timestamps to simulate real-world scenario - # where each curve reports results at different steps - data = exp.fetch_data() - df_with_single_arm_name = data.full_df.copy() - df_with_single_arm_name["arm_name"] = "0_0" - with self.assertRaisesRegex( - UnsupportedError, - "Arm 0_0 has multiple trial indices", - ): - align_partial_results(df=df_with_single_arm_name, metrics=["branin_map"]) - - df_with_single_trial_index = data.full_df.copy() - df_with_single_trial_index["trial_index"] = 0 - with self.assertRaisesRegex( - UnsupportedError, - "Trial 0 has multiple arm names", - ): - align_partial_results(df=df_with_single_trial_index, metrics=["branin_map"]) - class TestThresholdEarlyStoppingStrategy(TestCase): # to avoid log spam in tests, we test the logger output explicitly in the percentile @@ -1868,10 +1863,10 @@ def _evaluate_early_stopping_with_df( ) -> dict[int, str | None]: """Helper function for testing PercentileEarlyStoppingStrategy on an arbitrary (Data) df.""" - data = none_throws( - early_stopping_strategy._lookup_and_validate_data(experiment, [metric_name]) + df = none_throws( + early_stopping_strategy._lookup_and_validate(experiment, [metric_name]) ) - aligned_df = align_partial_results(df=data.full_df, metrics=[metric_name]) + aligned_df = align_partial_results(df=df, metrics=[metric_name]) metric_to_aligned_means = aligned_df["mean"] aligned_means = metric_to_aligned_means[metric_name] decisions = { @@ -1879,7 +1874,7 @@ def _evaluate_early_stopping_with_df( trial_index=trial_index, experiment=experiment, wide_df=aligned_means, - long_df=data.full_df, + long_df=df, minimize=cast( OptimizationConfig, experiment.optimization_config ).objective.minimize, diff --git a/ax/early_stopping/utils.py b/ax/early_stopping/utils.py index d434f4d7369..faeaf2ce660 100644 --- a/ax/early_stopping/utils.py +++ b/ax/early_stopping/utils.py @@ -6,7 +6,6 @@ # pyre-strict -from logging import Logger from typing import Any import numpy.typing as npt @@ -14,10 +13,6 @@ from ax.core.data import MAP_KEY from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus -from ax.exceptions.core import UnsupportedError -from ax.utils.common.logger import get_logger - -logger: Logger = get_logger(__name__) # Early stopping message constants for use in analysis and reporting EARLY_STOPPING_STATUS_MSG = ( @@ -163,40 +158,7 @@ def align_partial_results( ) } """ - missing_metrics = set(metrics) - set(df["metric_signature"]) - if missing_metrics: - raise ValueError(f"Metrics {missing_metrics} not found in input dataframe") - # select relevant metrics df = df[df["metric_signature"].isin(metrics)] - # log some information about raw data - for m in metrics: - df_m = df[df["metric_signature"] == m] - if len(df_m) > 0: - logger.debug( - f"Metric {m} raw data has observations from " - f"{df_m[MAP_KEY].min()} to {df_m[MAP_KEY].max()}." - ) - else: - logger.info(f"No data from metric {m} yet.") - # drop arm names (assumes 1:1 map between trial indices and arm names) - # NOTE: this is not the case for BatchTrials and repeated arms - # if we didn't catch that there were multiple arms per trial, the interpolation - # code below would interpolate between data points from potentially different arms, - # as only the trial index is used to differentiate distinct data for interpolation. - for trial_index, trial_group in df.groupby("trial_index"): - if len(trial_group["arm_name"].unique()) != 1: - raise UnsupportedError( - f"Trial {trial_index} has multiple arm names: " - f"{trial_group['arm_name'].unique()}." - ) - - for arm_name, arm_group in df.groupby("arm_name"): - if len(arm_group["trial_index"].unique()) != 1: - raise UnsupportedError( - f"Arm {arm_name} has multiple trial indices: " - f"{arm_group['trial_index'].unique()}." - ) - df = df.drop("arm_name", axis=1) # remove duplicates (same trial, metric, step), which can happen # if the same progression is erroneously reported more than once