diff --git a/src/taskgraph/transforms/task.py b/src/taskgraph/transforms/task.py index d2383635..5b8de43d 100644 --- a/src/taskgraph/transforms/task.py +++ b/src/taskgraph/transforms/task.py @@ -16,11 +16,14 @@ from dataclasses import dataclass from typing import Callable, Literal, Optional, Union +import voluptuous + from taskgraph.transforms.base import TransformSequence from taskgraph.util.hash import hash_path from taskgraph.util.keyed_by import evaluate_keyed_by from taskgraph.util.schema import ( IndexSchema, + LegacySchema, OptimizationType, Schema, TaskPriority, @@ -194,6 +197,14 @@ class PayloadBuilder: def payload_builder(name, schema): + if isinstance(schema, dict): + schema = LegacySchema( + { + voluptuous.Required("implementation"): name, + voluptuous.Optional("os"): str, + } + ).extend(schema) + def wrap(func): assert name not in payload_builders, f"duplicate payload builder name {name}" payload_builders[name] = PayloadBuilder(schema, func) diff --git a/src/taskgraph/util/schema.py b/src/taskgraph/util/schema.py index 0afdb001..29df7429 100644 --- a/src/taskgraph/util/schema.py +++ b/src/taskgraph/util/schema.py @@ -48,6 +48,9 @@ def validate_schema(schema, obj, msg_prefix): # Handle plain Python types (e.g. str, int) via msgspec.convert elif isinstance(schema, type): msgspec.convert(obj, schema) + # Handle plain dict schemas (e.g. from downstream payload builders) + elif isinstance(schema, dict): + voluptuous.Schema(schema)(obj) else: raise TypeError(f"Unsupported schema type: {type(schema)}") except ( diff --git a/test/test_transforms_task.py b/test/test_transforms_task.py index ad5697f4..3d49e799 100644 --- a/test/test_transforms_task.py +++ b/test/test_transforms_task.py @@ -7,6 +7,7 @@ from pprint import pprint import pytest +import voluptuous from pytest_taskgraph import FakeParameters from taskgraph.transforms import task @@ -966,3 +967,36 @@ def test_task_priority(run_transform, graph_config, test_task): assert task_dict["task"]["priority"] == priority else: assert task_dict["task"]["priority"] == graph_config["task-priority"] + + +@pytest.fixture +def dict_schema_builder(): + @task.payload_builder("test-builder", schema={"command": [str]}) + def _builder(config, task, task_def): + pass + + yield task.payload_builders["test-builder"].schema + task.payload_builders.pop("test-builder", None) + + +@pytest.mark.parametrize( + "payload", + ( + {"implementation": "test-builder", "command": ["echo"]}, + {"implementation": "test-builder", "command": ["echo"], "os": "linux"}, + ), +) +def test_dict_schema_accepts_valid_payload(dict_schema_builder, payload): + dict_schema_builder(payload) + + +@pytest.mark.parametrize( + "payload", + ( + {"implementation": "wrong-name", "command": ["echo"]}, + {"command": ["echo"]}, + ), +) +def test_dict_schema_rejects_invalid_payload(dict_schema_builder, payload): + with pytest.raises(voluptuous.MultipleInvalid): + dict_schema_builder(payload) diff --git a/test/test_util_schema.py b/test/test_util_schema.py index 4da2e58a..b3701a89 100644 --- a/test/test_util_schema.py +++ b/test/test_util_schema.py @@ -271,6 +271,18 @@ def test_index_schema_accepts_all_fields(self): ) +class TestValidateSchemaDictHandler(unittest.TestCase): + """validate_schema must accept plain dict schemas passed + by downstream payload builders without raising TypeError.""" + + def test_dict_schema_valid(self): + validate_schema({"name": str, "count": int}, {"name": "a", "count": 1}, "pfx") + + def test_dict_schema_invalid(self): + with self.assertRaises(Exception): + validate_schema({"name": str}, {"name": 123}, "pfx") + + def test_optionally_keyed_by(): typ = optionally_keyed_by("foo", str, use_msgspec=True) assert msgspec.convert("baz", typ) == "baz"