Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from ax.core.batch_trial import BatchTrial
from ax.core.data import combine_data_rows_favoring_recent, Data
from ax.core.derived_metric import DerivedMetric
from ax.core.experiment_design import (
AutomationSettings,
EXPERIMENT_DESIGN_KEY,
ExperimentDesign,
)
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.llm_provider import LLMMessage
Expand Down Expand Up @@ -155,6 +160,13 @@ def __init__(
self._status: ExperimentStatus | None = None
self._trials: dict[int, BaseTrial] = {}
self._properties: dict[str, Any] = properties or {}
# TODO[mpolson64]: Replace with proper storage as part of the refactor.
self._design: ExperimentDesign = (
ExperimentDesign.from_json(design_dict)
if (design_dict := self._properties.pop(EXPERIMENT_DESIGN_KEY, None))
is not None
else ExperimentDesign()
)

# Initialize trial type to runner mapping
self._default_trial_type = default_trial_type
Expand Down Expand Up @@ -313,6 +325,99 @@ def experiment_status_from_generator_runs(

return suggested_statuses.pop()

@property
def design(self) -> ExperimentDesign:
"""The experiment design configuration.

Holds a dictionary of ``AutomationSettings`` keyed by trial type.
Use the setter to replace the entire design; it validates that the
trial types in the new design match the experiment's supported
trial types.
"""
return self._design

@design.setter
def design(self, design: ExperimentDesign) -> None:
"""Set the experiment design, validating that all trial type keys
in ``design.automation_settings`` are supported by this experiment.
"""
for trial_type in design.automation_settings:
if not self.supports_trial_type(trial_type):
raise ValueError(
f"Trial type {trial_type!r} in automation_settings is not "
f"supported by this experiment. Supported trial types: "
f"{list(self._trial_type_to_runner.keys())}."
)
self._design = design

def get_concurrency_limit(
self,
trial_type: str | None = None,
) -> int | None:
"""Return the concurrency limit for the given trial type.

Args:
trial_type: The trial type to look up. Defaults to the
experiment's ``default_trial_type``.

Raises:
ValueError: If the resolved trial type has no
``AutomationSettings`` entry.
"""
if trial_type is None:
trial_type = self._default_trial_type
settings = self._design.automation_settings.get(trial_type)
if settings is None:
if not self._design.automation_settings:
return None
raise ValueError(
f"No AutomationSettings for trial type {trial_type!r}. "
f"Available trial types: "
f"{list(self._design.automation_settings.keys())}. "
f"Pass the trial_type argument explicitly."
)
return settings.concurrency_limit

def set_concurrency_limit(
self,
concurrency_limit: int | None,
trial_type: str | None = None,
) -> None:
"""Set the concurrency limit for the given trial type.

Creates an ``AutomationSettings`` entry if one does not exist for the
resolved trial type. Validates the trial type via the ``design``
setter.

Args:
concurrency_limit: The concurrency limit to set.
trial_type: The trial type to set for. Defaults to the
experiment's ``default_trial_type``.
"""
if trial_type is None:
trial_type = self._default_trial_type
new_settings = dict(self._design.automation_settings)
if trial_type in new_settings:
# Copy existing settings and update concurrency_limit.
existing = new_settings[trial_type]
new_settings[trial_type] = AutomationSettings(
concurrency_limit=concurrency_limit,
generation_lookahead=existing.generation_lookahead,
budget=existing.budget,
stage_after_seconds=existing.stage_after_seconds,
run_after_seconds=existing.run_after_seconds,
)
else:
new_settings[trial_type] = AutomationSettings(
concurrency_limit=concurrency_limit,
)
# Use the design setter to validate trial types.
self.design = ExperimentDesign(
automation_settings=new_settings,
analysis_frequency_seconds=self._design.analysis_frequency_seconds,
generation_frequency_seconds=self._design.generation_frequency_seconds,
)

@property
def search_space(self) -> SearchSpace:
"""The search space for this experiment.
Expand Down
117 changes: 117 additions & 0 deletions ax/core/experiment_design.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import dataclasses
from dataclasses import dataclass, field
from typing import Any

EXPERIMENT_DESIGN_KEY: str = "experiment_design"

# Sentinel used to serialize ``None`` dict keys in JSON, which does not
# support non-string keys.
_NULL_TRIAL_TYPE_KEY: str = "null"


@dataclass
class AutomationSettings:
"""Per-trial-type settings that govern how the Orchestrator automates
experiment execution for a given trial type.

Multi-type experiments specify one ``AutomationSettings`` per trial type;
single-type experiments use a single entry keyed by the experiment's
default trial type (typically ``None``).

NOTE: For now all durations are expressed in seconds for convenience;
units may change later.

Args:
concurrency_limit: Maximum number of arms to run concurrently for
this trial type. ``None`` means unlimited.
generation_lookahead: Number of candidate arms to pre-generate even
when concurrency is reached, so users can choose whether to
deploy them instead of existing ones. ``None`` means no
lookahead.
budget: Maximum total number of arms to run for this trial type
across the entire experiment. ``None`` means unlimited.
stage_after_seconds: Seconds to wait before automatically staging
a trial after it is created. ``None`` means do not auto-stage.
0 means auto-stage without waiting.
run_after_seconds: Seconds to wait before automatically running a
staged or candidate (if staging is not required) trial.
``None`` means do not auto-run. 0 means auto-run
without waiting.
"""

concurrency_limit: int | None = None
generation_lookahead: int | None = None
# NOTE: In the future, we may want a more complex notion for an overarching
# budget in the experiment, across multiple trial types. When we get there,
# we will likely want to hold that in `ExperimentDesign` and validate that
# this `budget` is `None` when that is specified.
budget: int | None = None
stage_after_seconds: int | None = None
run_after_seconds: int | None = None


@dataclass
class ExperimentDesign:
"""Holds experiment-level execution configuration.

Experiment-level settings (frequencies) live directly on this class,
while per-trial-type settings are stored in ``automation_settings``,
a dictionary mapping trial type names to ``AutomationSettings``
instances.

During prototyping, this is serialized into ``experiment.properties``
via storage encoders. First-class storage support will follow.

NOTE: in ax/storage/sqa_store/encoder.py, attributes of this class
are automatically serialized and stored in experiment.properties.

Args:
automation_settings: Mapping from trial type to per-trial-type
automation configuration.
analysis_frequency_seconds: How often (in seconds) to poll trial
statuses, fetch data, and run automated analysis across all
trial types. ``None`` means no automated analysis.
generation_frequency_seconds: How often (in seconds) to trigger
automated candidate generation across all trial types.
``None`` means no automated candidate generation.
"""

automation_settings: dict[str | None, AutomationSettings] = field(
default_factory=dict,
)
analysis_frequency_seconds: int | None = None
generation_frequency_seconds: int | None = None

def to_json(self) -> dict[str, Any]:
"""Serialize to a JSON-compatible dict for storage in
``experiment.properties``.
"""
return {
"automation_settings": {
(_NULL_TRIAL_TYPE_KEY if k is None else k): dataclasses.asdict(v)
for k, v in self.automation_settings.items()
},
"analysis_frequency_seconds": self.analysis_frequency_seconds,
"generation_frequency_seconds": self.generation_frequency_seconds,
}

@classmethod
def from_json(cls, json_dict: dict[str, Any]) -> ExperimentDesign:
"""Deserialize from a JSON dict stored in ``experiment.properties``."""
return cls(
automation_settings={
(None if k == _NULL_TRIAL_TYPE_KEY else k): AutomationSettings(**v)
for k, v in json_dict["automation_settings"].items()
},
analysis_frequency_seconds=json_dict.get("analysis_frequency_seconds"),
generation_frequency_seconds=json_dict.get("generation_frequency_seconds"),
)
128 changes: 128 additions & 0 deletions ax/core/tests/test_experiment_design.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any

from ax.core import Experiment
from ax.core.experiment_design import (
AutomationSettings,
EXPERIMENT_DESIGN_KEY,
ExperimentDesign,
)
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_search_space


class ExperimentDesignTest(TestCase):
"""Tests covering ExperimentDesign, AutomationSettings, and their usage
in ax Experiment.
"""

def setUp(self) -> None:
super().setUp()
self.experiment = Experiment(
name="test",
search_space=get_branin_search_space(),
)

def test_default_design_and_setting_automation_settings(self) -> None:
"""Default ExperimentDesign has no automation_settings; setting design
with AutomationSettings stores them correctly.
"""
self.assertIsInstance(self.experiment.design, ExperimentDesign)
self.assertEqual(self.experiment.design.automation_settings, {})
self.assertIsNone(self.experiment.get_concurrency_limit())

settings = AutomationSettings(
concurrency_limit=5,
budget=100,
)
design = ExperimentDesign(
automation_settings={None: settings},
analysis_frequency_seconds=3600,
)
self.experiment.design = design

self.assertEqual(
self.experiment.design.automation_settings[None].concurrency_limit, 5
)
self.assertEqual(self.experiment.design.automation_settings[None].budget, 100)
self.assertEqual(self.experiment.design.analysis_frequency_seconds, 3600)
self.assertEqual(self.experiment.get_concurrency_limit(), 5)

def test_design_setter_validates_trial_types(self) -> None:
"""Setting design with unsupported trial type raises ValueError."""
design = ExperimentDesign(
automation_settings={
"unsupported_type": AutomationSettings(concurrency_limit=3),
},
)
with self.assertRaisesRegex(ValueError, "unsupported_type"):
self.experiment.design = design

def test_concurrency_limit_convenience_methods(self) -> None:
"""set_concurrency_limit creates an entry; get_concurrency_limit
retrieves it; missing trial type with other entries present raises.
"""
self.experiment.set_concurrency_limit(concurrency_limit=10)
self.assertEqual(self.experiment.get_concurrency_limit(), 10)
self.assertIn(None, self.experiment.design.automation_settings)

with self.assertRaisesRegex(ValueError, "No AutomationSettings"):
self.experiment.get_concurrency_limit(trial_type="nonexistent")

# set_concurrency_limit validates trial types via the design setter.
with self.assertRaisesRegex(ValueError, "unsupported_type"):
self.experiment.set_concurrency_limit(
concurrency_limit=5, trial_type="unsupported_type"
)

def test_serialization_roundtrip(self) -> None:
"""ExperimentDesign survives to_json / from_json and deserialization
from experiment properties.
"""
with self.subTest("to_json / from_json roundtrip"):
design = ExperimentDesign(
automation_settings={
None: AutomationSettings(concurrency_limit=42, budget=100),
},
analysis_frequency_seconds=3600,
generation_frequency_seconds=7200,
)
json_dict = design.to_json()
restored = ExperimentDesign.from_json(json_dict)
self.assertEqual(restored.automation_settings[None].concurrency_limit, 42)
self.assertEqual(restored.automation_settings[None].budget, 100)
self.assertEqual(restored.analysis_frequency_seconds, 3600)
self.assertEqual(restored.generation_frequency_seconds, 7200)

with self.subTest("deserialization from properties"):
properties: dict[str, Any] = {
EXPERIMENT_DESIGN_KEY: {
"automation_settings": {
"null": {
"concurrency_limit": 42,
"generation_lookahead": None,
"budget": 100,
"stage_after_seconds": None,
"run_after_seconds": None,
},
},
"analysis_frequency_seconds": 3600,
"generation_frequency_seconds": None,
},
}
exp = Experiment(
name="test_new",
search_space=get_branin_search_space(),
properties=properties,
)
self.assertEqual(exp.get_concurrency_limit(), 42)
self.assertEqual(exp.design.automation_settings[None].budget, 100)
self.assertEqual(exp.design.analysis_frequency_seconds, 3600)
self.assertIsNone(exp.design.generation_frequency_seconds)
Loading
Loading