Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions ax/adapter/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,26 @@ def _gen(
linear_constraints = extract_parameter_constraints(
search_space.parameter_constraints, self.parameters
)
# Extract generated points to deduplicate against.
# Exclude out-of-design arms (which can only be manual arms
# instead of adapter-generated arms).
# Extract generated points.
# For normal generators these are used to deduplicate against.
# For in-sample generators (LILO labeling) they are the selection
# pool from which arms are drawn — not a dedup set. The two use
# cases have been shoehorned into the same code path; consider
# splitting them into separate methods in a future refactor.
generated_points = None
is_in_sample = isinstance(self.generator, InSampleUniformGenerator)
if self.generator.deduplicate:
arms_to_deduplicate = self._experiment.arms_by_signature_for_deduplication
# For normal generators, exclude arms from FAILED trials so the
# model may re-suggest them. For in-sample generators this
# exclusion is harmful: LILO labeling trials borrow arms from
# regular trials, so a FAILED labeling trial would incorrectly
# remove the original arm from the selection pool. Use the
# full arms_by_signature instead.
arms_to_deduplicate = (
self._experiment.arms_by_signature
if is_in_sample
else self._experiment.arms_by_signature_for_deduplication
)
# For in-sample generators, restrict to arms from trials that
# have or expect observed data (COMPLETED, EARLY_STOPPED,
# RUNNING). This prevents selecting arms from CANDIDATE/STAGED
Expand Down
59 changes: 59 additions & 0 deletions ax/adapter/tests/test_random_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,65 @@ def test_in_sample_excludes_non_data_bearing_trial_arms(self) -> None:
assert generated_points is not None
self.assertEqual(len(generated_points), 2)

def test_in_sample_failed_lilo_trial_does_not_poison_arm_pool(self) -> None:
"""A FAILED LILO labeling trial shares arm signatures with COMPLETED
regular trials. The in-sample pool must still include those arms —
the FAILED trial should not remove them from the selection pool.

Regression test for the arm-pool exhaustion bug where
``arms_by_signature_for_deduplication`` blindly excluded signatures
that appeared in *any* FAILED trial, even if the same arm existed
in a non-FAILED trial.
"""
search_space = SearchSpace(self.parameters[:2])
exp = Experiment(search_space=search_space)

# Trials 0 and 1: COMPLETED regular trials with 1 arm each.
arm_a = Arm(parameters={"x": 0.5, "y": 1.5})
arm_b = Arm(parameters={"x": 0.3, "y": 1.3})
t0 = exp.new_trial()
t0.add_arm(arm_a)
t0.mark_running(no_runner_required=True)
exp.trials[0].mark_completed()

t1 = exp.new_trial()
t1.add_arm(arm_b)
t1.mark_running(no_runner_required=True)
exp.trials[1].mark_completed()

# Trial 2: FAILED LILO labeling trial re-using the same arms.
# Simulates a LILO labeling trial that borrowed arm_a but whose
# LLM metric call failed.
t2 = exp.new_trial()
t2.add_arm(Arm(parameters={"x": 0.5, "y": 1.5})) # same as arm_a
t2.mark_running(no_runner_required=True)
t2.mark_failed()

# Sanity: arms_by_signature_for_deduplication removes arm_a's sig.
dedup = exp.arms_by_signature_for_deduplication
self.assertNotIn(arm_a.signature, dedup)

# But InSampleUniformGenerator should still see both arms.
generator = InSampleUniformGenerator(seed=0)
adapter = RandomAdapter(
experiment=exp,
generator=generator,
transforms=Cont_X_trans,
)

with mock.patch.object(
generator,
"gen",
wraps=generator.gen,
) as mock_gen:
adapter.gen(n=2)

# Both arms from COMPLETED trials must be in generated_points,
# despite arm_a's signature also appearing in a FAILED trial.
generated_points = mock_gen.call_args.kwargs["generated_points"]
assert generated_points is not None
self.assertEqual(len(generated_points), 2)

def test_generation_with_all_fixed(self) -> None:
# Make sure candidate generation succeeds and returns correct parameters
# when all parameters are fixed.
Expand Down
2 changes: 1 addition & 1 deletion ax/orchestration/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ def test_get_best_trial(self) -> None:
# We override the optimization config but not objectives, so a
# ValueError results when extract_objective_weights tries to find
# the MOO metric signature in the outcomes list.
with self.assertRaisesRegex(ValueError, "branin_a"):
with self.assertRaisesRegex(ValueError, "not in list"):
orchestrator.get_pareto_optimal_parameters(
optimization_config=get_branin_multi_objective_optimization_config(
has_objective_thresholds=True
Expand Down
Loading