Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/adapter/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def setUp(self) -> None:
True
)
self.generation_strategy = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=NUM_SOBOL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def setUp(self) -> None:

# Create a simple step-based generation strategy
self.step_gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=5,
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_sobol_mbm_generation_strategy(

generation_strategy = GenerationStrategy(
name=name,
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=num_sobol_trials,
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/methods/sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def get_sobol_generation_strategy() -> GenerationStrategy:
return GenerationStrategy(
name="Sobol",
steps=[
nodes=[
GenerationStep(generator=Generators.SOBOL, num_trials=-1),
],
)
Expand Down
2 changes: 1 addition & 1 deletion ax/early_stopping/experiment_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def replay_experiment(
# Setup a Orchestrator with a dummy gs to replay the historical experiment
dummy_sobol_gs = GenerationStrategy(
name="sobol",
steps=[
nodes=[
GenerationStep(generator=Generators.SOBOL, num_trials=-1),
],
)
Expand Down
2 changes: 1 addition & 1 deletion ax/generation_strategy/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def choose_generation_strategy_legacy(
)
else: # `force_random_search` is True or we could not suggest BO model
gs = GenerationStrategy(
steps=[
nodes=[
_make_sobol_step(
seed=random_seed,
should_deduplicate=should_deduplicate,
Expand Down
24 changes: 3 additions & 21 deletions ax/generation_strategy/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, TypeVar
from typing import TypeVar

from ax.adapter.base import Adapter
from ax.core.data import Data
Expand Down Expand Up @@ -77,27 +77,10 @@ class GenerationStrategy(Base):
def __init__(
self,
*,
nodes: list[GenerationNode] | None = None,
nodes: Sequence[GenerationNode],
name: str | None = None,
**kwargs: Any,
) -> None:
self._generator_runs = []
if not (bool(steps := kwargs.get("steps")) ^ bool(nodes)): # Steps XOR nodes
raise GenerationStrategyMisconfiguredException(
"GenerationStrategy must contain either steps or nodes. "
f"Got: nodes={nodes}, steps={steps}."
)

if steps:
warnings.warn(
DeprecationWarning(
"Specifying `steps` input is no longer supported. Please use "
"`nodes`. `steps` argument will be removed in early 2026."
),
stacklevel=2,
)
nodes = steps

self._validate_and_set_node_graph(nodes=nodes)

# Set name to an explicit value ahead of time to avoid
Expand Down Expand Up @@ -399,7 +382,6 @@ def _unset_non_persistent_state_fields(self) -> None:
n._trials_from_node_cache = set()
n._cached_trial_count = None

# TODO: Deprecate `steps` argument fully in Q1'26.
def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None:
"""Initialize and validate the steps provided to this GenerationStrategy.

Expand Down
67 changes: 22 additions & 45 deletions ax/generation_strategy/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def test_gen_with_parameter_constraints(self) -> None:

# Create generation strategy with Sobol
gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=10,
Expand Down Expand Up @@ -308,7 +308,7 @@ def setUp(self) -> None:
self.step_generator_kwargs = {"silently_filter_kwargs": True}
self.hss_experiment = get_hierarchical_search_space_experiment()
self.sobol_GS = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
Generators.SOBOL,
num_trials=-1,
Expand Down Expand Up @@ -465,7 +465,7 @@ def _get_sobol_mbm_step_gs(
) -> GenerationStrategy:
return GenerationStrategy(
name="Sobol+MBM",
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=num_sobol_trials,
Expand Down Expand Up @@ -502,7 +502,7 @@ def test_validation(self) -> None:
# num_trials can be positive or -1.
with self.assertRaises(UserInputError):
GenerationStrategy(
steps=[
nodes=[
GenerationStep(generator=Generators.SOBOL, num_trials=5),
GenerationStep(
generator=Generators.BOTORCH_MODULAR, num_trials=-10
Expand All @@ -515,7 +515,7 @@ def test_validation(self) -> None:
)
with self.assertRaisesRegex(UserInputError, "Maximum concurrency should be"):
GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL, num_trials=5, max_parallelism=-1
),
Expand All @@ -527,7 +527,7 @@ def test_custom_callables_for_models(self) -> None:
with self.assertRaises(UserInputError):
GenerationStrategy(
# pyre-ignore [6]: Testing deprecated input.
steps=[GenerationStep(generator=get_sobol, num_trials=-1)]
nodes=[GenerationStep(generator=get_sobol, num_trials=-1)]
)

def test_string_representation(self) -> None:
Expand All @@ -548,7 +548,7 @@ def test_string_representation(self) -> None:
"pausing_criteria=None)])",
)
gs2 = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)]
nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)]
)
self.assertEqual(
str(gs2),
Expand Down Expand Up @@ -601,7 +601,7 @@ def test_min_observed(self) -> None:
# pyre-fixme[6]: For 1st param expected `bool` but got `Experiment`.
exp = get_branin_experiment(get_branin_experiment())
gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL, num_trials=5, min_trials_observed=5
),
Expand All @@ -619,7 +619,7 @@ def test_one_node_with_finite_num_trials(self) -> None:
# pyre-fixme[6]: For 1st param expected `bool` but got `Experiment`.
exp = get_branin_experiment(get_branin_experiment())
gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(generator=Generators.SOBOL, num_trials=5),
]
)
Expand All @@ -634,7 +634,7 @@ def test_do_not_enforce_min_observations(self) -> None:
# case the previous model should be used until there is enough data.
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=1,
Expand Down Expand Up @@ -711,7 +711,7 @@ def test_sobol_MBM_strategy_keep_generating(self) -> None:
def test_sobol_strategy(self) -> None:
exp = get_branin_experiment()
sobol_generation_strategy = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=5,
Expand All @@ -730,7 +730,7 @@ def test_sobol_strategy(self) -> None:
def test_factorial_thompson_strategy(self, _: MagicMock) -> None:
exp = get_branin_experiment()
factorial_thompson_gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.FACTORIAL,
num_trials=1,
Expand Down Expand Up @@ -761,7 +761,7 @@ def test_factorial_thompson_strategy(self, _: MagicMock) -> None:

def test_clone_reset(self) -> None:
ftgs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(generator=Generators.FACTORIAL, num_trials=1),
GenerationStep(generator=Generators.THOMPSON, num_trials=2),
]
Expand All @@ -774,7 +774,7 @@ def test_clone_reset(self) -> None:

def test_kwargs_passed(self) -> None:
gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=1,
Expand Down Expand Up @@ -817,7 +817,7 @@ def test_sobol_MBM_strategy_batches(self) -> None:
def test_store_experiment(self) -> None:
exp = get_branin_experiment()
sobol_generation_strategy = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
)
self.assertIsNone(sobol_generation_strategy._experiment)
sobol_generation_strategy.gen_single_trial(exp)
Expand All @@ -828,7 +828,7 @@ def test_gen_single_trial_extracts_pending_observations(self) -> None:
experiment when none are passed in."""
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
)
# Create a trial and mark it as running so it becomes a pending observation.
trial = exp.new_trial(generator_run=gs.gen_single_trial(exp))
Expand Down Expand Up @@ -856,7 +856,7 @@ def test_gen_single_trial_raises_error_for_multiple_trials(self) -> None:
"""Test that gen_single_trial raises AxError if gen returns multiple trials."""
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
)
gr = gs.gen_single_trial(exp)
# Mock gen to return multiple trials
Expand All @@ -869,7 +869,7 @@ def test_gen_single_trial_raises_error_for_multiple_generator_runs(self) -> None
GeneratorRuns for a single trial."""
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
)
gr = gs.gen_single_trial(exp)
# Mock gen to return a single trial with multiple GeneratorRuns
Expand All @@ -880,7 +880,7 @@ def test_gen_single_trial_raises_error_for_multiple_generator_runs(self) -> None
def test_max_parallelism_reached(self) -> None:
exp = get_branin_experiment()
sobol_generation_strategy = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL, num_trials=5, max_parallelism=1
)
Expand Down Expand Up @@ -983,7 +983,7 @@ def test_current_generator_run_limit(self) -> None:
NUM_ROUNDS = 4
exp = get_branin_experiment()
sobol_gs_with_parallelism_limits = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=NUM_INIT_TRIALS,
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def test_current_generator_run_limit_unlimited_second_step(self) -> None:
NUM_ROUNDS = 4
exp = get_branin_experiment()
sobol_gs_with_parallelism_limits = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=NUM_INIT_TRIALS,
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def test_gen_with_fixed_features(
) -> None:
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=1,
Expand Down Expand Up @@ -1377,29 +1377,6 @@ def test_gs_setup_with_nodes(self) -> None:
node_3,
],
)
# check error raised if provided both steps and nodes
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "contain either steps or nodes"
):
GenerationStrategy(
nodes=[
node_1,
node_3,
],
steps=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=5,
generator_kwargs=self.step_generator_kwargs,
),
GenerationStep(
generator=Generators.BOTORCH_MODULAR,
num_trials=-1,
generator_kwargs=self.step_generator_kwargs,
),
],
)

# check error raised if two transition criterion defining a single edge have
# differing `continue_trial_generation` settings
with self.assertRaisesRegex(
Expand Down
8 changes: 4 additions & 4 deletions ax/generation_strategy/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_default_step_criterion_setup(self) -> None:
experiment = get_experiment()
gs = GenerationStrategy(
name="SOBOL+MBM::default",
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=3,
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_min_trials_is_met(self) -> None:
experiment = self.branin_experiment
gs = GenerationStrategy(
name="SOBOL::default",
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=4,
Expand Down Expand Up @@ -312,7 +312,7 @@ def test_min_trials_count_only_with_data(self) -> None:
experiment = self.branin_experiment
gs = GenerationStrategy(
name="SOBOL::default",
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=4,
Expand Down Expand Up @@ -508,7 +508,7 @@ def test_trials_from_node_empty(self) -> None:
experiment = get_experiment()
gs = GenerationStrategy(
name="SOBOL::default",
steps=[
nodes=[
GenerationStep(
generator=Generators.SOBOL,
num_trials=4,
Expand Down
Loading
Loading