Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,18 +787,22 @@ def _query_historical_experiments_given_parameters(
experiments_params[exp_name].append(sqa_param)
experiments_time_created[exp_name] = time_created

return {
exp_name: (
decoder.search_space_from_sqa(
results: dict[str, tuple[SearchSpace | None, datetime]] = {}
for exp_name, parameters_sqa in experiments_params.items():
try:
search_space = decoder.search_space_from_sqa(
parameters_sqa=parameters_sqa,
# Parameter constraints don't matter for search space
# compatibility
parameter_constraints_sqa=[],
),
experiments_time_created[exp_name],
)
for exp_name, parameters_sqa in experiments_params.items()
}
)
except Exception as e:
logger.warning(
f"Failed to decode search space for experiment '{exp_name}': {e}"
)
search_space = None
results[exp_name] = (search_space, experiments_time_created[exp_name])
return results


def identify_transferable_experiments(
Expand Down
55 changes: 55 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,6 +3236,61 @@ def test_query_historical_experiments_given_parameters(self) -> None:
self.assertIn("w", none_throws(returned_ss).parameters)
self.assertIn("x", none_throws(returned_ss).parameters)

with self.subTest("returns_none_search_space_on_decode_failure"):
# Save two experiments
exp1 = get_experiment_with_batch_trial()
exp1.name = "exp_decode_success"
exp1.experiment_type = "TEST"
exp1.is_test = False
trial1 = exp1.trials[0]
exp1.attach_data(get_data(trial_index=trial1.index))
save_experiment(exp1, config=config)

exp2 = get_experiment_with_batch_trial()
exp2.name = "exp_decode_failure"
exp2.experiment_type = "TEST"
exp2.is_test = False
trial2 = exp2.trials[0]
exp2.attach_data(get_data(trial_index=trial2.index))
save_experiment(exp2, config=config)

# Look up exp2's ID before mocking to avoid nested session_scope
with session_scope() as session:
exp2_id: int = (
session.query(SQAExperiment.id)
.filter(SQAExperiment.name == "exp_decode_failure")
.scalar()
)

# Mock decoder to raise on the second experiment's parameters
original_search_space_from_sqa: Callable[..., SearchSpace | None] = (
Decoder.search_space_from_sqa
)

def _side_effect(self: Decoder, **kwargs: Any) -> SearchSpace | None:
params = kwargs.get("parameters_sqa", [])
exp_ids = {p.experiment_id for p in params}
if exp_ids == {exp2_id}:
raise RuntimeError("Simulated decode failure")
return original_search_space_from_sqa(self, **kwargs)

with patch.object(Decoder, "search_space_from_sqa", _side_effect):
result = _query_historical_experiments_given_parameters(
parameter_names=["w", "x"],
experiment_types=["TEST"],
config=config,
)

# The successfully decoded experiment should have a SearchSpace
self.assertIn("exp_decode_success", result)
ss_success, _ = result["exp_decode_success"]
self.assertIsNotNone(ss_success)

# The failed experiment should have None search space
self.assertIn("exp_decode_failure", result)
ss_failure, _ = result["exp_decode_failure"]
self.assertIsNone(ss_failure)

def test_identify_transferable_experiments(
self,
) -> None:
Expand Down
Loading