Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions ax/adapter/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,26 @@ def _gen(
linear_constraints = extract_parameter_constraints(
search_space.parameter_constraints, self.parameters
)
# Extract generated points to deduplicate against.
# Exclude out-of-design arms (which can only be manual arms
# instead of adapter-generated arms).
# Extract generated points.
# For normal generators these are used to deduplicate against.
# For in-sample generators (LILO labeling) they are the selection
# pool from which arms are drawn — not a dedup set. The two use
# cases have been shoehorned into the same code path; consider
# splitting them into separate methods in a future refactor.
generated_points = None
is_in_sample = isinstance(self.generator, InSampleUniformGenerator)
if self.generator.deduplicate:
arms_to_deduplicate = self._experiment.arms_by_signature_for_deduplication
# For normal generators, exclude arms from FAILED trials so the
# model may re-suggest them. For in-sample generators this
# exclusion is harmful: LILO labeling trials borrow arms from
# regular trials, so a FAILED labeling trial would incorrectly
# remove the original arm from the selection pool. Use the
# full arms_by_signature instead.
arms_to_deduplicate = (
self._experiment.arms_by_signature
if is_in_sample
else self._experiment.arms_by_signature_for_deduplication
)
# For in-sample generators, restrict to arms from trials that
# have or expect observed data (COMPLETED, EARLY_STOPPED,
# RUNNING). This prevents selecting arms from CANDIDATE/STAGED
Expand Down
59 changes: 59 additions & 0 deletions ax/adapter/tests/test_random_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,65 @@ def test_in_sample_excludes_non_data_bearing_trial_arms(self) -> None:
assert generated_points is not None
self.assertEqual(len(generated_points), 2)

def test_in_sample_failed_lilo_trial_does_not_poison_arm_pool(self) -> None:
"""A FAILED LILO labeling trial shares arm signatures with COMPLETED
regular trials. The in-sample pool must still include those arms —
the FAILED trial should not remove them from the selection pool.

Regression test for the arm-pool exhaustion bug where
``arms_by_signature_for_deduplication`` blindly excluded signatures
that appeared in *any* FAILED trial, even if the same arm existed
in a non-FAILED trial.
"""
search_space = SearchSpace(self.parameters[:2])
exp = Experiment(search_space=search_space)

# Trials 0 and 1: COMPLETED regular trials with 1 arm each.
arm_a = Arm(parameters={"x": 0.5, "y": 1.5})
arm_b = Arm(parameters={"x": 0.3, "y": 1.3})
t0 = exp.new_trial()
t0.add_arm(arm_a)
t0.mark_running(no_runner_required=True)
exp.trials[0].mark_completed()

t1 = exp.new_trial()
t1.add_arm(arm_b)
t1.mark_running(no_runner_required=True)
exp.trials[1].mark_completed()

# Trial 2: FAILED LILO labeling trial re-using the same arms.
# Simulates a LILO labeling trial that borrowed arm_a but whose
# LLM metric call failed.
t2 = exp.new_trial()
t2.add_arm(Arm(parameters={"x": 0.5, "y": 1.5})) # same as arm_a
t2.mark_running(no_runner_required=True)
t2.mark_failed()

# Sanity: arms_by_signature_for_deduplication removes arm_a's sig.
dedup = exp.arms_by_signature_for_deduplication
self.assertNotIn(arm_a.signature, dedup)

# But InSampleUniformGenerator should still see both arms.
generator = InSampleUniformGenerator(seed=0)
adapter = RandomAdapter(
experiment=exp,
generator=generator,
transforms=Cont_X_trans,
)

with mock.patch.object(
generator,
"gen",
wraps=generator.gen,
) as mock_gen:
adapter.gen(n=2)

# Both arms from COMPLETED trials must be in generated_points,
# despite arm_a's signature also appearing in a FAILED trial.
generated_points = mock_gen.call_args.kwargs["generated_points"]
assert generated_points is not None
self.assertEqual(len(generated_points), 2)

def test_generation_with_all_fixed(self) -> None:
# Make sure candidate generation succeeds and returns correct parameters
# when all parameters are fixed.
Expand Down
10 changes: 10 additions & 0 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ def status(self) -> TrialStatus:
self._mark_stale_if_past_TTL()
return none_throws(self._status)

@property
def expecting_data(self) -> bool:
"""Whether this trial expects data via the standard data-fetch pipeline.

Returns ``False`` for LILO labeling trials because their pairwise
preference data is fetched inline during the labeling loop and is
never refetched through the normal orchestration path.
"""
return self.status.expecting_data and self.trial_type != Keys.LILO_LABELING

@status.setter
def status(self, status: TrialStatus) -> None:
raise NotImplementedError("Use `trial.mark_*` methods to set trial status.")
Expand Down
23 changes: 7 additions & 16 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.core.trial import Trial
from ax.core.trial_status import (
DEFAULT_STATUSES_TO_WARM_START,
STATUSES_EXPECTING_DATA,
TrialStatus,
)
from ax.core.trial_status import DEFAULT_STATUSES_TO_WARM_START, TrialStatus
from ax.core.types import ComparisonOp, TParameterization
from ax.exceptions.core import (
AxError,
Expand Down Expand Up @@ -1406,10 +1402,10 @@ def trials_by_status(self) -> dict[TrialStatus, list[BaseTrial]]:

@property
def trials_expecting_data(self) -> list[BaseTrial]:
"""list[BaseTrial]: the list of all trials for which data has arrived
or is expected to arrive.
"""list[BaseTrial]: the list of all trials that expect data via the
standard data-fetch pipeline.
"""
return [trial for trial in self.trials.values() if trial.status.expecting_data]
return [trial for trial in self.trials.values() if trial.expecting_data]

@property
def completed_trials(self) -> list[BaseTrial]:
Expand All @@ -1433,15 +1429,10 @@ def running_trial_indices(self) -> set[int]:

@property
def trial_indices_expecting_data(self) -> set[int]:
"""Set of indices of trials, statuses of which indicate that we expect
these trials to have data, either already or in the future.
"""Set of indices of trials that expect data via the standard
data-fetch pipeline.
"""
return set.union(
*(
self.trial_indices_by_status[status]
for status in STATUSES_EXPECTING_DATA
)
)
return {trial.index for trial in self.trials.values() if trial.expecting_data}

def trial_indices_with_data(
self, critical_metrics_only: bool | None = True
Expand Down
2 changes: 1 addition & 1 deletion ax/core/multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def fetch_data(
[
(
trial.fetch_data(**kwargs, metrics=metrics)
if trial.status.expecting_data
if trial.expecting_data
else Data()
)
for trial in self.trials.values()
Expand Down
18 changes: 18 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,14 +881,26 @@ def test_experiment_runner(self) -> None:
candidate_batch.run()
candidate_batch._status = TrialStatus.CANDIDATE
self.assertEqual(self.experiment.trials_expecting_data, [batch])

# LILO labeling trials are excluded from trials_expecting_data
# (their data is fetched inline during the labeling loop).
lilo_batch = self.experiment.new_batch_trial(
trial_type=Keys.LILO_LABELING,
)
lilo_batch.run()
lilo_batch.mark_completed()
self.assertEqual(self.experiment.trials_expecting_data, [batch])

tbs = self.experiment.trials_by_status # All statuses should be present
self.assertEqual(len(tbs), len(TrialStatus))
self.assertEqual(tbs[TrialStatus.RUNNING], [batch])
self.assertEqual(tbs[TrialStatus.CANDIDATE], [candidate_batch])
self.assertEqual(tbs[TrialStatus.COMPLETED], [lilo_batch])
tibs = self.experiment.trial_indices_by_status
self.assertEqual(len(tibs), len(TrialStatus))
self.assertEqual(tibs[TrialStatus.RUNNING], {0})
self.assertEqual(tibs[TrialStatus.CANDIDATE], {1})
self.assertEqual(tibs[TrialStatus.COMPLETED], {2})

identifier = {"new_runner": True}
# pyre-fixme[6]: For 1st param expected `Optional[str]` but got `Dict[str,
Expand Down Expand Up @@ -1727,6 +1739,12 @@ def test_trial_indices(self) -> None:
)
self.assertEqual(experiment.trial_indices_expecting_data, {2, 5})

# LILO labeling trials are excluded from trial_indices_expecting_data.
lilo_trial = experiment.new_batch_trial(trial_type=Keys.LILO_LABELING)
lilo_trial.mark_running(no_runner_required=True)
lilo_trial.mark_completed()
self.assertEqual(experiment.trial_indices_expecting_data, {2, 5})

def test_trial_indices_with_data(self) -> None:
exp = get_branin_experiment_with_multi_objective(
with_status_quo=True,
Expand Down
15 changes: 15 additions & 0 deletions ax/core/tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ax.exceptions.core import TrialMutationError, UnsupportedError, UserInputError
from ax.metrics.branin import BraninMetric
from ax.runners.synthetic import SyntheticRunner
from ax.utils.common.constants import Keys
from ax.utils.common.result import Ok
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -271,8 +272,22 @@ def test_mark_as(self) -> None:
TrialStatus.COMPLETED,
]:
self.assertTrue(self.trial.status.expecting_data)
# trial.expecting_data follows status for normal trials.
self.assertTrue(self.trial.expecting_data)
else:
self.assertFalse(self.trial.status.expecting_data)
self.assertFalse(self.trial.expecting_data)

def test_expecting_data_excludes_lilo(self) -> None:
"""LILO labeling trials never expect data via the standard pipeline."""
self.trial._trial_type = Keys.LILO_LABELING
self.trial.mark_running(no_runner_required=True)
self.assertTrue(self.trial.status.expecting_data)
self.assertFalse(self.trial.expecting_data)

self.trial.mark_completed()
self.assertTrue(self.trial.status.expecting_data)
self.assertFalse(self.trial.expecting_data)

def test_stop(self) -> None:
# test bad old status
Expand Down
2 changes: 1 addition & 1 deletion ax/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def trials_expecting_data(self) -> list[BaseTrial]:
"""
trials = []
for trial in self.experiment.trials.values():
if trial.status.expecting_data:
if trial.expecting_data:
if self.trial_type is None or trial.trial_type == self.trial_type:
trials.append(trial)
return trials
Expand Down
2 changes: 1 addition & 1 deletion ax/orchestration/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ def test_get_best_trial(self) -> None:
# We override the optimization config but not objectives, so a
# ValueError results when extract_objective_weights tries to find
# the MOO metric signature in the outcomes list.
with self.assertRaisesRegex(ValueError, "branin_a"):
with self.assertRaisesRegex(ValueError, "not in list"):
orchestrator.get_pareto_optimal_parameters(
optimization_config=get_branin_multi_objective_optimization_config(
has_objective_thresholds=True
Expand Down
Loading