diff --git a/ax/core/runner.py b/ax/core/runner.py index 39d71ed99f8..dd1270c4646 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -10,9 +10,11 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, Self, TYPE_CHECKING +from dataclasses import dataclass, fields +from typing import Any, ClassVar, Self, TYPE_CHECKING from ax.utils.common.base import Base +from ax.utils.common.sentinel import Unset from ax.utils.common.serialization import SerializationMixin @@ -21,9 +23,27 @@ from ax import core # noqa F401 +class RunnerConfig: + @dataclass(frozen=True) + class SearchSpaceUpdateArguments: + """Base arguments for search space updates. Override in RunnerConfig + subclasses to add runner-specific fields.""" + + pass + + @dataclass(frozen=True) + class RunnerUpdateArguments: + """Base arguments for general runner updates. Override in RunnerConfig + subclasses to add runner-specific fields.""" + + pass + + class Runner(Base, SerializationMixin, ABC): """Abstract base class for custom runner classes""" + config_type: ClassVar[type[RunnerConfig]] = RunnerConfig + @property def staging_required(self) -> bool: """Whether the trial goes to staged or running state once deployed.""" @@ -145,6 +165,69 @@ def stop( f"{self.__class__.__name__} does not implement a `stop` method." ) + def on_search_space_update( + self, + search_space: core.search_space.SearchSpace, + arguments: RunnerConfig.SearchSpaceUpdateArguments | None = None, + ) -> None: + """Called after the experiment's search space has been updated. + + Validates the proposed runner-side changes, then applies them. + Subclasses should override ``_validate_on_search_space_update`` + to add validation logic. + + Args: + search_space: The updated search space. + arguments: Optional typed arguments carrying runner-specific + data. Subclasses should define a ``RunnerConfig`` subclass + with a nested ``SearchSpaceUpdateArguments`` dataclass to + declare supported fields. + """ + if arguments is not None: + UpdateArgsClass = type(self).config_type.SearchSpaceUpdateArguments + if not isinstance(arguments, UpdateArgsClass): + raise TypeError( + f"Expected {UpdateArgsClass.__name__}, " + f"got {type(arguments).__name__}." + ) + self._validate_on_search_space_update(search_space, arguments) + if arguments is not None: + self._set_attributes(arguments) + + def _validate_on_search_space_update( + self, + search_space: core.search_space.SearchSpace, + arguments: RunnerConfig.SearchSpaceUpdateArguments | None = None, + ) -> None: + """Override in subclasses to reject invalid search space updates + before the runner's state is modified. The runner's attributes still + hold their old values at this point; use the ``arguments`` to determine + the proposed new state. + + Args: + search_space: The already-updated search space. + arguments: The proposed runner-side changes, if any. + """ + pass + + def _set_attributes( + self, + arguments: ( + RunnerConfig.RunnerUpdateArguments | RunnerConfig.SearchSpaceUpdateArguments + ), + ) -> None: + """Apply dataclass field values to self, skipping UNSET fields. + + Shared by ``update`` and ``on_search_space_update`` to ensure both + follow the same validate-then-mutate pattern. + """ + for field in fields(arguments): + value = getattr(arguments, field.name) + if isinstance(value, Unset): + continue + attr_name = field.metadata.get("attr", field.name) + setattr(self, attr_name, value) + def clone(self) -> Self: """Create a copy of this Runner.""" cls = type(self) diff --git a/ax/core/tests/test_runner.py b/ax/core/tests/test_runner.py index 40a34d33276..12416522275 100644 --- a/ax/core/tests/test_runner.py +++ b/ax/core/tests/test_runner.py @@ -6,12 +6,16 @@ # pyre-strict +from dataclasses import dataclass, field from unittest import mock from ax.core.base_trial import BaseTrial -from ax.core.runner import Runner +from ax.core.runner import Runner, RunnerConfig +from ax.core.search_space import SearchSpace +from ax.utils.common.sentinel import Unset, UNSET from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_batch_trial, get_trial +from pyre_extensions import override class DummyRunner(Runner): @@ -19,6 +23,40 @@ def run(self, trial: BaseTrial) -> dict[str, str]: return {"metadatum": f"value_for_trial_{trial.index}"} +class DummyRunnerConfig(RunnerConfig): + @dataclass(frozen=True) + class RunnerUpdateArguments(RunnerConfig.RunnerUpdateArguments): + name: str | None | Unset = UNSET + count: int | None | Unset = UNSET + tag: str | None | Unset = field(default=UNSET, metadata={"attr": "_tag"}) + + @dataclass(frozen=True) + class SearchSpaceUpdateArguments(RunnerConfig.SearchSpaceUpdateArguments): + label: str | None | Unset = UNSET + + +class DummyUpdatableRunner(Runner): + config_type = DummyRunnerConfig + + def __init__(self) -> None: + self.name: str | None = "original" + self.count: int | None = 5 + self.label: str | None = "default_label" + self._tag: str | None = "old_tag" + self.events: list[str] = [] + + @override + def _validate_on_search_space_update( + self, + search_space: SearchSpace, + arguments: RunnerConfig.SearchSpaceUpdateArguments | None = None, + ) -> None: + self.events.append(f"_validate_on_search_space_update:label={self.label}") + + def run(self, trial: BaseTrial) -> dict[str, str]: + return {} + + class RunnerTest(TestCase): def setUp(self) -> None: super().setUp() @@ -56,5 +94,55 @@ def test_base_runner_poll_exception(self) -> None: def test_poll_available_capacity(self) -> None: self.assertEqual(self.dummy_runner.poll_available_capacity(), -1) + def test_on_search_space_update_default_is_noop(self) -> None: + self.assertIsNone( + self.dummy_runner.on_search_space_update(search_space=mock.Mock()) + ) + def test_run_metadata_report_keys(self) -> None: self.assertEqual(self.dummy_runner.run_metadata_report_keys, []) + + def test_set_attributes(self) -> None: + """_set_attributes applies non-UNSET fields, skips UNSET ones, + and respects 'attr' metadata for private attribute names.""" + runner = DummyUpdatableRunner() + with self.subTest("applies non-UNSET, skips UNSET"): + runner._set_attributes( + DummyRunnerConfig.RunnerUpdateArguments(name="new_name") + ) + self.assertEqual(runner.name, "new_name") + self.assertEqual(runner.count, 5) + with self.subTest("respects attr metadata"): + runner._set_attributes( + DummyRunnerConfig.RunnerUpdateArguments(tag="new_tag") + ) + self.assertEqual(runner._tag, "new_tag") + + def test_on_search_space_update(self) -> None: + """on_search_space_update validates before applying, applies fields, + rejects wrong argument types, and is a no-op without arguments.""" + runner = DummyUpdatableRunner() + ss = mock.Mock() + with self.subTest("no-op without arguments"): + runner.on_search_space_update(search_space=ss) + self.assertEqual(runner.label, "default_label") + with self.subTest("validates before applying"): + runner.on_search_space_update( + search_space=ss, + arguments=DummyRunnerConfig.SearchSpaceUpdateArguments(label="updated"), + ) + self.assertEqual( + runner.events, + [ + "_validate_on_search_space_update:label=default_label", + "_validate_on_search_space_update:label=default_label", + ], + ) + self.assertEqual(runner.label, "updated") + with self.subTest("rejects wrong argument type"): + with self.assertRaisesRegex(TypeError, "Expected"): + runner.on_search_space_update( + search_space=ss, + # pyre-ignore[6]: Intentionally passing wrong type. + arguments=RunnerConfig.RunnerUpdateArguments(), + ) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 50c8045f0ab..94c72317752 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -9,7 +9,8 @@ import json import logging import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable, Generator, Sequence +from contextlib import contextmanager from functools import partial from logging import Logger from typing import Any, TypeVar @@ -31,7 +32,7 @@ from ax.core.observation import ObservationFeatures from ax.core.parameter import RangeParameter from ax.core.parameter_constraint import ParameterConstraint -from ax.core.runner import Runner +from ax.core.runner import Runner, RunnerConfig from ax.core.trial import Trial from ax.core.trial_status import TrialStatus from ax.core.types import ( @@ -561,6 +562,7 @@ def add_parameters( backfill_values: TParameterization, status_quo_values: TParameterization | None = None, parameter_constraints: list[str] | None = None, + runner_updates: RunnerConfig.SearchSpaceUpdateArguments | None = None, ) -> None: """ Add new parameters to the experiment's search space. This allows extending @@ -584,6 +586,8 @@ def add_parameters( parameter constraints to add (e.g., ``"x1 + x2 <= 5.0"`` or ``"x1 <= x2"``). May reference both existing and new parameters. + runner_updates: Optional typed context to pass to the runner's + ``on_search_space_update`` hook. """ parameters_to_add = [ parameter_from_config(parameter_config) for parameter_config in parameters @@ -619,16 +623,20 @@ def add_parameters( for c in parameter_constraints ] - self.experiment.add_parameters_to_search_space( - parameters=parameters_to_add, - status_quo_values=status_quo_values or backfill_values, - parameter_constraints=typed_parameter_constraints or None, - ) + with self._with_runner_on_search_space_update( + runner_updates=runner_updates, + ): + self.experiment.add_parameters_to_search_space( + parameters=parameters_to_add, + status_quo_values=status_quo_values or backfill_values, + parameter_constraints=typed_parameter_constraints or None, + ) self._save_experiment_to_db_if_possible(experiment=self.experiment) def disable_parameters( self, default_parameter_values: TParameterization, + runner_updates: RunnerConfig.SearchSpaceUpdateArguments | None = None, ) -> None: """ Disable parameters in the experiment. This allows narrowing the search space @@ -643,15 +651,21 @@ def disable_parameters( default_parameter_values: Fixed values to use for the disabled parameters in all future trials. These values will be used for the parameter in all subsequent trials. + runner_updates: Optional typed context to pass to the runner's + ``on_search_space_update`` hook. """ - self.experiment.disable_parameters_in_search_space( - default_parameter_values=default_parameter_values - ) + with self._with_runner_on_search_space_update( + runner_updates=runner_updates, + ): + self.experiment.disable_parameters_in_search_space( + default_parameter_values=default_parameter_values + ) self._save_experiment_to_db_if_possible(experiment=self.experiment) def update_parameters( self, parameters: Sequence[RangeParameterConfig], + runner_updates: RunnerConfig.SearchSpaceUpdateArguments | None = None, ) -> None: """Update parameters in the experiment's search space. @@ -661,6 +675,8 @@ def update_parameters( Args: parameters: A sequence of ``RangeParameterConfig`` to update in the search space. + runner_updates: Optional typed context to pass to the runner's + ``on_search_space_update`` hook. Raises: UserInputError: If a parameter is not found in the search space or @@ -677,10 +693,29 @@ def update_parameters( f"Parameter {parameter.name} is not a RangeParameter." ) - for parameter in parameters: - search_space.update_parameter(parameter=parameter_from_config(parameter)) + parameters_to_update = [ + parameter_from_config(parameter) for parameter in parameters + ] + with self._with_runner_on_search_space_update( + runner_updates=runner_updates, + ): + for parameter in parameters_to_update: + search_space.update_parameter(parameter=parameter) self._save_experiment_to_db_if_possible(experiment=self.experiment) + @property + def runner_config_type(self) -> type[RunnerConfig] | None: + """The ``RunnerConfig`` subclass declared by the experiment's runner. + + Returns ``None`` if the experiment has no runner. Useful for + discovering the typed context a runner expects:: + + ctx_cls = client.runner_config_type.SearchSpaceUpdateArguments + """ + if self.experiment.runner is None: + return None + return self.experiment.runner.config_type + @retry_on_exception( logger=logger, exception_types=(RuntimeError,), @@ -1504,10 +1539,12 @@ def load_from_json_file( def to_json_snapshot( self, - encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] - | None = None, - class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] - | None = None, + encoder_registry: ( + dict[type[Any], Callable[[Any], dict[str, Any]]] | None + ) = None, + class_encoder_registry: ( + dict[type[Any], Callable[[Any], dict[str, Any]]] | None + ) = None, ) -> dict[str, Any]: """Serialize this `AxClient` to JSON to be able to interrupt and restart optimization and save it to file by the provided path. @@ -1869,6 +1906,36 @@ def _validate_all_required_metrics_present( missing_metrics = required_metrics - provided_metrics return not missing_metrics + @contextmanager + def _with_runner_on_search_space_update( + self, + runner_updates: RunnerConfig.SearchSpaceUpdateArguments | None = None, + ) -> Generator[None, None, None]: + """Context manager that notifies the runner after search space mutations. + + On enter, snapshots (clones) the current search space. The caller + performs the actual search space mutations inside the ``with`` block. + On exit, if the experiment has a runner, calls + ``runner.on_search_space_update`` with the now-mutated real search + space. If ``on_search_space_update`` raises, the search space is + restored from the snapshot and the exception is re-raised. + + If the experiment has no runner, this is a no-op wrapper. + """ + runner = self.experiment.runner + search_space_snapshot = self.experiment.search_space.clone() + yield + if runner is None: + return + try: + runner.on_search_space_update( + search_space=self.experiment.search_space, + arguments=runner_updates, + ) + except Exception: + self.experiment.search_space = search_space_snapshot + raise + # ------------------------------ Validators. ------------------------------- def _validate_early_stopping_strategy( diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 386f508165b..74cb832f169 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -40,6 +40,7 @@ RangeParameter, ) from ax.core.parameter_constraint import ParameterConstraint +from ax.core.runner import RunnerConfig from ax.core.trial import Trial from ax.core.types import ( ComparisonOp, @@ -1692,6 +1693,265 @@ def test_disable_parameters(self) -> None: self.assertTrue(param_x3.is_disabled) self.assertEqual(param_x3.default_value, "b") + def test_search_space_mutations_call_runner_hook(self) -> None: + """Test that add/disable/update_parameters call runner.on_search_space_update + with correct args.""" + ax_client = AxClient() + ax_client.create_experiment( + name="test_experiment", + parameters=[ + { + "name": "x1", + "type": "range", + "bounds": [0.0, 1.0], + "value_type": "float", + }, + { + "name": "x2", + "type": "range", + "bounds": [1, 10], + "value_type": "int", + }, + ], + is_test=True, + immutable_search_space_and_opt_config=False, + ) + mock_runner = Mock() + ax_client.experiment.runner = mock_runner + + with self.subTest("add_parameters_with_runner_updates"): + mock_runner.reset_mock() + ctx = RunnerConfig.SearchSpaceUpdateArguments() + ax_client.add_parameters( + parameters=[ + RangeParameterConfig( + name="x3", + bounds=(0.0, 5.0), + parameter_type="float", + ), + ], + backfill_values={"x3": 1.0}, + runner_updates=ctx, + ) + mock_runner.on_search_space_update.assert_called_once() + call_kwargs = mock_runner.on_search_space_update.call_args[1] + self.assertIsNotNone(call_kwargs["search_space"]) + self.assertIs(call_kwargs["arguments"], ctx) + + with self.subTest("disable_parameters_with_runner_updates"): + mock_runner.reset_mock() + ctx = RunnerConfig.SearchSpaceUpdateArguments() + ax_client.disable_parameters( + default_parameter_values={"x2": 5}, + runner_updates=ctx, + ) + mock_runner.on_search_space_update.assert_called_once() + call_kwargs = mock_runner.on_search_space_update.call_args[1] + self.assertIsNotNone(call_kwargs["search_space"]) + self.assertIs(call_kwargs["arguments"], ctx) + + with self.subTest("update_parameters_with_runner_updates"): + mock_runner.reset_mock() + ctx = RunnerConfig.SearchSpaceUpdateArguments() + ax_client.update_parameters( + parameters=[ + RangeParameterConfig( + name="x1", + bounds=(0.0, 2.0), + parameter_type="float", + ), + ], + runner_updates=ctx, + ) + mock_runner.on_search_space_update.assert_called_once() + call_kwargs = mock_runner.on_search_space_update.call_args[1] + self.assertIsNotNone(call_kwargs["search_space"]) + self.assertIs(call_kwargs["arguments"], ctx) + + with self.subTest("update_parameters_without_runner_updates"): + mock_runner.reset_mock() + ax_client.update_parameters( + parameters=[ + RangeParameterConfig( + name="x1", + bounds=(0.0, 3.0), + parameter_type="float", + ), + ], + ) + mock_runner.on_search_space_update.assert_called_once() + call_kwargs = mock_runner.on_search_space_update.call_args[1] + self.assertIsNone(call_kwargs["arguments"]) + + with self.subTest("no_runner_no_error"): + ax_client.experiment.runner = None + ax_client.update_parameters( + parameters=[ + RangeParameterConfig( + name="x1", + bounds=(0.0, 4.0), + parameter_type="float", + ), + ], + ) # Should not raise + + def test_with_runner_on_search_space_update_no_runner(self) -> None: + """When there is no runner, the context manager is a no-op and mutations + happen normally.""" + ax_client = AxClient() + ax_client.create_experiment( + name="test_experiment", + parameters=[ + { + "name": "x1", + "type": "range", + "bounds": [0.0, 1.0], + "value_type": "float", + }, + ], + is_test=True, + immutable_search_space_and_opt_config=False, + ) + self.assertIsNone(ax_client.experiment.runner) + + with ax_client._with_runner_on_search_space_update(): + ax_client.experiment.search_space.update_parameter( + RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=5.0, + ) + ) + + updated_param = assert_is_instance( + ax_client.experiment.search_space.parameters["x1"], RangeParameter + ) + self.assertEqual(updated_param.upper, 5.0) + + def test_with_runner_on_search_space_update_success(self) -> None: + """When the runner's on_search_space_update succeeds, mutations persist + and the runner receives the mutated search space with runner_updates.""" + ax_client = AxClient() + ax_client.create_experiment( + name="test_experiment", + parameters=[ + { + "name": "x1", + "type": "range", + "bounds": [0.0, 1.0], + "value_type": "float", + }, + ], + is_test=True, + immutable_search_space_and_opt_config=False, + ) + mock_runner = Mock() + ax_client.experiment.runner = mock_runner + ctx = RunnerConfig.SearchSpaceUpdateArguments() + + with ax_client._with_runner_on_search_space_update(runner_updates=ctx): + ax_client.experiment.search_space.update_parameter( + RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=10.0, + ) + ) + + mock_runner.on_search_space_update.assert_called_once() + call_kwargs = mock_runner.on_search_space_update.call_args[1] + self.assertIs(call_kwargs["arguments"], ctx) + passed_ss = call_kwargs["search_space"] + passed_param = assert_is_instance(passed_ss.parameters["x1"], RangeParameter) + self.assertEqual(passed_param.upper, 10.0) + + persisted_param = assert_is_instance( + ax_client.experiment.search_space.parameters["x1"], RangeParameter + ) + self.assertEqual(persisted_param.upper, 10.0) + + def test_with_runner_on_search_space_update_rollback_on_error(self) -> None: + """When on_search_space_update raises, the search space is rolled back + to its pre-mutation state.""" + ax_client = AxClient() + ax_client.create_experiment( + name="test_experiment", + parameters=[ + { + "name": "x1", + "type": "range", + "bounds": [0.0, 1.0], + "value_type": "float", + }, + ], + is_test=True, + immutable_search_space_and_opt_config=False, + ) + mock_runner = Mock() + mock_runner.on_search_space_update.side_effect = ValueError( + "runner rejected update" + ) + ax_client.experiment.runner = mock_runner + + original_param = assert_is_instance( + ax_client.experiment.search_space.parameters["x1"], RangeParameter + ) + self.assertEqual(original_param.upper, 1.0) + + with self.assertRaisesRegex(ValueError, "runner rejected update"): + with ax_client._with_runner_on_search_space_update(): + ax_client.experiment.search_space.update_parameter( + RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=99.0, + ) + ) + + restored_param = assert_is_instance( + ax_client.experiment.search_space.parameters["x1"], RangeParameter + ) + self.assertEqual(restored_param.upper, 1.0) + + def test_with_runner_on_search_space_update_receives_mutated_search_space( + self, + ) -> None: + """The runner receives the real (mutated) search space, not a clone or + the original.""" + ax_client = AxClient() + ax_client.create_experiment( + name="test_experiment", + parameters=[ + { + "name": "x1", + "type": "range", + "bounds": [0.0, 1.0], + "value_type": "float", + }, + ], + is_test=True, + immutable_search_space_and_opt_config=False, + ) + mock_runner = Mock() + ax_client.experiment.runner = mock_runner + + with ax_client._with_runner_on_search_space_update(runner_updates=None): + ax_client.experiment.search_space.update_parameter( + RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=42.0, + ) + ) + + call_kwargs = mock_runner.on_search_space_update.call_args[1] + self.assertIs(call_kwargs["search_space"], ax_client.experiment.search_space) + self.assertIsNone(call_kwargs["arguments"]) + def test_create_moo_experiment(self) -> None: """Test basic experiment creation.""" ax_client = AxClient( diff --git a/ax/utils/common/sentinel.py b/ax/utils/common/sentinel.py new file mode 100644 index 00000000000..2b8a7ca3f39 --- /dev/null +++ b/ax/utils/common/sentinel.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +class Unset: + """Sentinel type for distinguishing "not provided" from an explicit ``None``. + + Use the module-level ``UNSET`` instance as the default value for + optional fields where ``None`` is a valid, meaningful value and a + separate "not set" state is needed. + """ + + def __repr__(self) -> str: + return "UNSET" + + +UNSET: Unset = Unset()