diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 8cb018f1900..f2211d8ad74 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -31,6 +31,11 @@ from ax.core.batch_trial import BatchTrial from ax.core.data import combine_data_rows_favoring_recent, Data from ax.core.derived_metric import DerivedMetric +from ax.core.experiment_design import ( + AutomationSettings, + EXPERIMENT_DESIGN_KEY, + ExperimentDesign, +) from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.llm_provider import LLMMessage @@ -155,6 +160,13 @@ def __init__( self._status: ExperimentStatus | None = None self._trials: dict[int, BaseTrial] = {} self._properties: dict[str, Any] = properties or {} + # TODO[mpolson64]: Replace with proper storage as part of the refactor. + self._design: ExperimentDesign = ( + ExperimentDesign.from_json(design_dict) + if (design_dict := self._properties.pop(EXPERIMENT_DESIGN_KEY, None)) + is not None + else ExperimentDesign() + ) # Initialize trial type to runner mapping self._default_trial_type = default_trial_type @@ -313,6 +325,99 @@ def experiment_status_from_generator_runs( return suggested_statuses.pop() + @property + def design(self) -> ExperimentDesign: + """The experiment design configuration. + + Holds a dictionary of ``AutomationSettings`` keyed by trial type. + Use the setter to replace the entire design; it validates that the + trial types in the new design match the experiment's supported + trial types. + """ + return self._design + + @design.setter + def design(self, design: ExperimentDesign) -> None: + """Set the experiment design, validating that all trial type keys + in ``design.automation_settings`` are supported by this experiment. + """ + for trial_type in design.automation_settings: + if not self.supports_trial_type(trial_type): + raise ValueError( + f"Trial type {trial_type!r} in automation_settings is not " + f"supported by this experiment. Supported trial types: " + f"{list(self._trial_type_to_runner.keys())}." + ) + self._design = design + + def get_concurrency_limit( + self, + trial_type: str | None = None, + ) -> int | None: + """Return the concurrency limit for the given trial type. + + Args: + trial_type: The trial type to look up. Defaults to the + experiment's ``default_trial_type``. + + Raises: + ValueError: If the resolved trial type has no + ``AutomationSettings`` entry. + """ + if trial_type is None: + trial_type = self._default_trial_type + settings = self._design.automation_settings.get(trial_type) + if settings is None: + if not self._design.automation_settings: + return None + raise ValueError( + f"No AutomationSettings for trial type {trial_type!r}. " + f"Available trial types: " + f"{list(self._design.automation_settings.keys())}. " + f"Pass the trial_type argument explicitly." + ) + return settings.concurrency_limit + + def set_concurrency_limit( + self, + concurrency_limit: int | None, + trial_type: str | None = None, + ) -> None: + """Set the concurrency limit for the given trial type. + + Creates an ``AutomationSettings`` entry if one does not exist for the + resolved trial type. Validates the trial type via the ``design`` + setter. + + Args: + concurrency_limit: The concurrency limit to set. + trial_type: The trial type to set for. Defaults to the + experiment's ``default_trial_type``. + """ + if trial_type is None: + trial_type = self._default_trial_type + new_settings = dict(self._design.automation_settings) + if trial_type in new_settings: + # Copy existing settings and update concurrency_limit. + existing = new_settings[trial_type] + new_settings[trial_type] = AutomationSettings( + concurrency_limit=concurrency_limit, + generation_lookahead=existing.generation_lookahead, + budget=existing.budget, + stage_after_seconds=existing.stage_after_seconds, + run_after_seconds=existing.run_after_seconds, + ) + else: + new_settings[trial_type] = AutomationSettings( + concurrency_limit=concurrency_limit, + ) + # Use the design setter to validate trial types. + self.design = ExperimentDesign( + automation_settings=new_settings, + analysis_frequency_seconds=self._design.analysis_frequency_seconds, + generation_frequency_seconds=self._design.generation_frequency_seconds, + ) + @property def search_space(self) -> SearchSpace: """The search space for this experiment. diff --git a/ax/core/experiment_design.py b/ax/core/experiment_design.py new file mode 100644 index 00000000000..cd8d1db1308 --- /dev/null +++ b/ax/core/experiment_design.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass, field +from typing import Any + +EXPERIMENT_DESIGN_KEY: str = "experiment_design" + +# Sentinel used to serialize ``None`` dict keys in JSON, which does not +# support non-string keys. +_NULL_TRIAL_TYPE_KEY: str = "null" + + +@dataclass +class AutomationSettings: + """Per-trial-type settings that govern how the Orchestrator automates + experiment execution for a given trial type. + + Multi-type experiments specify one ``AutomationSettings`` per trial type; + single-type experiments use a single entry keyed by the experiment's + default trial type (typically ``None``). + + NOTE: For now all durations are expressed in seconds for convenience; + units may change later. + + Args: + concurrency_limit: Maximum number of arms to run concurrently for + this trial type. ``None`` means unlimited. + generation_lookahead: Number of candidate arms to pre-generate even + when concurrency is reached, so users can choose whether to + deploy them instead of existing ones. ``None`` means no + lookahead. + budget: Maximum total number of arms to run for this trial type + across the entire experiment. ``None`` means unlimited. + stage_after_seconds: Seconds to wait before automatically staging + a trial after it is created. ``None`` means do not auto-stage. + 0 means auto-stage without waiting. + run_after_seconds: Seconds to wait before automatically running a + staged or candidate (if staging is not required) trial. + ``None`` means do not auto-run. 0 means auto-run + without waiting. + """ + + concurrency_limit: int | None = None + generation_lookahead: int | None = None + # NOTE: In the future, we may want a more complex notion for an overarching + # budget in the experiment, across multiple trial types. When we get there, + # we will likely want to hold that in `ExperimentDesign` and validate that + # this `budget` is `None` when that is specified. + budget: int | None = None + stage_after_seconds: int | None = None + run_after_seconds: int | None = None + + +@dataclass +class ExperimentDesign: + """Holds experiment-level execution configuration. + + Experiment-level settings (frequencies) live directly on this class, + while per-trial-type settings are stored in ``automation_settings``, + a dictionary mapping trial type names to ``AutomationSettings`` + instances. + + During prototyping, this is serialized into ``experiment.properties`` + via storage encoders. First-class storage support will follow. + + NOTE: in ax/storage/sqa_store/encoder.py, attributes of this class + are automatically serialized and stored in experiment.properties. + + Args: + automation_settings: Mapping from trial type to per-trial-type + automation configuration. + analysis_frequency_seconds: How often (in seconds) to poll trial + statuses, fetch data, and run automated analysis across all + trial types. ``None`` means no automated analysis. + generation_frequency_seconds: How often (in seconds) to trigger + automated candidate generation across all trial types. + ``None`` means no automated candidate generation. + """ + + automation_settings: dict[str | None, AutomationSettings] = field( + default_factory=dict, + ) + analysis_frequency_seconds: int | None = None + generation_frequency_seconds: int | None = None + + def to_json(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dict for storage in + ``experiment.properties``. + """ + return { + "automation_settings": { + (_NULL_TRIAL_TYPE_KEY if k is None else k): dataclasses.asdict(v) + for k, v in self.automation_settings.items() + }, + "analysis_frequency_seconds": self.analysis_frequency_seconds, + "generation_frequency_seconds": self.generation_frequency_seconds, + } + + @classmethod + def from_json(cls, json_dict: dict[str, Any]) -> ExperimentDesign: + """Deserialize from a JSON dict stored in ``experiment.properties``.""" + return cls( + automation_settings={ + (None if k == _NULL_TRIAL_TYPE_KEY else k): AutomationSettings(**v) + for k, v in json_dict["automation_settings"].items() + }, + analysis_frequency_seconds=json_dict.get("analysis_frequency_seconds"), + generation_frequency_seconds=json_dict.get("generation_frequency_seconds"), + ) diff --git a/ax/core/tests/test_experiment_design.py b/ax/core/tests/test_experiment_design.py new file mode 100644 index 00000000000..577a6c930f0 --- /dev/null +++ b/ax/core/tests/test_experiment_design.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any + +from ax.core import Experiment +from ax.core.experiment_design import ( + AutomationSettings, + EXPERIMENT_DESIGN_KEY, + ExperimentDesign, +) +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_search_space + + +class ExperimentDesignTest(TestCase): + """Tests covering ExperimentDesign, AutomationSettings, and their usage + in ax Experiment. + """ + + def setUp(self) -> None: + super().setUp() + self.experiment = Experiment( + name="test", + search_space=get_branin_search_space(), + ) + + def test_default_design_and_setting_automation_settings(self) -> None: + """Default ExperimentDesign has no automation_settings; setting design + with AutomationSettings stores them correctly. + """ + self.assertIsInstance(self.experiment.design, ExperimentDesign) + self.assertEqual(self.experiment.design.automation_settings, {}) + self.assertIsNone(self.experiment.get_concurrency_limit()) + + settings = AutomationSettings( + concurrency_limit=5, + budget=100, + ) + design = ExperimentDesign( + automation_settings={None: settings}, + analysis_frequency_seconds=3600, + ) + self.experiment.design = design + + self.assertEqual( + self.experiment.design.automation_settings[None].concurrency_limit, 5 + ) + self.assertEqual(self.experiment.design.automation_settings[None].budget, 100) + self.assertEqual(self.experiment.design.analysis_frequency_seconds, 3600) + self.assertEqual(self.experiment.get_concurrency_limit(), 5) + + def test_design_setter_validates_trial_types(self) -> None: + """Setting design with unsupported trial type raises ValueError.""" + design = ExperimentDesign( + automation_settings={ + "unsupported_type": AutomationSettings(concurrency_limit=3), + }, + ) + with self.assertRaisesRegex(ValueError, "unsupported_type"): + self.experiment.design = design + + def test_concurrency_limit_convenience_methods(self) -> None: + """set_concurrency_limit creates an entry; get_concurrency_limit + retrieves it; missing trial type with other entries present raises. + """ + self.experiment.set_concurrency_limit(concurrency_limit=10) + self.assertEqual(self.experiment.get_concurrency_limit(), 10) + self.assertIn(None, self.experiment.design.automation_settings) + + with self.assertRaisesRegex(ValueError, "No AutomationSettings"): + self.experiment.get_concurrency_limit(trial_type="nonexistent") + + # set_concurrency_limit validates trial types via the design setter. + with self.assertRaisesRegex(ValueError, "unsupported_type"): + self.experiment.set_concurrency_limit( + concurrency_limit=5, trial_type="unsupported_type" + ) + + def test_serialization_roundtrip(self) -> None: + """ExperimentDesign survives to_json / from_json and deserialization + from experiment properties. + """ + with self.subTest("to_json / from_json roundtrip"): + design = ExperimentDesign( + automation_settings={ + None: AutomationSettings(concurrency_limit=42, budget=100), + }, + analysis_frequency_seconds=3600, + generation_frequency_seconds=7200, + ) + json_dict = design.to_json() + restored = ExperimentDesign.from_json(json_dict) + self.assertEqual(restored.automation_settings[None].concurrency_limit, 42) + self.assertEqual(restored.automation_settings[None].budget, 100) + self.assertEqual(restored.analysis_frequency_seconds, 3600) + self.assertEqual(restored.generation_frequency_seconds, 7200) + + with self.subTest("deserialization from properties"): + properties: dict[str, Any] = { + EXPERIMENT_DESIGN_KEY: { + "automation_settings": { + "null": { + "concurrency_limit": 42, + "generation_lookahead": None, + "budget": 100, + "stage_after_seconds": None, + "run_after_seconds": None, + }, + }, + "analysis_frequency_seconds": 3600, + "generation_frequency_seconds": None, + }, + } + exp = Experiment( + name="test_new", + search_space=get_branin_search_space(), + properties=properties, + ) + self.assertEqual(exp.get_concurrency_limit(), 42) + self.assertEqual(exp.design.automation_settings[None].budget, 100) + self.assertEqual(exp.design.analysis_frequency_seconds, 3600) + self.assertIsNone(exp.design.generation_frequency_seconds) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index d6f960ffdcc..4d45bba2462 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -17,6 +17,7 @@ from ax.core.auxiliary import AuxiliaryExperiment from ax.core.batch_trial import BatchTrial from ax.core.data import Data +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -107,6 +108,10 @@ def analysis_card_group_to_dict(group: AnalysisCardGroup) -> dict[str, Any]: def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: """Convert Ax experiment to a dictionary.""" opt_config = experiment.optimization_config + # Serialize ExperimentDesign into properties + properties = {**experiment._properties} + if experiment.design != ExperimentDesign(): + properties[EXPERIMENT_DESIGN_KEY] = experiment.design.to_json() return { "__type": experiment.__class__.__name__, "name": experiment._name, @@ -127,7 +132,7 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: "trials": experiment.trials, "is_test": experiment.is_test, "data_by_trial": data_to_data_by_trial(data=experiment.data), - "properties": experiment._properties, + "properties": properties, "_trial_type_to_runner": experiment._trial_type_to_runner, } diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 783bfce31d9..ea37b97d9e1 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -119,6 +119,7 @@ get_default_orchestrator_options, get_derived_parameter, get_experiment_with_batch_and_single_trial, + get_experiment_with_concurrency_limit, get_experiment_with_data, get_experiment_with_map_data, get_experiment_with_map_data_type, @@ -272,6 +273,7 @@ ("Experiment", get_experiment_with_map_data_type), ("Experiment", get_branin_experiment_with_timestamp_map_metric), ("Experiment", get_experiment_with_map_data), + ("Experiment", get_experiment_with_concurrency_limit), ("FactorialMetric", get_factorial_metric), ("FixedParameter", get_fixed_parameter), ("FixedParameter", partial(get_fixed_parameter, with_dependents=True)), diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 63ab30eaa41..fffe16269ac 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -29,6 +29,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import DataType from ax.core.experiment import Experiment +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign from ax.core.generator_run import GeneratorRun from ax.core.llm_provider import LLMMessage from ax.core.metric import Metric @@ -99,6 +100,8 @@ def prepare_experiment_properties_for_storage( use this function to ensure consistent handling. """ properties = experiment._properties.copy() + if experiment.design != ExperimentDesign(): + properties[EXPERIMENT_DESIGN_KEY] = experiment.design.to_json() if ( oc := experiment.optimization_config ) is not None and oc.pruning_target_parameterization is not None: diff --git a/ax/storage/sqa_store/tests/utils.py b/ax/storage/sqa_store/tests/utils.py index 3bc0b7cf1e2..7626ed679c0 100644 --- a/ax/storage/sqa_store/tests/utils.py +++ b/ax/storage/sqa_store/tests/utils.py @@ -18,6 +18,7 @@ get_choice_parameter, get_experiment_with_batch_and_single_trial, get_experiment_with_batch_trial, + get_experiment_with_concurrency_limit, get_experiment_with_data, get_experiment_with_map_data, get_experiment_with_multi_objective, @@ -128,6 +129,12 @@ Encoder.experiment_to_sqa, Decoder.experiment_from_sqa, ), + ( + "Experiment", + get_experiment_with_concurrency_limit, + Encoder.experiment_to_sqa, + Decoder.experiment_from_sqa, + ), ( "FixedParameter", get_fixed_parameter, diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 7c143d45463..1100b1c3875 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -435,6 +435,17 @@ def get_branin_experiment( return exp +def get_experiment_with_concurrency_limit() -> Experiment: + """Return a Branin experiment with AutomationSettings on its design. + + NOTE: ExperimentDesign is still under development; this stub may change + as the design is finalized. + """ + experiment = get_branin_experiment() + experiment.set_concurrency_limit(concurrency_limit=42) + return experiment + + def get_branin_experiment_with_status_quo_trials( num_sobol_trials: int = 5, multi_objective: bool = False,