diff --git a/ax/analysis/__init__.py b/ax/analysis/__init__.py index d93f376f6fa..42b9df71412 100644 --- a/ax/analysis/__init__.py +++ b/ax/analysis/__init__.py @@ -6,7 +6,7 @@ # pyre-strict from ax.analysis.analysis import Analysis -from ax.analysis.best_trials import BestTrials +from ax.analysis.best_arms import BestArms from ax.analysis.metric_summary import MetricSummary from ax.analysis.search_space_summary import SearchSpaceSummary from ax.analysis.summary import Summary @@ -16,7 +16,7 @@ __all__ = [ "Analysis", - "BestTrials", + "BestArms", "MetricSummary", "SearchSpaceSummary", "Summary", diff --git a/ax/analysis/best_trials.py b/ax/analysis/best_arms.py similarity index 92% rename from ax/analysis/best_trials.py rename to ax/analysis/best_arms.py index 5a134d63061..d1e5d105897 100644 --- a/ax/analysis/best_trials.py +++ b/ax/analysis/best_arms.py @@ -14,6 +14,7 @@ from ax.analysis.summary import Summary from ax.analysis.utils import validate_experiment from ax.core.analysis_card import AnalysisCard +from ax.core.batch_trial import BatchTrial from ax.core.experiment import Experiment from ax.core.trial_status import TrialStatus from ax.exceptions.core import DataRequiredError @@ -44,7 +45,7 @@ @final -class BestTrials(Analysis): +class BestArms(Analysis): """ High-level summary of the best trial(s) in the Experiment with one row per arm. Any values missing at compute time will be represented as None. Columns where @@ -100,7 +101,7 @@ def validate_applicable_state( # Validate optimization config exists if experiment is None or experiment.optimization_config is None: return ( - "`BestTrials` analysis requires an `OptimizationConfig`. " + "`BestArms` analysis requires an `OptimizationConfig`. " "Ensure the `Experiment` has an `optimization_config` set to compute " "this analysis." ) @@ -121,7 +122,7 @@ def validate_applicable_state( if self.use_model_predictions or optimization_config.is_moo_problem: if generation_strategy is None: return ( - "`BestTrials` analysis requires a `GenerationStrategy` input " + "`BestArms` analysis requires a `GenerationStrategy` input " "when using model predictions or for multi-objective " "optimization problems." ) @@ -155,7 +156,7 @@ def compute( if not trial_indices: raise DataRequiredError( - "No best trial(s) could be identified. This could be due to " + "No best arm(s) could be identified. This could be due to " "insufficient data or no trials meeting the optimization criteria." ) @@ -199,8 +200,15 @@ def compute( if "relativized" in summary_card.subtitle: subtitle += " Metric values are shown relative to the status quo baseline." - return self._create_analysis_card( + has_batch_trials = any( + isinstance(trial, BatchTrial) for trial in exp.trials.values() + ) + display_name = "BestArm" if has_batch_trials else "BestTrials" + + card = self._create_analysis_card( title=(f"{title_prefix} for Experiment"), subtitle=subtitle, df=summary_card.df, ) + card.name = display_name + return card diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index a7e454c33df..bf1f311b532 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -78,7 +78,7 @@ class OverviewAnalysis(Analysis): * BanditRollout * UtilityProgressionAnalysis * ProgressionPlots - * BestTrials + * BestArms * Summary * Insights * Sensitivity Plots diff --git a/ax/analysis/results.py b/ax/analysis/results.py index a5481e8e133..4b114c31ec8 100644 --- a/ax/analysis/results.py +++ b/ax/analysis/results.py @@ -11,7 +11,7 @@ from ax.adapter.base import Adapter from ax.analysis.analysis import Analysis -from ax.analysis.best_trials import BestTrials +from ax.analysis.best_arms import BestArms from ax.analysis.plotly.arm_effects import ArmEffectsPlot from ax.analysis.plotly.bandit_rollout import BanditRollout from ax.analysis.plotly.progression import ( @@ -207,19 +207,18 @@ def compute( else None ) - # Compute best trials, skip for experiments with ScalarizedOutcomeConstraints or - # BatchTrials as it is not supported yet + # Compute best trials, skip for experiments with ScalarizedOutcomeConstraints has_scalarized_outcome_constraints = optimization_config is not None and any( isinstance(oc, ScalarizedOutcomeConstraint) for oc in optimization_config.outcome_constraints ) - best_trials_card = ( - BestTrials().compute_or_error_card( + best_arms_card = ( + BestArms().compute_or_error_card( experiment=experiment, generation_strategy=generation_strategy, adapter=adapter, ) - if not has_batch_trials and not has_scalarized_outcome_constraints + if not has_scalarized_outcome_constraints else None ) @@ -286,7 +285,7 @@ def compute( bandit_rollout_card, utility_progression_card, progression_group, - best_trials_card, + best_arms_card, summary, ) if child is not None diff --git a/ax/analysis/tests/test_best_trials.py b/ax/analysis/tests/test_best_arms.py similarity index 74% rename from ax/analysis/tests/test_best_trials.py rename to ax/analysis/tests/test_best_arms.py index c91c8b07caf..c925919b861 100644 --- a/ax/analysis/tests/test_best_trials.py +++ b/ax/analysis/tests/test_best_arms.py @@ -5,15 +5,20 @@ # pyre-strict -from ax.analysis.best_trials import BestTrials +from ax.analysis.best_arms import BestArms from ax.api.client import Client from ax.api.configs import RangeParameterConfig from ax.core.base_trial import TrialStatus from ax.exceptions.core import DataRequiredError from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_branin_experiment_with_multi_objective, +) +from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node -class TestBestTrials(TestCase): +class TestBestArms(TestCase): def setUp(self) -> None: super().setUp() self.client = Client() @@ -36,7 +41,7 @@ def setUp(self) -> None: self.experiment = self.client._experiment def test_compute_soo(self) -> None: - """Test BestTrials for single-objective optimization.""" + """Test BestArms for single-objective optimization.""" client = self.client # Setup: Create multiple trials with different objective values client.get_next_trials(max_trials=3) @@ -44,8 +49,8 @@ def test_compute_soo(self) -> None: client.complete_trial(trial_index=1, raw_data={"foo": 1.0}) client.complete_trial(trial_index=2, raw_data={"foo": 2.0}) - # Execute: Compute BestTrials analysis - analysis = BestTrials() + # Execute: Compute BestArms analysis + analysis = BestArms() card = analysis.compute( experiment=self.experiment, @@ -78,7 +83,7 @@ def test_compute_soo(self) -> None: ) def test_compute_moo(self) -> None: - """Test BestTrials for multi-objective optimization.""" + """Test BestArms for multi-objective optimization.""" client = self.client # Reconfigure as multi-objective client.configure_optimization( @@ -97,8 +102,8 @@ def test_compute_moo(self) -> None: trial_index=2, raw_data={"foo": 3.0, "bar": 1.0} ) # Pareto optimal - # Execute: Compute BestTrials analysis - analysis = BestTrials() + # Execute: Compute BestArms analysis + analysis = BestArms() card = analysis.compute( experiment=client._experiment, @@ -116,15 +121,15 @@ def test_compute_moo(self) -> None: self.assertEqual(pareto_indices, {0, 2}) def test_no_eligible_trials_returns_validation_error(self) -> None: - """Test that BestTrials returns validation error when no eligible trials.""" + """Test that BestArms returns validation error when no eligible trials.""" client = self.client # Setup: Create and complete a trial, then filter by a different status client.get_next_trials(max_trials=1) client.complete_trial(trial_index=0, raw_data={"foo": 1.0}) - # Execute: Attempt to validate BestTrials with FAILED status filter + # Execute: Attempt to validate BestArms with FAILED status filter # (no trials are FAILED, so this should return an error) - analysis = BestTrials(trial_statuses=[TrialStatus.FAILED]) + analysis = BestArms(trial_statuses=[TrialStatus.FAILED]) # Assert: Should return error string when no trials match the status filter error = analysis.validate_applicable_state( @@ -144,7 +149,7 @@ def test_generation_strategy_requirements(self) -> None: with self.subTest(msg="GS not required for raw observations"): # Execute & Assert: Should succeed without generation_strategy # when using raw observations - analysis = BestTrials(use_model_predictions=False) + analysis = BestArms(use_model_predictions=False) card = analysis.compute( experiment=self.experiment, generation_strategy=None ) @@ -157,7 +162,7 @@ def test_generation_strategy_requirements(self) -> None: with self.subTest(msg="GS required for model predictions"): # Execute & Assert: Should return error from validation # when generation_strategy is None with model predictions - analysis = BestTrials(use_model_predictions=True) + analysis = BestArms(use_model_predictions=True) error = analysis.validate_applicable_state( experiment=self.experiment, generation_strategy=None ) @@ -176,8 +181,8 @@ def test_trial_status_filter(self) -> None: # Mark trial 2 as failed self.experiment.trials[2].mark_failed() - # Execute: Compute BestTrials with only COMPLETED status filter - analysis = BestTrials(trial_statuses=[TrialStatus.COMPLETED]) + # Execute: Compute BestArms with only COMPLETED status filter + analysis = BestArms(trial_statuses=[TrialStatus.COMPLETED]) card = analysis.compute( experiment=self.experiment, generation_strategy=client._generation_strategy, @@ -200,11 +205,45 @@ def test_use_model_predictions_insufficient_data(self) -> None: client.complete_trial(trial_index=2, raw_data={"foo": 3.0}) # Execute & Assert: Should raise error when model cannot make predictions - analysis = BestTrials(use_model_predictions=True) + analysis = BestArms(use_model_predictions=True) with self.assertRaisesRegex( - DataRequiredError, "No best trial.*could be identified" + DataRequiredError, "No best arm.*could be identified" ): analysis.compute( experiment=self.experiment, generation_strategy=client._generation_strategy, ) + + def test_compute_soo_multi_batch(self) -> None: + """Test SOO with batch trials: card.name is 'BestArm' and output contains + all arms from the winning batch.""" + exp = get_branin_experiment( + with_completed_batch=True, num_batch_trial=2, num_arms_per_trial=3 + ) + + card = BestArms().compute(experiment=exp) + + # Batch trials produce "BestArm" display name + self.assertEqual(card.name, "BestArm") + self.assertEqual(card.title, "Best Trial for Experiment") + # Output should contain all arms from the winning batch, not just one + self.assertGreater(len(card.df), 1) + # All returned arms should be from the same trial + self.assertEqual(len(card.df["trial_index"].unique()), 1) + + def test_compute_moo_multi_batch(self) -> None: + """Test MOO Pareto frontier across multiple batch trials.""" + exp = get_branin_experiment_with_multi_objective( + with_completed_batch=True, + with_status_quo=True, + has_objective_thresholds=True, + ) + gs = get_default_generation_strategy_at_MBM_node(experiment=exp) + + card = BestArms().compute(experiment=exp, generation_strategy=gs) + + self.assertEqual(card.name, "BestArm") + self.assertEqual(card.title, "Pareto Frontier Trials for Experiment") + self.assertIn("pareto", card.subtitle.lower()) + # Pareto frontier should return at least one trial + self.assertGreater(len(card.df), 0) diff --git a/ax/analysis/tests/test_results.py b/ax/analysis/tests/test_results.py index 5282bc5736f..0126a79a1e9 100644 --- a/ax/analysis/tests/test_results.py +++ b/ax/analysis/tests/test_results.py @@ -124,10 +124,13 @@ def test_compute_with_single_objective_no_constraints(self) -> None: "Should have arm effects in children", ) - # Assert: Should have best trials + # Assert: Should have best arms self.assertTrue( - any("BestTrials" in name for name in child_names), - "Should have best trials in children", + any( + "BestArms" in name or "BestTrials" in name or "BestArm" in name + for name in child_names + ), + "Should have best arms in children", ) # Assert: No error cards should be present