From 197957c6d731ff869760b0508980f75a506e3b35 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 3 Apr 2026 14:14:43 -0700 Subject: [PATCH] Remove deprecated `steps` kwarg from GenerationStrategy (#5142) Summary: `GenerationStep(...)` already returns a `GenerationNode` (it's a factory class using `__new__`). The `steps=` kwarg on `GenerationStrategy.__init__` was a deprecated alias for `nodes=` that was scheduled for removal in early 2026. This diff: - Removes the `steps=` kwarg and `**kwargs` from `GenerationStrategy.__init__`, leaving only `nodes: list[GenerationNode]` - Removes the associated deprecation warning and XOR validation logic - Updates all 22 call sites across production code, tests, benchmarks, and the JSON decoder to use `nodes=` instead of `steps=` - Removes unused `warnings` and `Any` imports Reviewed By: lena-kashtelyan Differential Revision: D99453408 --- ax/adapter/tests/test_model_fit_metrics.py | 2 +- .../tests/test_generation_strategy_graph.py | 2 +- ax/benchmark/methods/modular_botorch.py | 2 +- ax/benchmark/methods/sobol.py | 2 +- ax/early_stopping/experiment_replay.py | 2 +- ax/generation_strategy/dispatch_utils.py | 2 +- ax/generation_strategy/generation_strategy.py | 24 +------ .../tests/test_generation_strategy.py | 67 ++++++------------- .../tests/test_transition_criterion.py | 8 +-- ax/orchestration/tests/test_orchestrator.py | 12 ++-- ax/service/tests/test_ax_client.py | 18 ++--- ax/service/tests/test_interactive_loop.py | 2 +- ax/service/tests/test_managed_loop.py | 6 +- ax/service/tests/test_report_utils.py | 2 +- ax/storage/json_store/decoder.py | 2 +- ax/utils/testing/modeling_stubs.py | 1 - 16 files changed, 56 insertions(+), 98 deletions(-) diff --git a/ax/adapter/tests/test_model_fit_metrics.py b/ax/adapter/tests/test_model_fit_metrics.py index 99dd6c3042e..854e1800bda 100644 --- a/ax/adapter/tests/test_model_fit_metrics.py +++ b/ax/adapter/tests/test_model_fit_metrics.py @@ -44,7 +44,7 @@ def setUp(self) -> None: True ) self.generation_strategy = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=NUM_SOBOL, diff --git a/ax/analysis/graphviz/tests/test_generation_strategy_graph.py b/ax/analysis/graphviz/tests/test_generation_strategy_graph.py index b58f81629a8..a030bacfffe 100644 --- a/ax/analysis/graphviz/tests/test_generation_strategy_graph.py +++ b/ax/analysis/graphviz/tests/test_generation_strategy_graph.py @@ -34,7 +34,7 @@ def setUp(self) -> None: # Create a simple step-based generation strategy self.step_gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=5, diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index 3904adcbcdb..fe6c611b825 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -97,7 +97,7 @@ def get_sobol_mbm_generation_strategy( generation_strategy = GenerationStrategy( name=name, - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=num_sobol_trials, diff --git a/ax/benchmark/methods/sobol.py b/ax/benchmark/methods/sobol.py index b17ba7299b0..2b4c4eb44b4 100644 --- a/ax/benchmark/methods/sobol.py +++ b/ax/benchmark/methods/sobol.py @@ -17,7 +17,7 @@ def get_sobol_generation_strategy() -> GenerationStrategy: return GenerationStrategy( name="Sobol", - steps=[ + nodes=[ GenerationStep(generator=Generators.SOBOL, num_trials=-1), ], ) diff --git a/ax/early_stopping/experiment_replay.py b/ax/early_stopping/experiment_replay.py index a6b81b7226e..a2efee76346 100644 --- a/ax/early_stopping/experiment_replay.py +++ b/ax/early_stopping/experiment_replay.py @@ -95,7 +95,7 @@ def replay_experiment( # Setup a Orchestrator with a dummy gs to replay the historical experiment dummy_sobol_gs = GenerationStrategy( name="sobol", - steps=[ + nodes=[ GenerationStep(generator=Generators.SOBOL, num_trials=-1), ], ) diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index 9322daadfbe..c1b7494c9d8 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -541,7 +541,7 @@ def choose_generation_strategy_legacy( ) else: # `force_random_search` is True or we could not suggest BO model gs = GenerationStrategy( - steps=[ + nodes=[ _make_sobol_step( seed=random_seed, should_deduplicate=should_deduplicate, diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 398eb4bb547..aeaea7793a6 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -8,10 +8,10 @@ from __future__ import annotations -import warnings +from collections.abc import Sequence from copy import deepcopy from logging import Logger -from typing import Any, TypeVar +from typing import TypeVar from ax.adapter.base import Adapter from ax.core.data import Data @@ -77,27 +77,10 @@ class GenerationStrategy(Base): def __init__( self, *, - nodes: list[GenerationNode] | None = None, + nodes: Sequence[GenerationNode], name: str | None = None, - **kwargs: Any, ) -> None: self._generator_runs = [] - if not (bool(steps := kwargs.get("steps")) ^ bool(nodes)): # Steps XOR nodes - raise GenerationStrategyMisconfiguredException( - "GenerationStrategy must contain either steps or nodes. " - f"Got: nodes={nodes}, steps={steps}." - ) - - if steps: - warnings.warn( - DeprecationWarning( - "Specifying `steps` input is no longer supported. Please use " - "`nodes`. `steps` argument will be removed in early 2026." - ), - stacklevel=2, - ) - nodes = steps - self._validate_and_set_node_graph(nodes=nodes) # Set name to an explicit value ahead of time to avoid @@ -399,7 +382,6 @@ def _unset_non_persistent_state_fields(self) -> None: n._trials_from_node_cache = set() n._cached_trial_count = None - # TODO: Deprecate `steps` argument fully in Q1'26. def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None: """Initialize and validate the steps provided to this GenerationStrategy. diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 190d73b54bc..8e5af201baa 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -209,7 +209,7 @@ def test_gen_with_parameter_constraints(self) -> None: # Create generation strategy with Sobol gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=10, @@ -308,7 +308,7 @@ def setUp(self) -> None: self.step_generator_kwargs = {"silently_filter_kwargs": True} self.hss_experiment = get_hierarchical_search_space_experiment() self.sobol_GS = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( Generators.SOBOL, num_trials=-1, @@ -465,7 +465,7 @@ def _get_sobol_mbm_step_gs( ) -> GenerationStrategy: return GenerationStrategy( name="Sobol+MBM", - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=num_sobol_trials, @@ -502,7 +502,7 @@ def test_validation(self) -> None: # num_trials can be positive or -1. with self.assertRaises(UserInputError): GenerationStrategy( - steps=[ + nodes=[ GenerationStep(generator=Generators.SOBOL, num_trials=5), GenerationStep( generator=Generators.BOTORCH_MODULAR, num_trials=-10 @@ -515,7 +515,7 @@ def test_validation(self) -> None: ) with self.assertRaisesRegex(UserInputError, "Maximum concurrency should be"): GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=5, max_parallelism=-1 ), @@ -527,7 +527,7 @@ def test_custom_callables_for_models(self) -> None: with self.assertRaises(UserInputError): GenerationStrategy( # pyre-ignore [6]: Testing deprecated input. - steps=[GenerationStep(generator=get_sobol, num_trials=-1)] + nodes=[GenerationStep(generator=get_sobol, num_trials=-1)] ) def test_string_representation(self) -> None: @@ -548,7 +548,7 @@ def test_string_representation(self) -> None: "pausing_criteria=None)])", ) gs2 = GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)] ) self.assertEqual( str(gs2), @@ -601,7 +601,7 @@ def test_min_observed(self) -> None: # pyre-fixme[6]: For 1st param expected `bool` but got `Experiment`. exp = get_branin_experiment(get_branin_experiment()) gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=5, min_trials_observed=5 ), @@ -619,7 +619,7 @@ def test_one_node_with_finite_num_trials(self) -> None: # pyre-fixme[6]: For 1st param expected `bool` but got `Experiment`. exp = get_branin_experiment(get_branin_experiment()) gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep(generator=Generators.SOBOL, num_trials=5), ] ) @@ -634,7 +634,7 @@ def test_do_not_enforce_min_observations(self) -> None: # case the previous model should be used until there is enough data. exp = get_branin_experiment() gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=1, @@ -711,7 +711,7 @@ def test_sobol_MBM_strategy_keep_generating(self) -> None: def test_sobol_strategy(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=5, @@ -730,7 +730,7 @@ def test_sobol_strategy(self) -> None: def test_factorial_thompson_strategy(self, _: MagicMock) -> None: exp = get_branin_experiment() factorial_thompson_gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.FACTORIAL, num_trials=1, @@ -761,7 +761,7 @@ def test_factorial_thompson_strategy(self, _: MagicMock) -> None: def test_clone_reset(self) -> None: ftgs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep(generator=Generators.FACTORIAL, num_trials=1), GenerationStep(generator=Generators.THOMPSON, num_trials=2), ] @@ -774,7 +774,7 @@ def test_clone_reset(self) -> None: def test_kwargs_passed(self) -> None: gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=1, @@ -817,7 +817,7 @@ def test_sobol_MBM_strategy_batches(self) -> None: def test_store_experiment(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] ) self.assertIsNone(sobol_generation_strategy._experiment) sobol_generation_strategy.gen_single_trial(exp) @@ -828,7 +828,7 @@ def test_gen_single_trial_extracts_pending_observations(self) -> None: experiment when none are passed in.""" exp = get_branin_experiment() gs = GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] ) # Create a trial and mark it as running so it becomes a pending observation. trial = exp.new_trial(generator_run=gs.gen_single_trial(exp)) @@ -856,7 +856,7 @@ def test_gen_single_trial_raises_error_for_multiple_trials(self) -> None: """Test that gen_single_trial raises AxError if gen returns multiple trials.""" exp = get_branin_experiment() gs = GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] ) gr = gs.gen_single_trial(exp) # Mock gen to return multiple trials @@ -869,7 +869,7 @@ def test_gen_single_trial_raises_error_for_multiple_generator_runs(self) -> None GeneratorRuns for a single trial.""" exp = get_branin_experiment() gs = GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] ) gr = gs.gen_single_trial(exp) # Mock gen to return a single trial with multiple GeneratorRuns @@ -880,7 +880,7 @@ def test_gen_single_trial_raises_error_for_multiple_generator_runs(self) -> None def test_max_parallelism_reached(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=5, max_parallelism=1 ) @@ -983,7 +983,7 @@ def test_current_generator_run_limit(self) -> None: NUM_ROUNDS = 4 exp = get_branin_experiment() sobol_gs_with_parallelism_limits = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=NUM_INIT_TRIALS, @@ -1026,7 +1026,7 @@ def test_current_generator_run_limit_unlimited_second_step(self) -> None: NUM_ROUNDS = 4 exp = get_branin_experiment() sobol_gs_with_parallelism_limits = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=NUM_INIT_TRIALS, @@ -1281,7 +1281,7 @@ def test_gen_with_fixed_features( ) -> None: exp = get_branin_experiment() gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=1, @@ -1377,29 +1377,6 @@ def test_gs_setup_with_nodes(self) -> None: node_3, ], ) - # check error raised if provided both steps and nodes - with self.assertRaisesRegex( - GenerationStrategyMisconfiguredException, "contain either steps or nodes" - ): - GenerationStrategy( - nodes=[ - node_1, - node_3, - ], - steps=[ - GenerationStep( - generator=Generators.SOBOL, - num_trials=5, - generator_kwargs=self.step_generator_kwargs, - ), - GenerationStep( - generator=Generators.BOTORCH_MODULAR, - num_trials=-1, - generator_kwargs=self.step_generator_kwargs, - ), - ], - ) - # check error raised if two transition criterion defining a single edge have # differing `continue_trial_generation` settings with self.assertRaisesRegex( diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index 93563a08f0f..362626dd932 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -164,7 +164,7 @@ def test_default_step_criterion_setup(self) -> None: experiment = get_experiment() gs = GenerationStrategy( name="SOBOL+MBM::default", - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=3, @@ -230,7 +230,7 @@ def test_min_trials_is_met(self) -> None: experiment = self.branin_experiment gs = GenerationStrategy( name="SOBOL::default", - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=4, @@ -312,7 +312,7 @@ def test_min_trials_count_only_with_data(self) -> None: experiment = self.branin_experiment gs = GenerationStrategy( name="SOBOL::default", - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=4, @@ -508,7 +508,7 @@ def test_trials_from_node_empty(self) -> None: experiment = get_experiment() gs = GenerationStrategy( name="SOBOL::default", - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=4, diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 6cead6408ac..3d1ea48e8a7 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -213,7 +213,7 @@ def setUp(self) -> None: search_space=get_branin_search_space() ) self.two_sobol_steps_GS = GenerationStrategy( # Contrived GS to ensure - steps=[ # that `DataRequiredError` is property handled in orchestrator. + nodes=[ # that `DataRequiredError` is property handled in orchestrator. GenerationStep( # This error is raised when not enough trials generator=Generators.SOBOL, # have been observed to proceed to next num_trials=5, # geneneration step. @@ -227,7 +227,7 @@ def setUp(self) -> None: ) # GS to force the orchestrator to poll completed trials after each ran trial. self.sobol_GS_no_parallelism = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=-1, max_parallelism=1 ) @@ -2025,7 +2025,7 @@ def test_get_fitted_adapter(self) -> None: # generation strategy NUM_SOBOL = 5 generation_strategy = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=NUM_SOBOL, @@ -2297,7 +2297,7 @@ def test_it_works_with_multitask_models( self, ) -> None: gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep(generator=Generators.SOBOL, num_trials=1), GenerationStep(generator=Generators.BOTORCH_MODULAR, num_trials=1), GenerationStep( @@ -2984,7 +2984,7 @@ def setUp(self) -> None: search_space=get_branin_search_space() ) self.two_sobol_steps_GS = GenerationStrategy( # Contrived GS to ensure - steps=[ # that `DataRequiredError` is property handled in orchestrator. + nodes=[ # that `DataRequiredError` is property handled in orchestrator. GenerationStep( # This error is raised when not enough trials generator=Generators.SOBOL, # have been observed to proceed to next num_trials=5, # geneneration step. @@ -2998,7 +2998,7 @@ def setUp(self) -> None: ) # GS to force the Orchestrator to poll completed trials after each ran trial. self.sobol_GS_no_parallelism = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=-1, max_parallelism=1 ) diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 386f508165b..bc31ae36663 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -224,7 +224,7 @@ def get_client_with_simple_discrete_moo_problem( use_y2_constraint: bool, ) -> AxClient: gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep(generator=Generators.SOBOL, num_trials=3), GenerationStep( generator=Generators.BOTORCH_MODULAR, @@ -587,7 +587,7 @@ def test_optimization_complete(self, _mock_gen, _mock_sig_to_metric) -> None: def test_sobol_generation_strategy_completion(self) -> None: ax_client = get_branin_optimization( generation_strategy=GenerationStrategy( - steps=[GenerationStep(Generators.SOBOL, num_trials=3)] + nodes=[GenerationStep(Generators.SOBOL, num_trials=3)] ) ) # All Sobol trials should be able to be generated at once and optimization @@ -607,7 +607,7 @@ def test_save_and_load_generation_strategy(self) -> None: decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) generation_strategy = GenerationStrategy( - steps=[GenerationStep(Generators.SOBOL, num_trials=-1)] + nodes=[GenerationStep(Generators.SOBOL, num_trials=-1)] ) ax_client = AxClient( db_settings=db_settings, generation_strategy=generation_strategy @@ -752,7 +752,7 @@ def test_create_experiment(self) -> None: """Test basic experiment creation.""" ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): @@ -914,7 +914,7 @@ def test_create_multitype_experiment(self) -> None: """ ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] ) ) ax_client.create_experiment( @@ -1011,7 +1011,7 @@ def test_create_multitype_experiment(self) -> None: def test_create_single_objective_experiment_with_objectives_dict(self) -> None: ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): @@ -1696,7 +1696,7 @@ def test_create_moo_experiment(self) -> None: """Test basic experiment creation.""" ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): @@ -1857,7 +1857,7 @@ def test_constraint_same_as_objective(self) -> None: """Check that we do not allow constraints on the objective metric.""" ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaises(ValueError): @@ -3265,7 +3265,7 @@ def test_torch_device(self) -> None: with self.assertWarnsRegex(RuntimeWarning, "a `torch_device` were specified."): AxClient( generation_strategy=GenerationStrategy( - steps=[GenerationStep(Generators.SOBOL, num_trials=3)] + nodes=[GenerationStep(Generators.SOBOL, num_trials=3)] ), torch_device=device, ) diff --git a/ax/service/tests/test_interactive_loop.py b/ax/service/tests/test_interactive_loop.py index b417fd58426..fac29c73f12 100644 --- a/ax/service/tests/test_interactive_loop.py +++ b/ax/service/tests/test_interactive_loop.py @@ -36,7 +36,7 @@ class TestInteractiveLoop(TestCase): def setUp(self) -> None: generation_strategy = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, max_parallelism=1, num_trials=-1 ) diff --git a/ax/service/tests/test_managed_loop.py b/ax/service/tests/test_managed_loop.py index a403a0c684d..f843059073b 100644 --- a/ax/service/tests/test_managed_loop.py +++ b/ax/service/tests/test_managed_loop.py @@ -363,7 +363,7 @@ def test_custom_gs(self) -> None: """Managed loop with custom generation strategy""" strategy0 = GenerationStrategy( name="Sobol", - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)], + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)], ) loop = OptimizationLoop.with_evaluation_function( parameters=[ @@ -406,7 +406,7 @@ def test_optimize_graceful_exit_on_exception(self) -> None: total_trials=6, generation_strategy=GenerationStrategy( name="Sobol", - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=3)], + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=3)], ), ) self.assertEqual(len(exp.trials), 3) # Check that we stopped at 3 trials. @@ -425,7 +425,7 @@ def test_optimize_graceful_exit_on_exception(self) -> None: def test_annotate_exception(self, _: Mock) -> None: strategy0 = GenerationStrategy( name="Sobol", - steps=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)], + nodes=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)], ) loop = OptimizationLoop.with_evaluation_function( parameters=[ diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index 0733f65cd0c..bbb4a9ad18b 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -613,7 +613,7 @@ def test_warn_if_unpredictable_metrics(self) -> None: # Create Orchestrator and run a few trials. exp = get_branin_experiment() gs = GenerationStrategy( - steps=[ + nodes=[ GenerationStep( generator=Generators.SOBOL, num_trials=3, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 4259a071c17..b0cd780a0a4 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -1262,7 +1262,7 @@ def generation_strategy_from_json( class_decoder_registry=class_decoder_registry, ) if len(steps) > 0: - gs = GenerationStrategy(steps=steps, name=generation_strategy_json.pop("name")) + gs = GenerationStrategy(nodes=steps, name=generation_strategy_json.pop("name")) gs._curr = gs._nodes[generation_strategy_json.pop("curr_index")] else: gs = GenerationStrategy(nodes=nodes, name=generation_strategy_json.pop("name")) diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 66fb3ba28c2..ceb8bed545d 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -313,7 +313,6 @@ def sobol_gpei_generation_node_gs( sobol_mbm_GS_nodes = GenerationStrategy( name="Sobol+MBM_Nodes", nodes=[sobol_node, mbm_node], - steps=None, ) return sobol_mbm_GS_nodes