Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ax/analysis/plotly/tests/test_utility_progression.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _assert_valid_utility_card(
"""Assert that a card has valid structure for utility progression."""
self.assertIsInstance(card, PlotlyAnalysisCard)
self.assertEqual(card.name, "UtilityProgressionAnalysis")
self.assertIn("trace_index", card.df.columns)
self.assertIn("trial_index", card.df.columns)
self.assertIn("utility", card.df.columns)

def test_utility_progression_soo(self) -> None:
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_all_infeasible_points_raises_error(self) -> None:
with (
patch(
"ax.analysis.plotly.utility_progression.get_trace",
return_value=[math.inf, -math.inf, math.inf],
return_value={0: math.inf, 1: -math.inf, 2: math.inf},
),
self.assertRaises(ExperimentNotReadyError) as cm,
):
Expand Down
24 changes: 11 additions & 13 deletions ax/analysis/plotly/utility_progression.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@
_UTILITY_PROGRESSION_TITLE = "Utility Progression"

_TRACE_INDEX_EXPLANATION = (
"The x-axis shows trace index, which counts completed or early-stopped trials "
"sequentially (1, 2, 3, ...). This differs from trial index, which may have "
"gaps if some trials failed or were abandoned. For example, if trials 0, 2, "
"and 5 completed while trials 1, 3, and 4 failed, the trace indices would be "
"1, 2, 3 corresponding to trial indices 0, 2, 5."
"The x-axis shows trial index. Only completed or early-stopped trials with "
"complete metric data are included, so there may be gaps if some trials "
"failed, were abandoned, or have incomplete data."
)

_CUMULATIVE_BEST_EXPLANATION = (
Expand All @@ -57,7 +55,8 @@ class UtilityProgressionAnalysis(Analysis):

The DataFrame computed will contain one row per completed trial and the
following columns:
- trace_index: Sequential index of completed/early-stopped trials (1, 2, 3, ...)
- trial_index: The trial index of each completed/early-stopped trial
that has complete metric avilability.
- utility: The cumulative best utility value at that trial
"""

Expand Down Expand Up @@ -114,7 +113,7 @@ def compute(
)

# Check if all points are infeasible (inf or -inf values)
if all(np.isinf(value) for value in trace):
if all(np.isinf(value) for value in trace.values()):
raise ExperimentNotReadyError(
"All trials in the utility trace are infeasible i.e., they violate "
"outcome constraints, so there are no feasible points to plot. During "
Expand All @@ -125,12 +124,11 @@ def compute(
"space, or (2) relaxing outcome constraints."
)

# Create DataFrame with 1-based trace index for user-friendly display
# (1st completed trial, 2nd completed trial, etc. instead of 0-indexed)
# Create DataFrame with trial indices from the trace
df = pd.DataFrame(
{
"trace_index": list(range(1, len(trace) + 1)),
"utility": trace,
"trial_index": list(trace.keys()),
"utility": list(trace.values()),
}
)

Expand Down Expand Up @@ -185,14 +183,14 @@ def compute(
# Create the plot
fig = px.line(
data_frame=df,
x="trace_index",
x="trial_index",
y="utility",
markers=True,
color_discrete_sequence=[AX_BLUE],
)

# Update axis labels and format x-axis to show integers only
fig.update_xaxes(title_text="Trace Index", dtick=1, rangemode="nonnegative")
fig.update_xaxes(title_text="Trial Index", dtick=1, rangemode="nonnegative")
fig.update_yaxes(title_text=y_label)

return create_plotly_analysis_card(
Expand Down
28 changes: 22 additions & 6 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,10 @@ def _get_oracle_value_of_params(
dummy_experiment = get_oracle_experiment_from_params(
problem=problem, dict_of_dict_of_params={0: {"0_0": params}}
)
(inference_value,) = get_trace(
trace = get_trace(
experiment=dummy_experiment, optimization_config=problem.optimization_config
)
inference_value = next(iter(trace.values()))
return inference_value


Expand Down Expand Up @@ -510,12 +511,27 @@ def get_benchmark_result_from_experiment_and_gs(
dict_of_dict_of_params=dict_of_dict_of_params,
trial_statuses=trial_statuses,
)
oracle_trace = np.array(
get_trace(
experiment=actual_params_oracle_dummy_experiment,
optimization_config=problem.optimization_config,
)
oracle_trace_dict = get_trace(
experiment=actual_params_oracle_dummy_experiment,
optimization_config=problem.optimization_config,
)
# Expand trace dict to a positional array aligned with all trials,
# carry-forwarding the last best value for trials without data (e.g.,
# failed or abandoned trials preserved via trial_statuses).
maximize = (
isinstance(problem.optimization_config, MultiObjectiveOptimizationConfig)
or problem.optimization_config.objective.is_scalarized_objective
or not problem.optimization_config.objective.minimize
)
all_trial_indices = sorted(actual_params_oracle_dummy_experiment.trials.keys())
last_best = -float("inf") if maximize else float("inf")
oracle_trace_list: list[float] = []
for idx in all_trial_indices:
if idx in oracle_trace_dict:
last_best = oracle_trace_dict[idx]
oracle_trace_list.append(last_best)
oracle_trace = np.array(oracle_trace_list)

is_feasible_trace = np.array(
get_is_feasible_trace(
experiment=actual_params_oracle_dummy_experiment,
Expand Down
16 changes: 0 additions & 16 deletions ax/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ class Metric(SortableBase, SerializationMixin):
properties: Properties specific to a particular metric.
"""

# The set of exception types stored in a ``MetchFetchE.exception`` that are
# recoverable ``orchestrator._fetch_and_process_trials_data_results()``.
# Exception may be a subclass of any of these types. If you want your metric
# to never fail the trial, set this to ``{Exception}`` in your metric subclass.
recoverable_exceptions: set[type[Exception]] = set()
has_map_data: bool = False

def __init__(
Expand Down Expand Up @@ -164,17 +159,6 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta:
"""
return timedelta(0)

@classmethod
def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
"""Checks whether the given MetricFetchE is recoverable for this metric class
in ``orchestrator._fetch_and_process_trials_data_results``.
"""
if metric_fetch_e.exception is None:
return False
return any(
isinstance(metric_fetch_e.exception, e) for e in cls.recoverable_exceptions
)

# NOTE: This is rarely overridden –– oonly if you want to fetch data in groups
# consisting of multiple different metric classes, for data to be fetched together.
# This makes sense only if `fetch_trial data_multi` or `fetch_experiment_data_multi`
Expand Down
Loading
Loading