From 2f76d18064a9beeecd34ac434fd795cd9d19d5b8 Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Thu, 2 Apr 2026 14:38:10 -0700 Subject: [PATCH] Enable BestTrials for Online Experiments Summary: Pass use_model_predictions=has_batch_trials to BestTrials in ResultsAnalysis. Online experiments (which use BatchTrials) get GP model predictions for noise smoothing via shrinkage. AutoML and other direct callers get the default (False). Per design meeting 3/26: 'If online experiment, set model_predictions=True. If AutoML, say False. Set based on if has batch trials or not. For now.' --- > Generated by [Metamate](https://fb.workplace.com/groups/metamate.feedback/) [Metamate Session](https://internalfb.com/intern/bunny?q=eggmate%20f5ba8eb9-b3ef-4a16-86ce-18bb5271aff5), [Trace](https://www.internalfb.com/confucius?session_id=f5ba8eb9-b3ef-4a16-86ce-18bb5271aff5&tab=Trace) Differential Revision: D99128865 --- ax/analysis/__init__.py | 4 +- ax/analysis/{best_trials.py => best_arms.py} | 18 +++-- ax/analysis/overview.py | 2 +- ax/analysis/results.py | 13 ++-- ...{test_best_trials.py => test_best_arms.py} | 73 ++++++++++++++----- ax/analysis/tests/test_results.py | 9 ++- 6 files changed, 84 insertions(+), 35 deletions(-) rename ax/analysis/{best_trials.py => best_arms.py} (92%) rename ax/analysis/tests/{test_best_trials.py => test_best_arms.py} (74%) diff --git a/ax/analysis/__init__.py b/ax/analysis/__init__.py index d93f376f6fa..42b9df71412 100644 --- a/ax/analysis/__init__.py +++ b/ax/analysis/__init__.py @@ -6,7 +6,7 @@ # pyre-strict from ax.analysis.analysis import Analysis -from ax.analysis.best_trials import BestTrials +from ax.analysis.best_arms import BestArms from ax.analysis.metric_summary import MetricSummary from ax.analysis.search_space_summary import SearchSpaceSummary from ax.analysis.summary import Summary @@ -16,7 +16,7 @@ __all__ = [ "Analysis", - "BestTrials", + "BestArms", "MetricSummary", "SearchSpaceSummary", "Summary", diff --git a/ax/analysis/best_trials.py b/ax/analysis/best_arms.py similarity index 92% rename from ax/analysis/best_trials.py rename to ax/analysis/best_arms.py index 5a134d63061..d1e5d105897 100644 --- a/ax/analysis/best_trials.py +++ b/ax/analysis/best_arms.py @@ -14,6 +14,7 @@ from ax.analysis.summary import Summary from ax.analysis.utils import validate_experiment from ax.core.analysis_card import AnalysisCard +from ax.core.batch_trial import BatchTrial from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus from ax.exceptions.core import DataRequiredError @@ -44,7 +45,7 @@ @final -class BestTrials(Analysis): +class BestArms(Analysis): """ High-level summary of the best trial(s) in the Experiment with one row per arm. Any values missing at compute time will be represented as None. Columns where @@ -100,7 +101,7 @@ def validate_applicable_state( # Validate optimization config exists if experiment is None or experiment.optimization_config is None: return ( - "`BestTrials` analysis requires an `OptimizationConfig`. " + "`BestArms` analysis requires an `OptimizationConfig`. " "Ensure the `Experiment` has an `optimization_config` set to compute " "this analysis." ) @@ -121,7 +122,7 @@ def validate_applicable_state( if self.use_model_predictions or optimization_config.is_moo_problem: if generation_strategy is None: return ( - "`BestTrials` analysis requires a `GenerationStrategy` input " + "`BestArms` analysis requires a `GenerationStrategy` input " "when using model predictions or for multi-objective " "optimization problems." ) @@ -155,7 +156,7 @@ def compute( if not trial_indices: raise DataRequiredError( - "No best trial(s) could be identified. This could be due to " + "No best arm(s) could be identified. This could be due to " "insufficient data or no trials meeting the optimization criteria." ) @@ -199,8 +200,15 @@ def compute( if "relativized" in summary_card.subtitle: subtitle += " Metric values are shown relative to the status quo baseline." - return self._create_analysis_card( + has_batch_trials = any( + isinstance(trial, BatchTrial) for trial in exp.trials.values() + ) + display_name = "BestArm" if has_batch_trials else "BestTrials" + + card = self._create_analysis_card( title=(f"{title_prefix} for Experiment"), subtitle=subtitle, df=summary_card.df, ) + card.name = display_name + return card diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index a7e454c33df..bf1f311b532 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -78,7 +78,7 @@ class OverviewAnalysis(Analysis): * BanditRollout * UtilityProgressionAnalysis * ProgressionPlots - * BestTrials + * BestArms * Summary * Insights * Sensitivity Plots diff --git a/ax/analysis/results.py b/ax/analysis/results.py index a5481e8e133..4b114c31ec8 100644 --- a/ax/analysis/results.py +++ b/ax/analysis/results.py @@ -11,7 +11,7 @@ from ax.adapter.base import Adapter from ax.analysis.analysis import Analysis -from ax.analysis.best_trials import BestTrials +from ax.analysis.best_arms import BestArms from ax.analysis.plotly.arm_effects import ArmEffectsPlot from ax.analysis.plotly.bandit_rollout import BanditRollout from ax.analysis.plotly.progression import ( @@ -207,19 +207,18 @@ def compute( else None ) - # Compute best trials, skip for experiments with ScalarizedOutcomeConstraints or - # BatchTrials as it is not supported yet + # Compute best trials, skip for experiments with ScalarizedOutcomeConstraints has_scalarized_outcome_constraints = optimization_config is not None and any( isinstance(oc, ScalarizedOutcomeConstraint) for oc in optimization_config.outcome_constraints ) - best_trials_card = ( - BestTrials().compute_or_error_card( + best_arms_card = ( + BestArms().compute_or_error_card( experiment=experiment, generation_strategy=generation_strategy, adapter=adapter, ) - if not has_batch_trials and not has_scalarized_outcome_constraints + if not has_scalarized_outcome_constraints else None ) @@ -286,7 +285,7 @@ def compute( bandit_rollout_card, utility_progression_card, progression_group, - best_trials_card, + best_arms_card, summary, ) if child is not None diff --git a/ax/analysis/tests/test_best_trials.py b/ax/analysis/tests/test_best_arms.py similarity index 74% rename from ax/analysis/tests/test_best_trials.py rename to ax/analysis/tests/test_best_arms.py index c91c8b07caf..c925919b861 100644 --- a/ax/analysis/tests/test_best_trials.py +++ b/ax/analysis/tests/test_best_arms.py @@ -5,15 +5,20 @@ # pyre-strict -from ax.analysis.best_trials import BestTrials +from ax.analysis.best_arms import BestArms from ax.api.client import Client from ax.api.configs import RangeParameterConfig from ax.core.base_trial import TrialStatus from ax.exceptions.core import DataRequiredError from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_branin_experiment_with_multi_objective, +) +from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node -class TestBestTrials(TestCase): +class TestBestArms(TestCase): def setUp(self) -> None: super().setUp() self.client = Client() @@ -36,7 +41,7 @@ def setUp(self) -> None: self.experiment = self.client._experiment def test_compute_soo(self) -> None: - """Test BestTrials for single-objective optimization.""" + """Test BestArms for single-objective optimization.""" client = self.client # Setup: Create multiple trials with different objective values client.get_next_trials(max_trials=3) @@ -44,8 +49,8 @@ def test_compute_soo(self) -> None: client.complete_trial(trial_index=1, raw_data={"foo": 1.0}) client.complete_trial(trial_index=2, raw_data={"foo": 2.0}) - # Execute: Compute BestTrials analysis - analysis = BestTrials() + # Execute: Compute BestArms analysis + analysis = BestArms() card = analysis.compute( experiment=self.experiment, @@ -78,7 +83,7 @@ def test_compute_soo(self) -> None: ) def test_compute_moo(self) -> None: - """Test BestTrials for multi-objective optimization.""" + """Test BestArms for multi-objective optimization.""" client = self.client # Reconfigure as multi-objective client.configure_optimization( @@ -97,8 +102,8 @@ def test_compute_moo(self) -> None: trial_index=2, raw_data={"foo": 3.0, "bar": 1.0} ) # Pareto optimal - # Execute: Compute BestTrials analysis - analysis = BestTrials() + # Execute: Compute BestArms analysis + analysis = BestArms() card = analysis.compute( experiment=client._experiment, @@ -116,15 +121,15 @@ def test_compute_moo(self) -> None: self.assertEqual(pareto_indices, {0, 2}) def test_no_eligible_trials_returns_validation_error(self) -> None: - """Test that BestTrials returns validation error when no eligible trials.""" + """Test that BestArms returns validation error when no eligible trials.""" client = self.client # Setup: Create and complete a trial, then filter by a different status client.get_next_trials(max_trials=1) client.complete_trial(trial_index=0, raw_data={"foo": 1.0}) - # Execute: Attempt to validate BestTrials with FAILED status filter + # Execute: Attempt to validate BestArms with FAILED status filter # (no trials are FAILED, so this should return an error) - analysis = BestTrials(trial_statuses=[TrialStatus.FAILED]) + analysis = BestArms(trial_statuses=[TrialStatus.FAILED]) # Assert: Should return error string when no trials match the status filter error = analysis.validate_applicable_state( @@ -144,7 +149,7 @@ def test_generation_strategy_requirements(self) -> None: with self.subTest(msg="GS not required for raw observations"): # Execute & Assert: Should succeed without generation_strategy # when using raw observations - analysis = BestTrials(use_model_predictions=False) + analysis = BestArms(use_model_predictions=False) card = analysis.compute( experiment=self.experiment, generation_strategy=None ) @@ -157,7 +162,7 @@ def test_generation_strategy_requirements(self) -> None: with self.subTest(msg="GS required for model predictions"): # Execute & Assert: Should return error from validation # when generation_strategy is None with model predictions - analysis = BestTrials(use_model_predictions=True) + analysis = BestArms(use_model_predictions=True) error = analysis.validate_applicable_state( experiment=self.experiment, generation_strategy=None ) @@ -176,8 +181,8 @@ def test_trial_status_filter(self) -> None: # Mark trial 2 as failed self.experiment.trials[2].mark_failed() - # Execute: Compute BestTrials with only COMPLETED status filter - analysis = BestTrials(trial_statuses=[TrialStatus.COMPLETED]) + # Execute: Compute BestArms with only COMPLETED status filter + analysis = BestArms(trial_statuses=[TrialStatus.COMPLETED]) card = analysis.compute( experiment=self.experiment, generation_strategy=client._generation_strategy, @@ -200,11 +205,45 @@ def test_use_model_predictions_insufficient_data(self) -> None: client.complete_trial(trial_index=2, raw_data={"foo": 3.0}) # Execute & Assert: Should raise error when model cannot make predictions - analysis = BestTrials(use_model_predictions=True) + analysis = BestArms(use_model_predictions=True) with self.assertRaisesRegex( - DataRequiredError, "No best trial.*could be identified" + DataRequiredError, "No best arm.*could be identified" ): analysis.compute( experiment=self.experiment, generation_strategy=client._generation_strategy, ) + + def test_compute_soo_multi_batch(self) -> None: + """Test SOO with batch trials: card.name is 'BestArm' and output contains + all arms from the winning batch.""" + exp = get_branin_experiment( + with_completed_batch=True, num_batch_trial=2, num_arms_per_trial=3 + ) + + card = BestArms().compute(experiment=exp) + + # Batch trials produce "BestArm" display name + self.assertEqual(card.name, "BestArm") + self.assertEqual(card.title, "Best Trial for Experiment") + # Output should contain all arms from the winning batch, not just one + self.assertGreater(len(card.df), 1) + # All returned arms should be from the same trial + self.assertEqual(len(card.df["trial_index"].unique()), 1) + + def test_compute_moo_multi_batch(self) -> None: + """Test MOO Pareto frontier across multiple batch trials.""" + exp = get_branin_experiment_with_multi_objective( + with_completed_batch=True, + with_status_quo=True, + has_objective_thresholds=True, + ) + gs = get_default_generation_strategy_at_MBM_node(experiment=exp) + + card = BestArms().compute(experiment=exp, generation_strategy=gs) + + self.assertEqual(card.name, "BestArm") + self.assertEqual(card.title, "Pareto Frontier Trials for Experiment") + self.assertIn("pareto", card.subtitle.lower()) + # Pareto frontier should return at least one trial + self.assertGreater(len(card.df), 0) diff --git a/ax/analysis/tests/test_results.py b/ax/analysis/tests/test_results.py index 5282bc5736f..0126a79a1e9 100644 --- a/ax/analysis/tests/test_results.py +++ b/ax/analysis/tests/test_results.py @@ -124,10 +124,13 @@ def test_compute_with_single_objective_no_constraints(self) -> None: "Should have arm effects in children", ) - # Assert: Should have best trials + # Assert: Should have best arms self.assertTrue( - any("BestTrials" in name for name in child_names), - "Should have best trials in children", + any( + "BestArms" in name or "BestTrials" in name or "BestArm" in name + for name in child_names + ), + "Should have best arms in children", ) # Assert: No error cards should be present