From 8ac7d52a3638182be7adf50735ea028debf4c68e Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 7 Apr 2026 07:41:09 -0700 Subject: [PATCH 1/4] Change get_trace to return dict[int, float] and filter by MetricAvailability Summary: `get_trace` previously returned `list[float]` with positional indexing, which forced callers to fabricate sequential trial indices. This change: 1. Changes the return type to `dict[int, float]` mapping trial_index to performance value, so callers get real trial indices. 2. Adds `MetricAvailability` filtering to exclude trials with incomplete metric data before pivoting. This prevents the `ValueError("Some metrics are not present for all trials and arms")` that `_pivot_data_with_feasibility` raises when a completed trial is missing any metric (e.g., partial fetches, fetch failures, metrics added mid-experiment). 3. Removes the carry-forward expansion loop for abandoned/failed trials. These trials are now simply excluded from the returned dict. Callers updated: - `UtilityProgressionAnalysis` now shows real trial indices on x-axis - `pareto_frontier.hypervolume_trace_plot` uses real trial indices - `benchmark.py` extracts values from dict Differential Revision: D99448053 --- .../plotly/tests/test_utility_progression.py | 4 +- ax/analysis/plotly/utility_progression.py | 24 +++--- ax/benchmark/benchmark.py | 28 +++++-- ax/plot/pareto_frontier.py | 4 +- ax/service/tests/test_best_point.py | 82 ++++++++++--------- ax/service/utils/best_point.py | 62 ++++++-------- 6 files changed, 108 insertions(+), 96 deletions(-) diff --git a/ax/analysis/plotly/tests/test_utility_progression.py b/ax/analysis/plotly/tests/test_utility_progression.py index cfaf90e214e..d2fe59ecd7c 100644 --- a/ax/analysis/plotly/tests/test_utility_progression.py +++ b/ax/analysis/plotly/tests/test_utility_progression.py @@ -36,7 +36,7 @@ def _assert_valid_utility_card( """Assert that a card has valid structure for utility progression.""" self.assertIsInstance(card, PlotlyAnalysisCard) self.assertEqual(card.name, "UtilityProgressionAnalysis") - self.assertIn("trace_index", card.df.columns) + self.assertIn("trial_index", card.df.columns) self.assertIn("utility", card.df.columns) def test_utility_progression_soo(self) -> None: @@ -211,7 +211,7 @@ def test_all_infeasible_points_raises_error(self) -> None: with ( patch( "ax.analysis.plotly.utility_progression.get_trace", - return_value=[math.inf, -math.inf, math.inf], + return_value={0: math.inf, 1: -math.inf, 2: math.inf}, ), self.assertRaises(ExperimentNotReadyError) as cm, ): diff --git a/ax/analysis/plotly/utility_progression.py b/ax/analysis/plotly/utility_progression.py index 5ffc765d761..ee738d3845c 100644 --- a/ax/analysis/plotly/utility_progression.py +++ b/ax/analysis/plotly/utility_progression.py @@ -28,11 +28,9 @@ _UTILITY_PROGRESSION_TITLE = "Utility Progression" _TRACE_INDEX_EXPLANATION = ( - "The x-axis shows trace index, which counts completed or early-stopped trials " - "sequentially (1, 2, 3, ...). This differs from trial index, which may have " - "gaps if some trials failed or were abandoned. For example, if trials 0, 2, " - "and 5 completed while trials 1, 3, and 4 failed, the trace indices would be " - "1, 2, 3 corresponding to trial indices 0, 2, 5." + "The x-axis shows trial index. Only completed or early-stopped trials with " + "complete metric data are included, so there may be gaps if some trials " + "failed, were abandoned, or have incomplete data." ) _CUMULATIVE_BEST_EXPLANATION = ( @@ -57,7 +55,8 @@ class UtilityProgressionAnalysis(Analysis): The DataFrame computed will contain one row per completed trial and the following columns: - - trace_index: Sequential index of completed/early-stopped trials (1, 2, 3, ...) + - trial_index: The trial index of each completed/early-stopped trial + that has complete metric avilability. - utility: The cumulative best utility value at that trial """ @@ -114,7 +113,7 @@ def compute( ) # Check if all points are infeasible (inf or -inf values) - if all(np.isinf(value) for value in trace): + if all(np.isinf(value) for value in trace.values()): raise ExperimentNotReadyError( "All trials in the utility trace are infeasible i.e., they violate " "outcome constraints, so there are no feasible points to plot. During " @@ -125,12 +124,11 @@ def compute( "space, or (2) relaxing outcome constraints." ) - # Create DataFrame with 1-based trace index for user-friendly display - # (1st completed trial, 2nd completed trial, etc. instead of 0-indexed) + # Create DataFrame with trial indices from the trace df = pd.DataFrame( { - "trace_index": list(range(1, len(trace) + 1)), - "utility": trace, + "trial_index": list(trace.keys()), + "utility": list(trace.values()), } ) @@ -185,14 +183,14 @@ def compute( # Create the plot fig = px.line( data_frame=df, - x="trace_index", + x="trial_index", y="utility", markers=True, color_discrete_sequence=[AX_BLUE], ) # Update axis labels and format x-axis to show integers only - fig.update_xaxes(title_text="Trace Index", dtick=1, rangemode="nonnegative") + fig.update_xaxes(title_text="Trial Index", dtick=1, rangemode="nonnegative") fig.update_yaxes(title_text=y_label) return create_plotly_analysis_card( diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index bd257c92d9a..2b066442183 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -320,9 +320,10 @@ def _get_oracle_value_of_params( dummy_experiment = get_oracle_experiment_from_params( problem=problem, dict_of_dict_of_params={0: {"0_0": params}} ) - (inference_value,) = get_trace( + trace = get_trace( experiment=dummy_experiment, optimization_config=problem.optimization_config ) + inference_value = next(iter(trace.values())) return inference_value @@ -510,12 +511,27 @@ def get_benchmark_result_from_experiment_and_gs( dict_of_dict_of_params=dict_of_dict_of_params, trial_statuses=trial_statuses, ) - oracle_trace = np.array( - get_trace( - experiment=actual_params_oracle_dummy_experiment, - optimization_config=problem.optimization_config, - ) + oracle_trace_dict = get_trace( + experiment=actual_params_oracle_dummy_experiment, + optimization_config=problem.optimization_config, + ) + # Expand trace dict to a positional array aligned with all trials, + # carry-forwarding the last best value for trials without data (e.g., + # failed or abandoned trials preserved via trial_statuses). + maximize = ( + isinstance(problem.optimization_config, MultiObjectiveOptimizationConfig) + or problem.optimization_config.objective.is_scalarized_objective + or not problem.optimization_config.objective.minimize ) + all_trial_indices = sorted(actual_params_oracle_dummy_experiment.trials.keys()) + last_best = -float("inf") if maximize else float("inf") + oracle_trace_list: list[float] = [] + for idx in all_trial_indices: + if idx in oracle_trace_dict: + last_best = oracle_trace_dict[idx] + oracle_trace_list.append(last_best) + oracle_trace = np.array(oracle_trace_list) + is_feasible_trace = np.array( get_is_feasible_trace( experiment=actual_params_oracle_dummy_experiment, diff --git a/ax/plot/pareto_frontier.py b/ax/plot/pareto_frontier.py index fec10dd7fac..5d11638f4e5 100644 --- a/ax/plot/pareto_frontier.py +++ b/ax/plot/pareto_frontier.py @@ -61,8 +61,8 @@ def scatter_plot_with_hypervolume_trace_plotly(experiment: Experiment) -> go.Fig df = pd.DataFrame( { - "hypervolume": hypervolume_trace, - "trial_index": [*range(len(hypervolume_trace))], + "trial_index": list(hypervolume_trace.keys()), + "hypervolume": list(hypervolume_trace.values()), } ) diff --git a/ax/service/tests/test_best_point.py b/ax/service/tests/test_best_point.py index 03393107221..e4a14ebfba5 100644 --- a/ax/service/tests/test_best_point.py +++ b/ax/service/tests/test_best_point.py @@ -57,7 +57,7 @@ def test_get_trace(self) -> None: exp = get_experiment_with_observations( observations=[[11], [10], [9], [15], [5]], minimize=True ) - self.assertEqual(get_trace(exp), [11, 10, 9, 9, 5]) + self.assertEqual(get_trace(exp), {0: 11, 1: 10, 2: 9, 3: 9, 4: 5}) # Same experiment with maximize via new optimization config. opt_conf = none_throws(exp.optimization_config).clone() @@ -67,7 +67,7 @@ def test_get_trace(self) -> None: opt_conf.objective.metric_names[0]: opt_conf.objective.metric_names[0] }, ) - self.assertEqual(get_trace(exp, opt_conf), [11, 11, 11, 15, 15]) + self.assertEqual(get_trace(exp, opt_conf), {0: 11, 1: 11, 2: 11, 3: 15, 4: 15}) with self.subTest("Single objective with constraints"): # The second metric is the constraint and needs to be >= 0 @@ -76,48 +76,52 @@ def test_get_trace(self) -> None: minimize=False, constrained=True, ) - self.assertEqual(get_trace(exp), [float("-inf"), 10, 10, 10, 11]) + self.assertEqual( + get_trace(exp), + {0: float("-inf"), 1: 10, 2: 10, 3: 10, 4: 11}, + ) exp = get_experiment_with_observations( observations=[[11, -1], [10, 1], [9, 1], [15, -1], [11, 1]], minimize=True, constrained=True, ) - self.assertEqual(get_trace(exp), [float("inf"), 10, 9, 9, 9]) + self.assertEqual(get_trace(exp), {0: float("inf"), 1: 10, 2: 9, 3: 9, 4: 9}) # Scalarized. exp = get_experiment_with_observations( observations=[[1, 1], [2, 2], [3, 3]], scalarized=True, ) - self.assertEqual(get_trace(exp), [2, 4, 6]) + self.assertEqual(get_trace(exp), {0: 2, 1: 4, 2: 6}) # Multi objective. exp = get_experiment_with_observations( observations=[[1, 1], [-1, 100], [1, 2], [3, 3], [2, 4], [2, 1]], ) - self.assertEqual(get_trace(exp), [1, 1, 2, 9, 11, 11]) + self.assertEqual(get_trace(exp), {0: 1, 1: 1, 2: 2, 3: 9, 4: 11, 5: 11}) # W/o ObjectiveThresholds (inferring ObjectiveThresholds from scaled nadir) assert_is_instance( exp.optimization_config, MultiObjectiveOptimizationConfig ).objective_thresholds = [] trace = get_trace(exp) + trace_values = list(trace.values()) # With inferred thresholds via scaled nadir, check trace properties: # - All values should be non-negative - self.assertTrue(all(v >= 0.0 for v in trace)) + self.assertTrue(all(v >= 0.0 for v in trace_values)) # - Trace should be non-decreasing (cumulative best) - for i in range(1, len(trace)): - self.assertGreaterEqual(trace[i], trace[i - 1]) + for i in range(1, len(trace_values)): + self.assertGreaterEqual(trace_values[i], trace_values[i - 1]) # - Final value should be positive (non-trivial HV) - self.assertGreater(trace[-1], 0.0) + self.assertGreater(trace_values[-1], 0.0) # Multi-objective w/ constraints. exp = get_experiment_with_observations( observations=[[-1, 1, 1], [1, 2, 1], [3, 3, -1], [2, 4, 1], [2, 1, 1]], constrained=True, ) - self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8]) + self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8}) # W/ relative constraints & status quo. exp.status_quo = Arm(parameters={"x": 0.5, "y": 0.5}, name="status_quo") @@ -149,17 +153,17 @@ def test_get_trace(self) -> None: ] status_quo_data = Data(df=pd.DataFrame.from_records(df_dict)) exp.attach_data(data=status_quo_data) - self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8]) + self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8}) # W/ first objective being minimized. exp = get_experiment_with_observations( observations=[[1, 1], [-1, 2], [3, 3], [-2, 4], [2, 1]], minimize=True ) - self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8]) + self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8}) # W/ empty data. exp = get_experiment_with_trial() - self.assertEqual(get_trace(exp), []) + self.assertEqual(get_trace(exp), {}) # test batch trial exp = get_experiment_with_batch_trial(with_status_quo=False) @@ -191,7 +195,7 @@ def test_get_trace(self) -> None: ] ) exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict))) - self.assertEqual(get_trace(exp), [2.0]) + self.assertEqual(get_trace(exp), {0: 2.0}) # test that there is performance metric in the trace for each # completed/early-stopped trial trial1 = assert_is_instance(trial, BatchTrial).clone_to(include_sq=False) @@ -214,7 +218,7 @@ def test_get_trace(self) -> None: ] ) exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2))) - self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0]) + self.assertEqual(get_trace(exp), {0: 2.0, 2: 20.0}) def test_get_trace_with_non_completed_trials(self) -> None: with self.subTest("minimize with abandoned trial"): @@ -224,12 +228,11 @@ def test_get_trace_with_non_completed_trials(self) -> None: # Mark trial 2 (value=9) as abandoned exp.trials[2].mark_abandoned(unsafe=True) - # Abandoned trial carries forward the last best value + # Abandoned trial is excluded from trace trace = get_trace(exp) - self.assertEqual(len(trace), 5) - # Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10 + # Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): excluded # Trial 3: 10 (15 > 10), Trial 4: 5 - self.assertEqual(trace, [11, 10, 10, 10, 5]) + self.assertEqual(trace, {0: 11, 1: 10, 3: 10, 4: 5}) with self.subTest("maximize with abandoned trial"): exp = get_experiment_with_observations( @@ -238,12 +241,11 @@ def test_get_trace_with_non_completed_trials(self) -> None: # Mark trial 1 (value=3) as abandoned exp.trials[1].mark_abandoned(unsafe=True) - # Abandoned trial carries forward the last best value + # Abandoned trial is excluded from trace trace = get_trace(exp) - self.assertEqual(len(trace), 5) - # Trial 0: 1, Trial 1 (abandoned): carry forward 1, + # Trial 0: 1, Trial 1 (abandoned): excluded, # Trial 2: 2, Trial 3: 5, Trial 4: 5 - self.assertEqual(trace, [1, 1, 2, 5, 5]) + self.assertEqual(trace, {0: 1, 2: 2, 3: 5, 4: 5}) with self.subTest("minimize with failed trial"): exp = get_experiment_with_observations( @@ -252,12 +254,11 @@ def test_get_trace_with_non_completed_trials(self) -> None: # Mark trial 2 (value=9) as failed exp.trials[2].mark_failed(unsafe=True) - # Failed trial carries forward the last best value + # Failed trial is excluded from trace trace = get_trace(exp) - self.assertEqual(len(trace), 5) - # Trial 0: 11, Trial 1: 10, Trial 2 (failed): carry forward 10 + # Trial 0: 11, Trial 1: 10, Trial 2 (failed): excluded # Trial 3: 10 (15 > 10), Trial 4: 5 - self.assertEqual(trace, [11, 10, 10, 10, 5]) + self.assertEqual(trace, {0: 11, 1: 10, 3: 10, 4: 5}) def test_get_trace_with_include_status_quo(self) -> None: with self.subTest("Multi-objective: status quo dominates in some trials"): @@ -347,9 +348,11 @@ def test_get_trace_with_include_status_quo(self) -> None: # The last value MUST differ because status quo dominates # Without status quo, only poor arms contribute (low hypervolume) # With status quo, excellent values contribute (high hypervolume) + last_without = list(trace_without_sq.values())[-1] + last_with = list(trace_with_sq.values())[-1] self.assertGreater( - trace_with_sq[-1], - trace_without_sq[-1], + last_with, + last_without, f"Status quo dominates in trial 3, so trace with SQ should be higher. " f"Without SQ: {trace_without_sq}, With SQ: {trace_with_sq}", ) @@ -418,9 +421,11 @@ def test_get_trace_with_include_status_quo(self) -> None: # The last value MUST differ because status quo is best # Without status quo: best in trial 3 is 15.0, cumulative min is 9 # With status quo: best in trial 3 is 5.0, cumulative min is 5 + last_without = list(trace_without_sq.values())[-1] + last_with = list(trace_with_sq.values())[-1] self.assertLess( - trace_with_sq[-1], - trace_without_sq[-1], + last_with, + last_without, f"Status quo is best in trial 3, so trace with SQ should be " f"lower (minimize). Without SQ: {trace_without_sq}, " f"With SQ: {trace_with_sq}", @@ -502,19 +507,20 @@ def _make_pref_opt_config(self, profile_name: str) -> PreferenceOptimizationConf preference_profile_name=profile_name, ) - def _assert_valid_trace(self, trace: list[float], expected_len: int) -> None: + def _assert_valid_trace(self, trace: dict[int, float], expected_len: int) -> None: """Assert trace has expected length, contains floats, is non-decreasing and has more than one unique value.""" self.assertEqual(len(trace), expected_len) - for value in trace: + trace_values = list(trace.values()) + for value in trace_values: self.assertIsInstance(value, float) - for i in range(1, len(trace)): + for i in range(1, len(trace_values)): self.assertGreaterEqual( - trace[i], - trace[i - 1], + trace_values[i], + trace_values[i - 1], msg=f"Trace not monotonically increasing at index {i}: {trace}", ) - unique_values = set(trace) + unique_values = set(trace_values) self.assertGreater( len(unique_values), 1, diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 0b43001a4a5..40fbebf861c 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -45,6 +45,7 @@ from ax.core.outcome_constraint import _op_to_str, OutcomeConstraint from ax.core.trial import Trial from ax.core.types import ComparisonOp, TModelPredictArm, TParameterization +from ax.core.utils import compute_metric_availability, MetricAvailability from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.generators.torch_base import TorchGenerator @@ -1212,7 +1213,7 @@ def get_trace( experiment: Experiment, optimization_config: OptimizationConfig | None = None, include_status_quo: bool = False, -) -> list[float]: +) -> dict[int, float]: """ Compute the optimization trace at each iteration. Given an experiment and an optimization config, compute the performance @@ -1229,10 +1230,10 @@ def get_trace( improvements in the optimization trace. If the first trial(s) are infeasible, the trace can start at inf or -inf. - An iteration here refers to a completed or early-stopped (batch) trial. - There will be one performance metric in the trace for each iteration. - Trials without data (e.g. abandoned or failed) carry forward the last - best value. + An iteration here refers to a completed or early-stopped (batch) trial + with complete metric data for all metrics in the optimization config. + Trials with incomplete data, or trials that are abandoned or failed, are + excluded from the trace. Args: experiment: The experiment to get the trace for. @@ -1243,14 +1244,15 @@ def get_trace( behavior. Returns: - A list of performance values at each iteration. + A dict mapping trial index to performance value, ordered by trial + index. Only trials with complete metric data are included. """ optimization_config = optimization_config or none_throws( experiment.optimization_config ) df = experiment.lookup_data().df if len(df) == 0: - return [] + return {} # Get the names of the metrics in optimization config. metric_names = set(optimization_config.objective.metric_names) @@ -1271,7 +1273,22 @@ def get_trace( idx &= df["arm_name"] != status_quo.name df = df.loc[idx, :] if len(df) == 0: - return [] + return {} + + # Filter to trials with complete metric data. + availability = compute_metric_availability( + experiment=experiment, + trial_indices=df["trial_index"].unique().tolist(), + optimization_config=optimization_config, + ) + complete_trials = { + idx + for idx, avail in availability.items() + if avail == MetricAvailability.COMPLETE + } + df = df[df["trial_index"].isin(complete_trials)] + if len(df) == 0: + return {} # Derelativize the optimization config only if needed (i.e., if there are # relative constraints). This avoids unnecessary data pivoting that can @@ -1289,7 +1306,7 @@ def get_trace( use_cumulative_best=True, experiment=experiment, ) - # Aggregate by trial, then. compute cumulative best + # Aggregate by trial, then compute cumulative best objective = optimization_config.objective maximize = ( isinstance(optimization_config, MultiObjectiveOptimizationConfig) @@ -1303,32 +1320,7 @@ def get_trace( keep_order=False, # sort by trial index ) - compact_trace = cumulative_value.tolist() - - # Expand trace to include trials without data (e.g. ABANDONED, FAILED) - # with carry-forward values. - data_trial_indices = set(cumulative_value.index) - expanded_trace = [] - compact_idx = 0 - last_best_value = -float("inf") if maximize else float("inf") - - for trial_index in sorted(experiment.trials.keys()): - trial = experiment.trials[trial_index] - if trial_index in data_trial_indices: - # Trial has data in compact trace - if compact_idx < len(compact_trace): - value = compact_trace[compact_idx] - expanded_trace.append(value) - last_best_value = value - compact_idx += 1 - else: - # Should not happen, but handle gracefully - expanded_trace.append(last_best_value) - elif trial.status in (TrialStatus.ABANDONED, TrialStatus.FAILED): - # Trial has no data; carry forward the last best value. - expanded_trace.append(last_best_value) - - return expanded_trace + return {int(k): float(v) for k, v in cumulative_value.items()} def get_tensor_converter_adapter( From 7f401eb563779b4ceb20bdf778fe28495a98170e Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 7 Apr 2026 07:41:09 -0700 Subject: [PATCH 2/4] Filter by MetricAvailability in get_best_raw_objective_point_with_trial_index Summary: `get_best_raw_objective_point_with_trial_index` filters to COMPLETED trials but assumes all metrics in the optimization config are present in the data. This can fail with a hard `ValueError` from `_pivot_data_with_feasibility` when a completed trial has incomplete metric data (e.g., due to partial metric fetches, fetch failures, or metrics added mid-experiment). This diff adds `MetricAvailability` filtering after the completed-trials filter to exclude trials with incomplete data before they reach the pivot step. Trials without complete metric data are now silently excluded rather than causing a crash, with a clear error if no trials remain. Differential Revision: D99451732 --- ax/service/tests/test_best_point_utils.py | 59 +++++++++++++---------- ax/service/utils/best_point.py | 20 ++++++++ 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index 4101f2fd6d9..ce82aaabd7b 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -504,7 +504,8 @@ def test_best_raw_objective_point(self) -> None: objective=Objective(metric=get_branin_metric(name="not_branin")) ) with self.assertRaisesRegex( - ValueError, "Some metrics are not present for all trials and arms" + ValueError, + "no completed trials have complete metric data", ): get_best_raw_objective_point_with_trial_index( experiment=exp, optimization_config=opt_conf @@ -543,8 +544,8 @@ def test_best_raw_objective_point_unsatisfiable(self) -> None: experiment=exp, optimization_config=opt_conf ) - # adding a new metric that is not present in the data should raise an error, - # even if the other metrics are satisfied + # Adding a constraint on an unobserved metric causes the trial to be + # filtered out by MetricAvailability before reaching feasibility check. opt_conf.outcome_constraints.pop() unobserved_metric = get_branin_metric(name="unobserved") opt_conf.outcome_constraints.append( @@ -552,31 +553,39 @@ def test_best_raw_objective_point_unsatisfiable(self) -> None: metric=unobserved_metric, op=ComparisonOp.LEQ, bound=0, relative=False ) ) - # also add a constraint that is always satisfied, as the Branin metric is - # non-negative, and check that only the "unobserved" metric shows up in the - # error message - opt_conf.outcome_constraints.append( - OutcomeConstraint( - metric=get_branin_metric(), op=ComparisonOp.GEQ, bound=0, relative=False + with self.assertRaisesRegex( + ValueError, + "no completed trials have complete metric data", + ): + get_best_raw_objective_point_with_trial_index( + experiment=exp, optimization_config=opt_conf ) - ) - with self.assertLogs(logger=best_point_logger, level="WARN") as lg: - with self.assertRaisesRegex( - ValueError, - r"No points satisfied all outcome constraints within 95 percent " - r"confidence interval\. The feasibility of 1 arm\(s\) could not be " - r"determined: \['0_0'\]\.", - ): - get_best_raw_objective_point_with_trial_index( - experiment=exp, optimization_config=opt_conf - ) - self.assertEqual(len(lg.output), 1) - self.assertRegex( - lg.output[0], - r"Arm 0_0 is missing data for one or more constrained metrics: " - r"\{'unobserved'\}\.", + # Using a constrained experiment where all metrics are observed + # (passes MetricAvailability), but constraints are unsatisfiable. + constrained_exp = get_experiment_with_observations( + observations=[[1.0, 2.0]], + constrained=True, + minimize=False, ) + constrained_opt = none_throws(constrained_exp.optimization_config).clone() + # Make constraint unsatisfiable: require m2 >= 9999 (observed m2=2.0). + constrained_opt.outcome_constraints = [ + OutcomeConstraint( + metric=Metric(name="m2"), + op=ComparisonOp.GEQ, + bound=9999, + relative=False, + ), + ] + with self.assertRaisesRegex( + ValueError, + r"No points satisfied all outcome constraints", + ): + get_best_raw_objective_point_with_trial_index( + experiment=constrained_exp, + optimization_config=constrained_opt, + ) def test_best_raw_objective_point_unsatisfiable_relative(self) -> None: exp = get_experiment_with_observations( diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 40fbebf861c..124cc6e3ccd 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -134,6 +134,26 @@ def get_best_raw_objective_point_with_trial_index( raise ValueError("Cannot identify best point if no trials are completed.") completed_df = dat.df[dat.df["trial_index"].isin(completed_indices)] + # Filter to trials with complete metric data to avoid errors in + # _pivot_data_with_feasibility when some metrics are missing (e.g., due to + # partial metric fetches, fetch failures, or metrics added mid-experiment). + availability = compute_metric_availability( + experiment=experiment, + trial_indices=sorted(completed_indices), + optimization_config=optimization_config, + ) + complete_trials = { + idx + for idx, avail in availability.items() + if avail == MetricAvailability.COMPLETE + } + completed_df = completed_df[completed_df["trial_index"].isin(complete_trials)] + if len(completed_df) == 0: + raise ValueError( + "Cannot identify best point: no completed trials have complete " + "metric data for all metrics in the optimization config." + ) + is_feasible = is_row_feasible( df=completed_df, optimization_config=optimization_config, From 02e06aecc25450a639b505f25f4f6bc90c090e68 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 7 Apr 2026 07:41:46 -0700 Subject: [PATCH 3/4] Decouple metric fetch errors from trial status in Orchestrator Summary: Design doc: D98741656 When `fetch_trials_data_results` returned a `MetricFetchE` for an optimization config metric, the orchestrator marked the trial as ABANDONED. This discarded good data, inflated the failure rate, and was inconsistent with the Client layer which keeps trials COMPLETED with incomplete metrics via `MetricAvailability` (D93924193). This diff removes the trial abandonment behavior. Metric fetch errors are now logged (with traceback via `logger.exception`) but trial status is unchanged. `MetricAvailability` tracks data completeness, and the failure rate check uses it to detect persistent metric issues. Changes: - `_fetch_and_process_trials_data_results`: Removed the branch that marked trials ABANDONED for metric fetch errors and the separate `is_available_while_running` branch. All metric fetch errors are now simply logged and the method continues. The `_report_metric_fetch_e` hook is still called so subclasses (e.g. `AxSweepOrchestrator`) can react to errors (create pastes, build error tables, etc.). - `error_if_failure_rate_exceeded`: Merged `_check_if_failure_rate_exceeded` into this method to avoid duplicate computation. Now counts both runner failures (FAILED/ABANDONED) and metric-incomplete trials (via `compute_metric_availability`) toward the failure rate. - `_get_failure_rate_exceeded_error`: Rewritten with an actionable error message listing runner failures, metric-incomplete trials, missing metrics, and affected trial indices. - Removed dead code: `_mark_err_trial_status`, `_num_trials_bad_due_to_err`, `_num_metric_fetch_e_encountered`, `_check_if_failure_rate_exceeded`, `METRIC_FETCH_ERR_MESSAGE`. - Kept `_report_metric_fetch_e` as a no-op hook so subclasses like `AxSweepOrchestrator` can still react to metric fetch errors. - Updated telemetry (`OrchestratorCompletedRecord`) to use `_count_metric_incomplete_trials` (via `compute_metric_availability`) for both `num_metric_fetch_e_encountered` and `num_trials_bad_due_to_err`. - Updated `AxSweepOrchestrator` test assertions: trials now stay COMPLETED (not ABANDONED) after metric fetch errors. - `Metric.recoverable_exceptions` and `Metric.is_recoverable_fetch_e` are kept for now since `pts/` metrics still reference them; cleanup will follow in a separate diff. Differential Revision: D98924467 --- ax/orchestration/orchestrator.py | 283 +++++++++++--------- ax/orchestration/tests/test_orchestrator.py | 144 ++++------ 2 files changed, 203 insertions(+), 224 deletions(-) diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index d1b6e1f827b..95e1672753b 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -32,6 +32,7 @@ from ax.core.runner import Runner from ax.core.trial import Trial from ax.core.trial_status import TrialStatus +from ax.core.utils import compute_metric_availability, MetricAvailability from ax.exceptions.core import ( AxError, DataRequiredError, @@ -80,13 +81,6 @@ "of an optimization and if at least {min_failed} trials have been " "failed/abandoned, potentially automatically due to issues with the trial." ) -METRIC_FETCH_ERR_MESSAGE = ( - "A majority of the trial failures encountered are due to metric fetching errors. " - "This could mean the metrics are flaky, broken, or misconfigured. Please check " - "that the trial processes/jobs are successfully producing the expected metrics and " - "that the metric is correctly configured." -) - EXPECTED_STAGED_MSG = ( "Expected all trials to be in status {expected} after running or staging, " "found {t_idx_to_status}." @@ -191,13 +185,6 @@ class Orchestrator(WithDBSettingsBase, BestPointMixin): # Saved as a property so that it can be accessed after optimization is complex (ex. # for global stopping saving calculation). _num_remaining_requested_trials: int = 0 - # Total number of MetricFetchEs encountered during the course of optimization. Note - # this is different from and may be greater than the number of trials that have - # been marked either FAILED or ABANDONED due to metric fetching errors. - _num_metric_fetch_e_encountered: int = 0 - # Number of trials that have been marked either FAILED or ABANDONED due to - # MetricFetchE being encountered during _fetch_and_process_trials_data_results - _num_trials_bad_due_to_err: int = 0 # Keeps track of whether the allowed failure rate has been exceeded during # the optimization. If true, allows any pending trials to finish and raises # an error through self._complete_optimization. @@ -1073,83 +1060,63 @@ def summarize_final_result(self) -> OptimizationResult: """ return OptimizationResult() - def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool: - """Checks if the failure rate (set in Orchestrator options) has been exceeded at - any point during the optimization. + def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: + """Raises an exception if the failure rate (set in Orchestrator options) has + been exceeded at any point during the optimization. - NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate. + The failure rate is computed as the ratio of "bad" trials to total trials + created by this orchestrator. "Bad" trials include: + - Execution failures: trials with FAILED or ABANDONED status. + - Metric-incomplete trials: COMPLETED trials whose metric data is not + fully available (as determined by ``compute_metric_availability``). Args: force_check: Indicates whether to force a failure-rate check regardless of the number of trials that have been executed. If False (default), the check will be skipped if the optimization has fewer than - five failed trials. If True, the check will be performed unless there - are 0 failures. + ``min_failed_trials_for_failure_rate_check`` bad trials. If True, the + check will be performed unless there are 0 bad trials. + """ + # Count runner-level failures (FAILED + ABANDONED). + num_execution_failures = self._num_bad_in_orchestrator() - Effect on state: - If the failure rate has been exceeded, a warning is logged and the private - attribute `_failure_rate_has_been_exceeded` is set to True, which causes the - `_get_max_pending_trials` to return zero, so that no further trials are - scheduled and an error is raised at the end of the optimization. + # Count completed trials with incomplete metric availability. + num_metric_incomplete, missing_metrics_by_trial = ( + self._get_metric_incomplete_trials() + ) - Returns: - Boolean representing whether the failure rate has been exceeded. - """ - if self._failure_rate_has_been_exceeded: - return True + num_bad = num_execution_failures + num_metric_incomplete - num_bad_in_orchestrator = self._num_bad_in_orchestrator() - # skip check if 0 failures - if num_bad_in_orchestrator == 0: - return False + if not self._failure_rate_has_been_exceeded: + # Skip check if 0 bad trials. + if num_bad == 0: + return - # skip check if fewer than min_failed_trials_for_failure_rate_check failures - # unless force_check is True - if ( - num_bad_in_orchestrator - < self.options.min_failed_trials_for_failure_rate_check - and not force_check - ): - return False + # Skip check if fewer than min threshold unless force_check. + if ( + num_bad < self.options.min_failed_trials_for_failure_rate_check + and not force_check + ): + return - num_ran_in_orchestrator = self._num_ran_in_orchestrator() - failure_rate_exceeded = ( - num_bad_in_orchestrator / num_ran_in_orchestrator - ) > self.options.tolerated_trial_failure_rate + num_ran_in_orchestrator = self._num_ran_in_orchestrator() + failure_rate_exceeded = ( + num_bad / num_ran_in_orchestrator + ) > self.options.tolerated_trial_failure_rate + + if not failure_rate_exceeded: + return - if failure_rate_exceeded: - if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2: - self.logger.warning( - "MetricFetchE INFO: Sweep aborted due to an exceeded error rate, " - "which was primarily caused by failure to fetch metrics. Please " - "check if anything could cause your metrics to be flaky or " - "broken." - ) # NOTE: this private attribute causes `_get_max_pending_trials` to # return zero, which causes no further trials to be scheduled. self._failure_rate_has_been_exceeded = True - return True - - return False - - def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: - """Raises an exception if the failure rate (set in Orchestrator options) has - been exceeded at any point during the optimization. - NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate. - - Args: - force_check: Indicates whether to force a failure-rate check - regardless of the number of trials that have been executed. If False - (default), the check will be skipped if the optimization has fewer than - five failed trials. If True, the check will be performed unless there - are 0 failures. - """ - if self._check_if_failure_rate_exceeded(force_check=force_check): - raise self._get_failure_rate_exceeded_error( - num_bad_in_orchestrator=self._num_bad_in_orchestrator(), - num_ran_in_orchestrator=self._num_ran_in_orchestrator(), - ) + raise self._get_failure_rate_exceeded_error( + num_execution_failures=num_execution_failures, + num_metric_incomplete=num_metric_incomplete, + num_ran_in_orchestrator=self._num_ran_in_orchestrator(), + missing_metrics_by_trial=missing_metrics_by_trial, + ) def _error_if_status_quo_infeasible(self) -> None: """Raises an exception if the status-quo arm is infeasible and the @@ -2032,9 +1999,13 @@ def _fetch_and_process_trials_data_results( self, trial_indices: Iterable[int], ) -> dict[int, dict[str, MetricFetchResult]]: - """ - Fetches results from experiment and modifies trial statuses depending on - success or failure. + """Fetch trial data results and log any metric fetch errors. + + Metric fetch errors are logged but do NOT change trial status. + ``MetricAvailability`` (computed via ``compute_metric_availability``) + tracks data completeness separately, and the failure rate check in + ``error_if_failure_rate_exceeded`` uses it to detect persistent + metric issues. """ try: @@ -2085,41 +2056,12 @@ def _fetch_and_process_trials_data_results( f"Failed to fetch {metric_name} for trial {trial_index} with " f"status {status}, found {metric_fetch_e}." ) - self._num_metric_fetch_e_encountered += 1 self._report_metric_fetch_e( trial=self.experiment.trials[trial_index], metric_name=metric_name, metric_fetch_e=metric_fetch_e, ) - # If the fetch failure was for a metric in the optimization config (an - # objective or constraint) mark the trial as failed - optimization_config = self.experiment.optimization_config - if ( - optimization_config is not None - and metric_name in optimization_config.metric_names - and not self.experiment.metrics[metric_name].is_recoverable_fetch_e( - metric_fetch_e=metric_fetch_e - ) - ): - status = self._mark_err_trial_status( - trial=self.experiment.trials[trial_index], - metric_name=metric_name, - metric_fetch_e=metric_fetch_e, - ) - self.logger.warning( - f"MetricFetchE INFO: Because {metric_name} is an objective, " - f"marking trial {trial_index} as {status}." - ) - self._num_trials_bad_due_to_err += 1 - continue - - self.logger.info( - "MetricFetchE INFO: Continuing optimization even though " - "MetricFetchE encountered." - ) - continue - return results def _report_metric_fetch_e( @@ -2128,39 +2070,122 @@ def _report_metric_fetch_e( metric_name: str, metric_fetch_e: MetricFetchE, ) -> None: + """Hook for subclasses to react to metric fetch errors. + + Called once per metric fetch error during + ``_fetch_and_process_trials_data_results``. The default + implementation is a no-op; override in subclasses to add custom + reporting (e.g., creating error tables or pastes). + """ pass - def _mark_err_trial_status( + def _get_metric_incomplete_trials( self, - trial: BaseTrial, - metric_name: str | None = None, - metric_fetch_e: MetricFetchE | None = None, - ) -> TrialStatus: - trial.mark_abandoned( - reason=metric_fetch_e.message if metric_fetch_e else None, unsafe=True + ) -> tuple[int, dict[int, set[str]]]: + """Count completed trials with incomplete metric availability and identify + which metrics are missing for each. + + Required metrics include optimization config metrics and any explicitly + defined early stopping strategy metrics. + + Returns: + A tuple of (num_metric_incomplete, missing_metrics_by_trial) where + missing_metrics_by_trial maps trial index to the set of missing + metric names. + """ + opt_config = self.experiment.optimization_config + if opt_config is None: + return 0, {} + + completed_trial_indices = [ + t.index + for t in self.experiment.trials.values() + if t.status == TrialStatus.COMPLETED + and t.index >= self._num_preexisting_trials + ] + if len(completed_trial_indices) == 0: + return 0, {} + + required_metrics = set(opt_config.metric_names) + + # Include explicitly defined early stopping strategy metrics. + # ESS stores metric *signatures*, which may differ from metric names, + # so we resolve them via experiment.signature_to_metric. + ess = self.options.early_stopping_strategy + ess_signatures = ess.metric_signatures if ess is not None else None + if ess_signatures is not None: + for sig in ess_signatures: + metric = self.experiment.signature_to_metric[sig] + required_metrics.add(metric.name) + + metric_availabilities = compute_metric_availability( + experiment=self.experiment, + trial_indices=completed_trial_indices, + metric_names=required_metrics, ) - return TrialStatus.ABANDONED + + # Identify which specific metrics are missing per trial. + data = self.experiment.lookup_data(trial_indices=completed_trial_indices) + metrics_per_trial: dict[int, set[str]] = {} + if len(data.metric_names) > 0: + df = data.full_df + for trial_idx, group in df.groupby("trial_index")["metric_name"]: + metrics_per_trial[int(trial_idx)] = set(group.unique()) + + missing_metrics_by_trial: dict[int, set[str]] = {} + for idx, avail in metric_availabilities.items(): + if avail != MetricAvailability.COMPLETE: + available = metrics_per_trial.get(idx, set()) + missing_metrics_by_trial[idx] = required_metrics - available + + return len(missing_metrics_by_trial), missing_metrics_by_trial def _get_failure_rate_exceeded_error( self, - num_bad_in_orchestrator: int, + num_execution_failures: int, + num_metric_incomplete: int, num_ran_in_orchestrator: int, + missing_metrics_by_trial: dict[int, set[str]], ) -> FailureRateExceededError: - return FailureRateExceededError( - ( - f"{METRIC_FETCH_ERR_MESSAGE}\n" - if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2 - else "" + """Build an actionable error message describing why the failure rate was + exceeded, including runner failures, metric-incomplete trials, which + metrics are missing, and which trials are affected. + """ + num_bad = num_execution_failures + num_metric_incomplete + observed_rate = num_bad / num_ran_in_orchestrator + + parts: list[str] = [] + parts.append( + f"Failure rate exceeded: {num_bad} of {num_ran_in_orchestrator} " + f"trials were unsuccessful (observed rate: {observed_rate:.0%}, tolerance: " + f"{self.options.tolerated_trial_failure_rate:.0%}). " + f"Checks are triggered when at least " + f"{self.options.min_failed_trials_for_failure_rate_check} trials " + "are unsuccessful or at the end of the optimization." + ) + + if num_execution_failures > 0: + parts.append( + f"{num_execution_failures} trial(s) failed at the execution " + "level (FAILED or ABANDONED). Check any trial evaluation " + "processes/jobs to see why they are failing." ) - + " Orignal error message: " - + FAILURE_EXCEEDED_MSG.format( - f_rate=self.options.tolerated_trial_failure_rate, - n_failed=num_bad_in_orchestrator, - n_ran=num_ran_in_orchestrator, - min_failed=self.options.min_failed_trials_for_failure_rate_check, - observed_rate=float(num_bad_in_orchestrator) / num_ran_in_orchestrator, + + if num_metric_incomplete > 0: + all_missing: set[str] = set() + for missing in missing_metrics_by_trial.values(): + all_missing.update(missing) + affected_trials = sorted(missing_metrics_by_trial.keys()) + + parts.append( + f"{num_metric_incomplete} trial(s) have incomplete metric data. " + f"Missing metrics: {sorted(all_missing)}. " + f"Affected trials: {affected_trials}. " + "Check that your metric fetching infrastructure is healthy " + "and that the metrics are being logged correctly." ) - ) + + return FailureRateExceededError("\n".join(parts)) def _warn_if_non_terminal_trials(self) -> None: """Warns if there are any non-terminal trials on the experiment.""" diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 023f970fca6..7252e3a4142 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -38,12 +38,7 @@ get_pending_observation_features_based_on_trial_status, ) from ax.early_stopping.strategies import BaseEarlyStoppingStrategy -from ax.exceptions.core import ( - AxError, - OptimizationComplete, - UnsupportedError, - UserInputError, -) +from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import AxGenerationException from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import ( @@ -1834,6 +1829,12 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( ) def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: + """Metric fetch errors on objective metrics do NOT change trial status. + + The trial remains COMPLETED, and MetricAvailability reflects the missing + data. The failure rate check uses MetricAvailability to detect persistent + metric issues. + """ gs = self.two_sobol_steps_GS orchestrator = Orchestrator( experiment=self.branin_experiment, @@ -1854,97 +1855,44 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: ), self.assertLogs(logger="ax.orchestration.orchestrator") as lg, ): - # This trial will fail + # The trial completes but has incomplete metrics, triggering + # the failure rate check. with self.assertRaises(FailureRateExceededError): orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ) - ) + # Verify the error was logged (not the old "marking trial as ABANDONED"). self.assertTrue( any( re.search( - r"Because (branin|m1) is an objective, marking trial 0 as " - "TrialStatus.ABANDONED", - warning, - ) - is not None - for warning in lg.output - ) - ) - self.assertEqual( - orchestrator.experiment.trials[0].status, TrialStatus.ABANDONED - ) - - def test_fetch_and_process_trials_data_results_failed_objective_but_recoverable( - self, - ) -> None: - gs = self.two_sobol_steps_GS - orchestrator = Orchestrator( - experiment=self.branin_experiment, - generation_strategy=gs, - options=OrchestratorOptions( - enforce_immutable_search_space_and_opt_config=False, - **self.orchestrator_options_kwargs, - ), - db_settings=self.db_settings_if_always_needed, - ) - BraninMetric.recoverable_exceptions = {AxError, TypeError} - # we're throwing a recoverable exception because UserInputError - # is a subclass of AxError - with ( - patch( - f"{BraninMetric.__module__}.BraninMetric.f", - side_effect=UserInputError("yikes!"), - ), - patch( - f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", - return_value=False, - ), - self.assertLogs(logger="ax.orchestration.orchestrator") as lg, - ): - orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ), - lg.output, - ) - self.assertTrue( - any( - re.search( - "MetricFetchE INFO: Continuing optimization even though " - "MetricFetchE encountered", + r"Failed to fetch (branin|m1) for trial 0", warning, ) is not None for warning in lg.output ) ) + # Trial stays COMPLETED -- not ABANDONED. self.assertEqual( orchestrator.experiment.trials[0].status, TrialStatus.COMPLETED ) - def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable( - self, - ) -> None: + def test_failure_rate_metric_incomplete(self) -> None: + """Failure rate check uses MetricAvailability to count metric-incomplete + trials and raises FailureRateExceededError with an actionable message + listing missing metrics and affected trials. + """ gs = self.two_sobol_steps_GS + tolerated_failure_rate = 0.5 + min_failed = 1 orchestrator = Orchestrator( experiment=self.branin_experiment, generation_strategy=gs, options=OrchestratorOptions( + tolerated_trial_failure_rate=tolerated_failure_rate, + min_failed_trials_for_failure_rate_check=min_failed, **self.orchestrator_options_kwargs, ), db_settings=self.db_settings_if_always_needed, ) - # we're throwing a unrecoverable exception because Exception is not subclass - # of either error type in recoverable_exceptions - BraninMetric.recoverable_exceptions = {AxError, TypeError} with ( patch( f"{BraninMetric.__module__}.BraninMetric.f", @@ -1954,33 +1902,39 @@ def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable( f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", return_value=False, ), - self.assertLogs(logger="ax.orchestration.orchestrator") as lg, ): - # This trial will fail - with self.assertRaises(FailureRateExceededError): + with self.assertRaises(FailureRateExceededError) as cm: orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ) - ) - self.assertTrue( - any( - re.search( - r"Because (branin|m1) is an objective, marking trial 0 as " - "TrialStatus.ABANDONED", - warning, - ) - is not None - for warning in lg.output - ) - ) + + # Trial stays COMPLETED -- metric fetch errors do not change status. self.assertEqual( - orchestrator.experiment.trials[0].status, TrialStatus.ABANDONED + orchestrator.experiment.trials[0].status, TrialStatus.COMPLETED ) + # Build the expected error message from orchestrator config values. + # 1 trial ran, 0 execution failures, 1 metric-incomplete trial. + opt_config = none_throws(orchestrator.experiment.optimization_config) + opt_metric_names = sorted(opt_config.metric_names) + expected_parts = [ + ( + f"Failure rate exceeded: 1 of 1 trials were unsuccessful " + f"(observed rate: 100%, tolerance: " + f"{tolerated_failure_rate:.0%}). " + f"Checks are triggered when at least " + f"{min_failed} trials " + f"are unsuccessful or at the end of the optimization." + ), + ( + f"1 trial(s) have incomplete metric data. " + f"Missing metrics: {opt_metric_names}. " + f"Affected trials: [0]. " + f"Check that your metric fetching infrastructure is healthy " + f"and that the metrics are being logged correctly." + ), + ] + expected_msg = "\n".join(expected_parts) + self.assertEqual(str(cm.exception), expected_msg) + def test_should_consider_optimization_complete(self) -> None: # Tests non-GSS parts of the completion criterion. gs = self.sobol_MBM_GS From 70f60e5deaae24e30e48bbd264aacf7fbfbc8751 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 7 Apr 2026 07:46:38 -0700 Subject: [PATCH 4/4] Remove dead `recoverable_exceptions` and `is_recoverable_fetch_e` (#5120) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5120 Follow-up to D98924467, which decoupled metric fetch errors from trial status in the Orchestrator. The orchestrator no longer uses `recoverable_exceptions` or `is_recoverable_fetch_e` to decide trial fate, making them dead code. - Remove `Metric.recoverable_exceptions` class attribute and `Metric.is_recoverable_fetch_e` classmethod from `ax/core/metric.py`. Reviewed By: bernardbeckerman Differential Revision: D98932195 --- ax/core/metric.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/ax/core/metric.py b/ax/core/metric.py index 24240dfe4b6..ecb975c1981 100644 --- a/ax/core/metric.py +++ b/ax/core/metric.py @@ -88,11 +88,6 @@ class Metric(SortableBase, SerializationMixin): properties: Properties specific to a particular metric. """ - # The set of exception types stored in a ``MetchFetchE.exception`` that are - # recoverable ``orchestrator._fetch_and_process_trials_data_results()``. - # Exception may be a subclass of any of these types. If you want your metric - # to never fail the trial, set this to ``{Exception}`` in your metric subclass. - recoverable_exceptions: set[type[Exception]] = set() has_map_data: bool = False def __init__( @@ -164,17 +159,6 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta: """ return timedelta(0) - @classmethod - def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool: - """Checks whether the given MetricFetchE is recoverable for this metric class - in ``orchestrator._fetch_and_process_trials_data_results``. - """ - if metric_fetch_e.exception is None: - return False - return any( - isinstance(metric_fetch_e.exception, e) for e in cls.recoverable_exceptions - ) - # NOTE: This is rarely overridden –– oonly if you want to fetch data in groups # consisting of multiple different metric classes, for data to be fetched together. # This makes sense only if `fetch_trial data_multi` or `fetch_experiment_data_multi`