Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion ax/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any, Self, TYPE_CHECKING
from dataclasses import dataclass, fields
from typing import Any, ClassVar, Self, TYPE_CHECKING

from ax.utils.common.base import Base
from ax.utils.common.sentinel import Unset
from ax.utils.common.serialization import SerializationMixin


Expand All @@ -21,9 +23,27 @@
from ax import core # noqa F401


class RunnerConfig:
@dataclass(frozen=True)
class SearchSpaceUpdateArguments:
"""Base arguments for search space updates. Override in RunnerConfig
subclasses to add runner-specific fields."""

pass

@dataclass(frozen=True)
class RunnerUpdateArguments:
"""Base arguments for general runner updates. Override in RunnerConfig
subclasses to add runner-specific fields."""

pass


class Runner(Base, SerializationMixin, ABC):
"""Abstract base class for custom runner classes"""

config_type: ClassVar[type[RunnerConfig]] = RunnerConfig

@property
def staging_required(self) -> bool:
"""Whether the trial goes to staged or running state once deployed."""
Expand Down Expand Up @@ -145,6 +165,69 @@ def stop(
f"{self.__class__.__name__} does not implement a `stop` method."
)

def on_search_space_update(
self,
search_space: core.search_space.SearchSpace,
arguments: RunnerConfig.SearchSpaceUpdateArguments | None = None,
) -> None:
"""Called after the experiment's search space has been updated.

Validates the proposed runner-side changes, then applies them.
Subclasses should override ``_validate_on_search_space_update``
to add validation logic.

Args:
search_space: The updated search space.
arguments: Optional typed arguments carrying runner-specific
data. Subclasses should define a ``RunnerConfig`` subclass
with a nested ``SearchSpaceUpdateArguments`` dataclass to
declare supported fields.
"""
if arguments is not None:
UpdateArgsClass = type(self).config_type.SearchSpaceUpdateArguments
if not isinstance(arguments, UpdateArgsClass):
raise TypeError(
f"Expected {UpdateArgsClass.__name__}, "
f"got {type(arguments).__name__}."
)
self._validate_on_search_space_update(search_space, arguments)
if arguments is not None:
self._set_attributes(arguments)

def _validate_on_search_space_update(
self,
search_space: core.search_space.SearchSpace,
arguments: RunnerConfig.SearchSpaceUpdateArguments | None = None,
) -> None:
"""Override in subclasses to reject invalid search space updates
before the runner's state is modified. The runner's attributes still
hold their old values at this point; use the ``arguments`` to determine
the proposed new state.

Args:
search_space: The already-updated search space.
arguments: The proposed runner-side changes, if any.
"""
pass

def _set_attributes(
self,
arguments: (
RunnerConfig.RunnerUpdateArguments | RunnerConfig.SearchSpaceUpdateArguments
),
) -> None:
"""Apply dataclass field values to self, skipping UNSET fields.

Shared by ``update`` and ``on_search_space_update`` to ensure both
follow the same validate-then-mutate pattern.
"""
for field in fields(arguments):
value = getattr(arguments, field.name)
if isinstance(value, Unset):
continue
attr_name = field.metadata.get("attr", field.name)
setattr(self, attr_name, value)

def clone(self) -> Self:
"""Create a copy of this Runner."""
cls = type(self)
Expand Down
90 changes: 89 additions & 1 deletion ax/core/tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,57 @@

# pyre-strict

from dataclasses import dataclass, field
from unittest import mock

from ax.core.base_trial import BaseTrial
from ax.core.runner import Runner
from ax.core.runner import Runner, RunnerConfig
from ax.core.search_space import SearchSpace
from ax.utils.common.sentinel import Unset, UNSET
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_batch_trial, get_trial
from pyre_extensions import override


class DummyRunner(Runner):
def run(self, trial: BaseTrial) -> dict[str, str]:
return {"metadatum": f"value_for_trial_{trial.index}"}


class DummyRunnerConfig(RunnerConfig):
@dataclass(frozen=True)
class RunnerUpdateArguments(RunnerConfig.RunnerUpdateArguments):
name: str | None | Unset = UNSET
count: int | None | Unset = UNSET
tag: str | None | Unset = field(default=UNSET, metadata={"attr": "_tag"})

@dataclass(frozen=True)
class SearchSpaceUpdateArguments(RunnerConfig.SearchSpaceUpdateArguments):
label: str | None | Unset = UNSET


class DummyUpdatableRunner(Runner):
config_type = DummyRunnerConfig

def __init__(self) -> None:
self.name: str | None = "original"
self.count: int | None = 5
self.label: str | None = "default_label"
self._tag: str | None = "old_tag"
self.events: list[str] = []

@override
def _validate_on_search_space_update(
self,
search_space: SearchSpace,
arguments: RunnerConfig.SearchSpaceUpdateArguments | None = None,
) -> None:
self.events.append(f"_validate_on_search_space_update:label={self.label}")

def run(self, trial: BaseTrial) -> dict[str, str]:
return {}


class RunnerTest(TestCase):
def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -56,5 +94,55 @@ def test_base_runner_poll_exception(self) -> None:
def test_poll_available_capacity(self) -> None:
self.assertEqual(self.dummy_runner.poll_available_capacity(), -1)

def test_on_search_space_update_default_is_noop(self) -> None:
self.assertIsNone(
self.dummy_runner.on_search_space_update(search_space=mock.Mock())
)

def test_run_metadata_report_keys(self) -> None:
self.assertEqual(self.dummy_runner.run_metadata_report_keys, [])

def test_set_attributes(self) -> None:
"""_set_attributes applies non-UNSET fields, skips UNSET ones,
and respects 'attr' metadata for private attribute names."""
runner = DummyUpdatableRunner()
with self.subTest("applies non-UNSET, skips UNSET"):
runner._set_attributes(
DummyRunnerConfig.RunnerUpdateArguments(name="new_name")
)
self.assertEqual(runner.name, "new_name")
self.assertEqual(runner.count, 5)
with self.subTest("respects attr metadata"):
runner._set_attributes(
DummyRunnerConfig.RunnerUpdateArguments(tag="new_tag")
)
self.assertEqual(runner._tag, "new_tag")

def test_on_search_space_update(self) -> None:
"""on_search_space_update validates before applying, applies fields,
rejects wrong argument types, and is a no-op without arguments."""
runner = DummyUpdatableRunner()
ss = mock.Mock()
with self.subTest("no-op without arguments"):
runner.on_search_space_update(search_space=ss)
self.assertEqual(runner.label, "default_label")
with self.subTest("validates before applying"):
runner.on_search_space_update(
search_space=ss,
arguments=DummyRunnerConfig.SearchSpaceUpdateArguments(label="updated"),
)
self.assertEqual(
runner.events,
[
"_validate_on_search_space_update:label=default_label",
"_validate_on_search_space_update:label=default_label",
],
)
self.assertEqual(runner.label, "updated")
with self.subTest("rejects wrong argument type"):
with self.assertRaisesRegex(TypeError, "Expected"):
runner.on_search_space_update(
search_space=ss,
# pyre-ignore[6]: Intentionally passing wrong type.
arguments=RunnerConfig.RunnerUpdateArguments(),
)
Loading
Loading