From e84260122805b03b99cedf6c7fdc53640d2507b8 Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Sun, 5 Apr 2026 21:35:42 -0700 Subject: [PATCH 1/2] Fix in-sample arm pool exhaustion from FAILED LILO labeling trials (#5145) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: `InSampleUniformGenerator` selects existing arms for LILO labeling by drawing from the `generated_points` pool constructed in `RandomAdapter._gen()`. This pool started from `arms_by_signature_for_deduplication`, which excludes arm signatures from *any* FAILED trial. Because LILO labeling trials borrow arms from regular BO trials (same signatures), a FAILED labeling trial incorrectly removes the original arm from the selection pool — even though it still exists in a non-FAILED trial. Within a single LILO labeling loop run, failed iterations accumulate and progressively poison the pool until no arms remain, crashing with: ValueError: Cannot select 2 arms: only 0 eligible arms available The fix: for in-sample generators, start from `arms_by_signature` (all arms) instead of `arms_by_signature_for_deduplication`. The existing `expecting_sigs` filter already handles the real restriction (only data-expecting, non-abandoned arms), so the FAILED-arm exclusion was just an accidental side effect of piggybacking on the dedup infrastructure. Reviewed By: bletham Differential Revision: D99611303 --- ax/adapter/random.py | 21 ++++++-- ax/adapter/tests/test_random_adapter.py | 59 +++++++++++++++++++++ ax/orchestration/tests/test_orchestrator.py | 2 +- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/ax/adapter/random.py b/ax/adapter/random.py index 276fb7bab4e..8e0846df45f 100644 --- a/ax/adapter/random.py +++ b/ax/adapter/random.py @@ -95,13 +95,26 @@ def _gen( linear_constraints = extract_parameter_constraints( search_space.parameter_constraints, self.parameters ) - # Extract generated points to deduplicate against. - # Exclude out-of-design arms (which can only be manual arms - # instead of adapter-generated arms). + # Extract generated points. + # For normal generators these are used to deduplicate against. + # For in-sample generators (LILO labeling) they are the selection + # pool from which arms are drawn — not a dedup set. The two use + # cases have been shoehorned into the same code path; consider + # splitting them into separate methods in a future refactor. generated_points = None is_in_sample = isinstance(self.generator, InSampleUniformGenerator) if self.generator.deduplicate: - arms_to_deduplicate = self._experiment.arms_by_signature_for_deduplication + # For normal generators, exclude arms from FAILED trials so the + # model may re-suggest them. For in-sample generators this + # exclusion is harmful: LILO labeling trials borrow arms from + # regular trials, so a FAILED labeling trial would incorrectly + # remove the original arm from the selection pool. Use the + # full arms_by_signature instead. + arms_to_deduplicate = ( + self._experiment.arms_by_signature + if is_in_sample + else self._experiment.arms_by_signature_for_deduplication + ) # For in-sample generators, restrict to arms from trials that # have or expect observed data (COMPLETED, EARLY_STOPPED, # RUNNING). This prevents selecting arms from CANDIDATE/STAGED diff --git a/ax/adapter/tests/test_random_adapter.py b/ax/adapter/tests/test_random_adapter.py index 6c3dcf9e3d5..ce8971904e5 100644 --- a/ax/adapter/tests/test_random_adapter.py +++ b/ax/adapter/tests/test_random_adapter.py @@ -352,6 +352,65 @@ def test_in_sample_excludes_non_data_bearing_trial_arms(self) -> None: assert generated_points is not None self.assertEqual(len(generated_points), 2) + def test_in_sample_failed_lilo_trial_does_not_poison_arm_pool(self) -> None: + """A FAILED LILO labeling trial shares arm signatures with COMPLETED + regular trials. The in-sample pool must still include those arms — + the FAILED trial should not remove them from the selection pool. + + Regression test for the arm-pool exhaustion bug where + ``arms_by_signature_for_deduplication`` blindly excluded signatures + that appeared in *any* FAILED trial, even if the same arm existed + in a non-FAILED trial. + """ + search_space = SearchSpace(self.parameters[:2]) + exp = Experiment(search_space=search_space) + + # Trials 0 and 1: COMPLETED regular trials with 1 arm each. + arm_a = Arm(parameters={"x": 0.5, "y": 1.5}) + arm_b = Arm(parameters={"x": 0.3, "y": 1.3}) + t0 = exp.new_trial() + t0.add_arm(arm_a) + t0.mark_running(no_runner_required=True) + exp.trials[0].mark_completed() + + t1 = exp.new_trial() + t1.add_arm(arm_b) + t1.mark_running(no_runner_required=True) + exp.trials[1].mark_completed() + + # Trial 2: FAILED LILO labeling trial re-using the same arms. + # Simulates a LILO labeling trial that borrowed arm_a but whose + # LLM metric call failed. + t2 = exp.new_trial() + t2.add_arm(Arm(parameters={"x": 0.5, "y": 1.5})) # same as arm_a + t2.mark_running(no_runner_required=True) + t2.mark_failed() + + # Sanity: arms_by_signature_for_deduplication removes arm_a's sig. + dedup = exp.arms_by_signature_for_deduplication + self.assertNotIn(arm_a.signature, dedup) + + # But InSampleUniformGenerator should still see both arms. + generator = InSampleUniformGenerator(seed=0) + adapter = RandomAdapter( + experiment=exp, + generator=generator, + transforms=Cont_X_trans, + ) + + with mock.patch.object( + generator, + "gen", + wraps=generator.gen, + ) as mock_gen: + adapter.gen(n=2) + + # Both arms from COMPLETED trials must be in generated_points, + # despite arm_a's signature also appearing in a FAILED trial. + generated_points = mock_gen.call_args.kwargs["generated_points"] + assert generated_points is not None + self.assertEqual(len(generated_points), 2) + def test_generation_with_all_fixed(self) -> None: # Make sure candidate generation succeeds and returns correct parameters # when all parameters are fixed. diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 6cead6408ac..023f970fca6 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -1549,7 +1549,7 @@ def test_get_best_trial(self) -> None: # We override the optimization config but not objectives, so a # ValueError results when extract_objective_weights tries to find # the MOO metric signature in the outcomes list. - with self.assertRaisesRegex(ValueError, "branin_a"): + with self.assertRaisesRegex(ValueError, "not in list"): orchestrator.get_pareto_optimal_parameters( optimization_config=get_branin_multi_objective_optimization_config( has_objective_thresholds=True From d6a975df56c5797bcd3a5ef0db61c96b80b9f5c7 Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Sun, 5 Apr 2026 21:35:42 -0700 Subject: [PATCH 2/2] Exclude LILO labeling trials from trials_expecting_data (#5144) Summary: LILO labeling trials have their pairwise preference data fetched inline during the labeling loop (`_run_lilo_labeling_loop`), not via the normal data refetch paths (orchestrator `poll_and_process_results`, PTSClient `refetch_data`). Including them in `trials_expecting_data` causes unnecessary data fetch attempts for metrics (e.g., Deltoid) that don't exist on these trials, producing noisy errors and wasting time. This filters LILO labeling trials (`trial_type == Keys.LILO_LABELING`) from the `trials_expecting_data` property on `Experiment`, which is the centralized source used by all downstream data refetch consumers. Differential Revision: D99571562 --- ax/core/base_trial.py | 10 ++++++++++ ax/core/experiment.py | 23 +++++++---------------- ax/core/multi_type_experiment.py | 2 +- ax/core/tests/test_experiment.py | 18 ++++++++++++++++++ ax/core/tests/test_trial.py | 15 +++++++++++++++ ax/orchestration/orchestrator.py | 2 +- 6 files changed, 52 insertions(+), 18 deletions(-) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 8c529111e65..c9563e959b0 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -198,6 +198,16 @@ def status(self) -> TrialStatus: self._mark_stale_if_past_TTL() return none_throws(self._status) + @property + def expecting_data(self) -> bool: + """Whether this trial expects data via the standard data-fetch pipeline. + + Returns ``False`` for LILO labeling trials because their pairwise + preference data is fetched inline during the labeling loop and is + never refetched through the normal orchestration path. + """ + return self.status.expecting_data and self.trial_type != Keys.LILO_LABELING + @status.setter def status(self, status: TrialStatus) -> None: raise NotImplementedError("Use `trial.mark_*` methods to set trial status.") diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 8cb018f1900..d5c3fabd732 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -45,11 +45,7 @@ from ax.core.runner import Runner from ax.core.search_space import SearchSpace from ax.core.trial import Trial -from ax.core.trial_status import ( - DEFAULT_STATUSES_TO_WARM_START, - STATUSES_EXPECTING_DATA, - TrialStatus, -) +from ax.core.trial_status import DEFAULT_STATUSES_TO_WARM_START, TrialStatus from ax.core.types import ComparisonOp, TParameterization from ax.exceptions.core import ( AxError, @@ -1406,10 +1402,10 @@ def trials_by_status(self) -> dict[TrialStatus, list[BaseTrial]]: @property def trials_expecting_data(self) -> list[BaseTrial]: - """list[BaseTrial]: the list of all trials for which data has arrived - or is expected to arrive. + """list[BaseTrial]: the list of all trials that expect data via the + standard data-fetch pipeline. """ - return [trial for trial in self.trials.values() if trial.status.expecting_data] + return [trial for trial in self.trials.values() if trial.expecting_data] @property def completed_trials(self) -> list[BaseTrial]: @@ -1433,15 +1429,10 @@ def running_trial_indices(self) -> set[int]: @property def trial_indices_expecting_data(self) -> set[int]: - """Set of indices of trials, statuses of which indicate that we expect - these trials to have data, either already or in the future. + """Set of indices of trials that expect data via the standard + data-fetch pipeline. """ - return set.union( - *( - self.trial_indices_by_status[status] - for status in STATUSES_EXPECTING_DATA - ) - ) + return {trial.index for trial in self.trials.values() if trial.expecting_data} def trial_indices_with_data( self, critical_metrics_only: bool | None = True diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index f6a1ce91add..7003c29167d 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -275,7 +275,7 @@ def fetch_data( [ ( trial.fetch_data(**kwargs, metrics=metrics) - if trial.status.expecting_data + if trial.expecting_data else Data() ) for trial in self.trials.values() diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index d716a11efc9..31432d5a3ef 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -881,14 +881,26 @@ def test_experiment_runner(self) -> None: candidate_batch.run() candidate_batch._status = TrialStatus.CANDIDATE self.assertEqual(self.experiment.trials_expecting_data, [batch]) + + # LILO labeling trials are excluded from trials_expecting_data + # (their data is fetched inline during the labeling loop). + lilo_batch = self.experiment.new_batch_trial( + trial_type=Keys.LILO_LABELING, + ) + lilo_batch.run() + lilo_batch.mark_completed() + self.assertEqual(self.experiment.trials_expecting_data, [batch]) + tbs = self.experiment.trials_by_status # All statuses should be present self.assertEqual(len(tbs), len(TrialStatus)) self.assertEqual(tbs[TrialStatus.RUNNING], [batch]) self.assertEqual(tbs[TrialStatus.CANDIDATE], [candidate_batch]) + self.assertEqual(tbs[TrialStatus.COMPLETED], [lilo_batch]) tibs = self.experiment.trial_indices_by_status self.assertEqual(len(tibs), len(TrialStatus)) self.assertEqual(tibs[TrialStatus.RUNNING], {0}) self.assertEqual(tibs[TrialStatus.CANDIDATE], {1}) + self.assertEqual(tibs[TrialStatus.COMPLETED], {2}) identifier = {"new_runner": True} # pyre-fixme[6]: For 1st param expected `Optional[str]` but got `Dict[str, @@ -1727,6 +1739,12 @@ def test_trial_indices(self) -> None: ) self.assertEqual(experiment.trial_indices_expecting_data, {2, 5}) + # LILO labeling trials are excluded from trial_indices_expecting_data. + lilo_trial = experiment.new_batch_trial(trial_type=Keys.LILO_LABELING) + lilo_trial.mark_running(no_runner_required=True) + lilo_trial.mark_completed() + self.assertEqual(experiment.trial_indices_expecting_data, {2, 5}) + def test_trial_indices_with_data(self) -> None: exp = get_branin_experiment_with_multi_objective( with_status_quo=True, diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 0bfd29d27f2..4fa43495e04 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -26,6 +26,7 @@ from ax.exceptions.core import TrialMutationError, UnsupportedError, UserInputError from ax.metrics.branin import BraninMetric from ax.runners.synthetic import SyntheticRunner +from ax.utils.common.constants import Keys from ax.utils.common.result import Ok from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -271,8 +272,22 @@ def test_mark_as(self) -> None: TrialStatus.COMPLETED, ]: self.assertTrue(self.trial.status.expecting_data) + # trial.expecting_data follows status for normal trials. + self.assertTrue(self.trial.expecting_data) else: self.assertFalse(self.trial.status.expecting_data) + self.assertFalse(self.trial.expecting_data) + + def test_expecting_data_excludes_lilo(self) -> None: + """LILO labeling trials never expect data via the standard pipeline.""" + self.trial._trial_type = Keys.LILO_LABELING + self.trial.mark_running(no_runner_required=True) + self.assertTrue(self.trial.status.expecting_data) + self.assertFalse(self.trial.expecting_data) + + self.trial.mark_completed() + self.assertTrue(self.trial.status.expecting_data) + self.assertFalse(self.trial.expecting_data) def test_stop(self) -> None: # test bad old status diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index f6ec4abebe0..d1b6e1f827b 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -477,7 +477,7 @@ def trials_expecting_data(self) -> list[BaseTrial]: """ trials = [] for trial in self.experiment.trials.values(): - if trial.status.expecting_data: + if trial.expecting_data: if self.trial_type is None or trial.trial_type == self.trial_type: trials.append(trial) return trials