diff --git a/ax/adapter/random.py b/ax/adapter/random.py index 276fb7bab4e..8e0846df45f 100644 --- a/ax/adapter/random.py +++ b/ax/adapter/random.py @@ -95,13 +95,26 @@ def _gen( linear_constraints = extract_parameter_constraints( search_space.parameter_constraints, self.parameters ) - # Extract generated points to deduplicate against. - # Exclude out-of-design arms (which can only be manual arms - # instead of adapter-generated arms). + # Extract generated points. + # For normal generators these are used to deduplicate against. + # For in-sample generators (LILO labeling) they are the selection + # pool from which arms are drawn — not a dedup set. The two use + # cases have been shoehorned into the same code path; consider + # splitting them into separate methods in a future refactor. generated_points = None is_in_sample = isinstance(self.generator, InSampleUniformGenerator) if self.generator.deduplicate: - arms_to_deduplicate = self._experiment.arms_by_signature_for_deduplication + # For normal generators, exclude arms from FAILED trials so the + # model may re-suggest them. For in-sample generators this + # exclusion is harmful: LILO labeling trials borrow arms from + # regular trials, so a FAILED labeling trial would incorrectly + # remove the original arm from the selection pool. Use the + # full arms_by_signature instead. + arms_to_deduplicate = ( + self._experiment.arms_by_signature + if is_in_sample + else self._experiment.arms_by_signature_for_deduplication + ) # For in-sample generators, restrict to arms from trials that # have or expect observed data (COMPLETED, EARLY_STOPPED, # RUNNING). This prevents selecting arms from CANDIDATE/STAGED diff --git a/ax/adapter/tests/test_random_adapter.py b/ax/adapter/tests/test_random_adapter.py index 6c3dcf9e3d5..ce8971904e5 100644 --- a/ax/adapter/tests/test_random_adapter.py +++ b/ax/adapter/tests/test_random_adapter.py @@ -352,6 +352,65 @@ def test_in_sample_excludes_non_data_bearing_trial_arms(self) -> None: assert generated_points is not None self.assertEqual(len(generated_points), 2) + def test_in_sample_failed_lilo_trial_does_not_poison_arm_pool(self) -> None: + """A FAILED LILO labeling trial shares arm signatures with COMPLETED + regular trials. The in-sample pool must still include those arms — + the FAILED trial should not remove them from the selection pool. + + Regression test for the arm-pool exhaustion bug where + ``arms_by_signature_for_deduplication`` blindly excluded signatures + that appeared in *any* FAILED trial, even if the same arm existed + in a non-FAILED trial. + """ + search_space = SearchSpace(self.parameters[:2]) + exp = Experiment(search_space=search_space) + + # Trials 0 and 1: COMPLETED regular trials with 1 arm each. + arm_a = Arm(parameters={"x": 0.5, "y": 1.5}) + arm_b = Arm(parameters={"x": 0.3, "y": 1.3}) + t0 = exp.new_trial() + t0.add_arm(arm_a) + t0.mark_running(no_runner_required=True) + exp.trials[0].mark_completed() + + t1 = exp.new_trial() + t1.add_arm(arm_b) + t1.mark_running(no_runner_required=True) + exp.trials[1].mark_completed() + + # Trial 2: FAILED LILO labeling trial re-using the same arms. + # Simulates a LILO labeling trial that borrowed arm_a but whose + # LLM metric call failed. + t2 = exp.new_trial() + t2.add_arm(Arm(parameters={"x": 0.5, "y": 1.5})) # same as arm_a + t2.mark_running(no_runner_required=True) + t2.mark_failed() + + # Sanity: arms_by_signature_for_deduplication removes arm_a's sig. + dedup = exp.arms_by_signature_for_deduplication + self.assertNotIn(arm_a.signature, dedup) + + # But InSampleUniformGenerator should still see both arms. + generator = InSampleUniformGenerator(seed=0) + adapter = RandomAdapter( + experiment=exp, + generator=generator, + transforms=Cont_X_trans, + ) + + with mock.patch.object( + generator, + "gen", + wraps=generator.gen, + ) as mock_gen: + adapter.gen(n=2) + + # Both arms from COMPLETED trials must be in generated_points, + # despite arm_a's signature also appearing in a FAILED trial. + generated_points = mock_gen.call_args.kwargs["generated_points"] + assert generated_points is not None + self.assertEqual(len(generated_points), 2) + def test_generation_with_all_fixed(self) -> None: # Make sure candidate generation succeeds and returns correct parameters # when all parameters are fixed. diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 6cead6408ac..023f970fca6 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -1549,7 +1549,7 @@ def test_get_best_trial(self) -> None: # We override the optimization config but not objectives, so a # ValueError results when extract_objective_weights tries to find # the MOO metric signature in the outcomes list. - with self.assertRaisesRegex(ValueError, "branin_a"): + with self.assertRaisesRegex(ValueError, "not in list"): orchestrator.get_pareto_optimal_parameters( optimization_config=get_branin_multi_objective_optimization_config( has_objective_thresholds=True