Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ax/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-strict

from ax.analysis.analysis import Analysis
from ax.analysis.best_trials import BestTrials
from ax.analysis.best_arms import BestArms
from ax.analysis.metric_summary import MetricSummary
from ax.analysis.search_space_summary import SearchSpaceSummary
from ax.analysis.summary import Summary
Expand All @@ -16,7 +16,7 @@

__all__ = [
"Analysis",
"BestTrials",
"BestArms",
"MetricSummary",
"SearchSpaceSummary",
"Summary",
Expand Down
18 changes: 13 additions & 5 deletions ax/analysis/best_trials.py → ax/analysis/best_arms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.analysis.summary import Summary
from ax.analysis.utils import validate_experiment
from ax.core.analysis_card import AnalysisCard
from ax.core.batch_trial import BatchTrial
from ax.core.experiment import Experiment
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import DataRequiredError
Expand Down Expand Up @@ -44,7 +45,7 @@


@final
class BestTrials(Analysis):
class BestArms(Analysis):
"""
High-level summary of the best trial(s) in the Experiment with one row per arm.
Any values missing at compute time will be represented as None. Columns where
Expand Down Expand Up @@ -100,7 +101,7 @@ def validate_applicable_state(
# Validate optimization config exists
if experiment is None or experiment.optimization_config is None:
return (
"`BestTrials` analysis requires an `OptimizationConfig`. "
"`BestArms` analysis requires an `OptimizationConfig`. "
"Ensure the `Experiment` has an `optimization_config` set to compute "
"this analysis."
)
Expand All @@ -121,7 +122,7 @@ def validate_applicable_state(
if self.use_model_predictions or optimization_config.is_moo_problem:
if generation_strategy is None:
return (
"`BestTrials` analysis requires a `GenerationStrategy` input "
"`BestArms` analysis requires a `GenerationStrategy` input "
"when using model predictions or for multi-objective "
"optimization problems."
)
Expand Down Expand Up @@ -155,7 +156,7 @@ def compute(

if not trial_indices:
raise DataRequiredError(
"No best trial(s) could be identified. This could be due to "
"No best arm(s) could be identified. This could be due to "
"insufficient data or no trials meeting the optimization criteria."
)

Expand Down Expand Up @@ -199,8 +200,15 @@ def compute(
if "relativized" in summary_card.subtitle:
subtitle += " Metric values are shown relative to the status quo baseline."

return self._create_analysis_card(
has_batch_trials = any(
isinstance(trial, BatchTrial) for trial in exp.trials.values()
)
display_name = "BestArm" if has_batch_trials else "BestTrials"

card = self._create_analysis_card(
title=(f"{title_prefix} for Experiment"),
subtitle=subtitle,
df=summary_card.df,
)
card.name = display_name
return card
2 changes: 1 addition & 1 deletion ax/analysis/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class OverviewAnalysis(Analysis):
* BanditRollout
* UtilityProgressionAnalysis
* ProgressionPlots
* BestTrials
* BestArms
* Summary
* Insights
* Sensitivity Plots
Expand Down
13 changes: 6 additions & 7 deletions ax/analysis/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ax.adapter.base import Adapter
from ax.analysis.analysis import Analysis
from ax.analysis.best_trials import BestTrials
from ax.analysis.best_arms import BestArms
from ax.analysis.plotly.arm_effects import ArmEffectsPlot
from ax.analysis.plotly.bandit_rollout import BanditRollout
from ax.analysis.plotly.progression import (
Expand Down Expand Up @@ -207,19 +207,18 @@ def compute(
else None
)

# Compute best trials, skip for experiments with ScalarizedOutcomeConstraints or
# BatchTrials as it is not supported yet
# Compute best trials, skip for experiments with ScalarizedOutcomeConstraints
has_scalarized_outcome_constraints = optimization_config is not None and any(
isinstance(oc, ScalarizedOutcomeConstraint)
for oc in optimization_config.outcome_constraints
)
best_trials_card = (
BestTrials().compute_or_error_card(
best_arms_card = (
BestArms().compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
if not has_batch_trials and not has_scalarized_outcome_constraints
if not has_scalarized_outcome_constraints
else None
)

Expand Down Expand Up @@ -286,7 +285,7 @@ def compute(
bandit_rollout_card,
utility_progression_card,
progression_group,
best_trials_card,
best_arms_card,
summary,
)
if child is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@

# pyre-strict

from ax.analysis.best_trials import BestTrials
from ax.analysis.best_arms import BestArms
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
from ax.core.base_trial import TrialStatus
from ax.exceptions.core import DataRequiredError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_experiment_with_multi_objective,
)
from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node


class TestBestTrials(TestCase):
class TestBestArms(TestCase):
def setUp(self) -> None:
super().setUp()
self.client = Client()
Expand All @@ -36,16 +41,16 @@ def setUp(self) -> None:
self.experiment = self.client._experiment

def test_compute_soo(self) -> None:
"""Test BestTrials for single-objective optimization."""
"""Test BestArms for single-objective optimization."""
client = self.client
# Setup: Create multiple trials with different objective values
client.get_next_trials(max_trials=3)
client.complete_trial(trial_index=0, raw_data={"foo": 3.0})
client.complete_trial(trial_index=1, raw_data={"foo": 1.0})
client.complete_trial(trial_index=2, raw_data={"foo": 2.0})

# Execute: Compute BestTrials analysis
analysis = BestTrials()
# Execute: Compute BestArms analysis
analysis = BestArms()

card = analysis.compute(
experiment=self.experiment,
Expand Down Expand Up @@ -78,7 +83,7 @@ def test_compute_soo(self) -> None:
)

def test_compute_moo(self) -> None:
"""Test BestTrials for multi-objective optimization."""
"""Test BestArms for multi-objective optimization."""
client = self.client
# Reconfigure as multi-objective
client.configure_optimization(
Expand All @@ -97,8 +102,8 @@ def test_compute_moo(self) -> None:
trial_index=2, raw_data={"foo": 3.0, "bar": 1.0}
) # Pareto optimal

# Execute: Compute BestTrials analysis
analysis = BestTrials()
# Execute: Compute BestArms analysis
analysis = BestArms()

card = analysis.compute(
experiment=client._experiment,
Expand All @@ -116,15 +121,15 @@ def test_compute_moo(self) -> None:
self.assertEqual(pareto_indices, {0, 2})

def test_no_eligible_trials_returns_validation_error(self) -> None:
"""Test that BestTrials returns validation error when no eligible trials."""
"""Test that BestArms returns validation error when no eligible trials."""
client = self.client
# Setup: Create and complete a trial, then filter by a different status
client.get_next_trials(max_trials=1)
client.complete_trial(trial_index=0, raw_data={"foo": 1.0})

# Execute: Attempt to validate BestTrials with FAILED status filter
# Execute: Attempt to validate BestArms with FAILED status filter
# (no trials are FAILED, so this should return an error)
analysis = BestTrials(trial_statuses=[TrialStatus.FAILED])
analysis = BestArms(trial_statuses=[TrialStatus.FAILED])

# Assert: Should return error string when no trials match the status filter
error = analysis.validate_applicable_state(
Expand All @@ -144,7 +149,7 @@ def test_generation_strategy_requirements(self) -> None:
with self.subTest(msg="GS not required for raw observations"):
# Execute & Assert: Should succeed without generation_strategy
# when using raw observations
analysis = BestTrials(use_model_predictions=False)
analysis = BestArms(use_model_predictions=False)
card = analysis.compute(
experiment=self.experiment, generation_strategy=None
)
Expand All @@ -157,7 +162,7 @@ def test_generation_strategy_requirements(self) -> None:
with self.subTest(msg="GS required for model predictions"):
# Execute & Assert: Should return error from validation
# when generation_strategy is None with model predictions
analysis = BestTrials(use_model_predictions=True)
analysis = BestArms(use_model_predictions=True)
error = analysis.validate_applicable_state(
experiment=self.experiment, generation_strategy=None
)
Expand All @@ -176,8 +181,8 @@ def test_trial_status_filter(self) -> None:
# Mark trial 2 as failed
self.experiment.trials[2].mark_failed()

# Execute: Compute BestTrials with only COMPLETED status filter
analysis = BestTrials(trial_statuses=[TrialStatus.COMPLETED])
# Execute: Compute BestArms with only COMPLETED status filter
analysis = BestArms(trial_statuses=[TrialStatus.COMPLETED])
card = analysis.compute(
experiment=self.experiment,
generation_strategy=client._generation_strategy,
Expand All @@ -200,11 +205,45 @@ def test_use_model_predictions_insufficient_data(self) -> None:
client.complete_trial(trial_index=2, raw_data={"foo": 3.0})

# Execute & Assert: Should raise error when model cannot make predictions
analysis = BestTrials(use_model_predictions=True)
analysis = BestArms(use_model_predictions=True)
with self.assertRaisesRegex(
DataRequiredError, "No best trial.*could be identified"
DataRequiredError, "No best arm.*could be identified"
):
analysis.compute(
experiment=self.experiment,
generation_strategy=client._generation_strategy,
)

def test_compute_soo_multi_batch(self) -> None:
"""Test SOO with batch trials: card.name is 'BestArm' and output contains
all arms from the winning batch."""
exp = get_branin_experiment(
with_completed_batch=True, num_batch_trial=2, num_arms_per_trial=3
)

card = BestArms().compute(experiment=exp)

# Batch trials produce "BestArm" display name
self.assertEqual(card.name, "BestArm")
self.assertEqual(card.title, "Best Trial for Experiment")
# Output should contain all arms from the winning batch, not just one
self.assertGreater(len(card.df), 1)
# All returned arms should be from the same trial
self.assertEqual(len(card.df["trial_index"].unique()), 1)

def test_compute_moo_multi_batch(self) -> None:
"""Test MOO Pareto frontier across multiple batch trials."""
exp = get_branin_experiment_with_multi_objective(
with_completed_batch=True,
with_status_quo=True,
has_objective_thresholds=True,
)
gs = get_default_generation_strategy_at_MBM_node(experiment=exp)

card = BestArms().compute(experiment=exp, generation_strategy=gs)

self.assertEqual(card.name, "BestArm")
self.assertEqual(card.title, "Pareto Frontier Trials for Experiment")
self.assertIn("pareto", card.subtitle.lower())
# Pareto frontier should return at least one trial
self.assertGreater(len(card.df), 0)
9 changes: 6 additions & 3 deletions ax/analysis/tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def test_compute_with_single_objective_no_constraints(self) -> None:
"Should have arm effects in children",
)

# Assert: Should have best trials
# Assert: Should have best arms
self.assertTrue(
any("BestTrials" in name for name in child_names),
"Should have best trials in children",
any(
"BestArms" in name or "BestTrials" in name or "BestArm" in name
for name in child_names
),
"Should have best arms in children",
)

# Assert: No error cards should be present
Expand Down
Loading