From 2e3bb3376826ff1ee9c2f3f1a54c9b14db99a32c Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:11:19 -0400 Subject: [PATCH] Add smithy-xml package --- .../smithy-core/src/smithy_core/traits.py | 40 ++ packages/smithy-xml/CHANGELOG.md | 1 + packages/smithy-xml/NOTICE | 1 + packages/smithy-xml/README.md | 4 + packages/smithy-xml/pyproject.toml | 50 ++ .../smithy-xml/src/smithy_xml/__init__.py | 55 ++ .../src/smithy_xml/_private/__init__.py | 2 + .../src/smithy_xml/_private/deserializers.py | 378 ++++++++++ .../src/smithy_xml/_private/readers.py | 55 ++ .../src/smithy_xml/_private/serializers.py | 372 ++++++++++ packages/smithy-xml/src/smithy_xml/py.typed | 1 + .../smithy-xml/src/smithy_xml/settings.py | 16 + packages/smithy-xml/tests/__init__.py | 2 + packages/smithy-xml/tests/unit/__init__.py | 662 ++++++++++++++++++ .../tests/unit/test_deserializers.py | 145 ++++ .../smithy-xml/tests/unit/test_serializers.py | 209 ++++++ pyproject.toml | 1 + uv.lock | 11 + 18 files changed, 2005 insertions(+) create mode 100644 packages/smithy-xml/CHANGELOG.md create mode 100644 packages/smithy-xml/NOTICE create mode 100644 packages/smithy-xml/README.md create mode 100644 packages/smithy-xml/pyproject.toml create mode 100644 packages/smithy-xml/src/smithy_xml/__init__.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/__init__.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/deserializers.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/readers.py create mode 100644 packages/smithy-xml/src/smithy_xml/_private/serializers.py create mode 100644 packages/smithy-xml/src/smithy_xml/py.typed create mode 100644 packages/smithy-xml/src/smithy_xml/settings.py create mode 100644 packages/smithy-xml/tests/__init__.py create mode 100644 packages/smithy-xml/tests/unit/__init__.py create mode 100644 packages/smithy-xml/tests/unit/test_deserializers.py create mode 100644 packages/smithy-xml/tests/unit/test_serializers.py diff --git a/packages/smithy-core/src/smithy_core/traits.py b/packages/smithy-core/src/smithy_core/traits.py index d7dfd22cf..5b64c82f8 100644 --- a/packages/smithy-core/src/smithy_core/traits.py +++ b/packages/smithy-core/src/smithy_core/traits.py @@ -350,3 +350,43 @@ def name(self) -> str: @property def scheme(self) -> str | None: return self.document_value.get("scheme") # type: ignore + + +@dataclass(init=False, frozen=True) +class XmlNameTrait(Trait, id=ShapeID("smithy.api#xmlName")): + document_value: str | None = None + + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def value(self) -> str: + return self.document_value # type: ignore + + +@dataclass(init=False, frozen=True) +class XmlNamespaceTrait(Trait, id=ShapeID("smithy.api#xmlNamespace")): + def __post_init__(self): + assert isinstance(self.document_value, Mapping) + assert isinstance(self.document_value["uri"], str) + assert isinstance(self.document_value.get("prefix"), str | None) + + @property + def uri(self) -> str: + return self.document_value["uri"] # type: ignore + + @property + def prefix(self) -> str | None: + return self.document_value.get("prefix") # type: ignore + + +@dataclass(init=False, frozen=True) +class XmlFlattenedTrait(Trait, id=ShapeID("smithy.api#xmlFlattened")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class XmlAttributeTrait(Trait, id=ShapeID("smithy.api#xmlAttribute")): + def __post_init__(self): + assert self.document_value is None diff --git a/packages/smithy-xml/CHANGELOG.md b/packages/smithy-xml/CHANGELOG.md new file mode 100644 index 000000000..5ddad421e --- /dev/null +++ b/packages/smithy-xml/CHANGELOG.md @@ -0,0 +1 @@ +# Changelog \ No newline at end of file diff --git a/packages/smithy-xml/NOTICE b/packages/smithy-xml/NOTICE new file mode 100644 index 000000000..616fc5889 --- /dev/null +++ b/packages/smithy-xml/NOTICE @@ -0,0 +1 @@ +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/packages/smithy-xml/README.md b/packages/smithy-xml/README.md new file mode 100644 index 000000000..bcd7da956 --- /dev/null +++ b/packages/smithy-xml/README.md @@ -0,0 +1,4 @@ +# smithy-xml + +This package provides generic XML serialization and deserialization support +for Smithy clients and servers. diff --git a/packages/smithy-xml/pyproject.toml b/packages/smithy-xml/pyproject.toml new file mode 100644 index 000000000..eca3b80f9 --- /dev/null +++ b/packages/smithy-xml/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "smithy-xml" +dynamic = ["version"] +requires-python = ">=3.12" +authors = [ + {name = "Amazon Web Services"}, +] +description = "XML serialization and deserialization support for Smithy tooling." +readme = "README.md" +license = {text = "Apache License 2.0"} +keywords = ["smithy", "sdk", "xml"] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Software Development :: Libraries" +] +dependencies = [ + "smithy-core", +] + +[project.urls] +"Changelog" = "https://github.com/smithy-lang/smithy-python/blob/develop/packages/smithy-xml/CHANGELOG.md" +"Code" = "https://github.com/smithy-lang/smithy-python/blob/develop/packages/smithy-xml/" +"Issue tracker" = "https://github.com/smithy-lang/smithy-python/issues" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.version] +path = "src/smithy_xml/__init__.py" + +[tool.hatch.build] +exclude = [ + "tests", +] + +[tool.ruff] +src = ["src"] diff --git a/packages/smithy-xml/src/smithy_xml/__init__.py b/packages/smithy-xml/src/smithy_xml/__init__.py new file mode 100644 index 000000000..9616aa8ef --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/__init__.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from io import BytesIO +from xml.etree.ElementTree import iterparse + +from smithy_core.codecs import Codec +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.interfaces import BytesReader, BytesWriter +from smithy_core.serializers import ShapeSerializer +from smithy_core.types import TimestampFormat + +from ._private.deserializers import XMLShapeDeserializer as _XMLShapeDeserializer +from ._private.readers import XMLEventReader as _XMLEventReader +from ._private.serializers import XMLShapeSerializer as _XMLShapeSerializer +from .settings import XMLSettings + +__version__ = "0.0.1" +__all__ = ("XMLCodec", "XMLSettings") + + +class XMLCodec(Codec): + """A codec for converting shapes to/from XML.""" + + def __init__( + self, + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + default_namespace: str | None = None, + ) -> None: + self._settings = XMLSettings( + default_timestamp_format=default_timestamp_format, + default_namespace=default_namespace, + ) + + @property + def media_type(self) -> str: + return "application/xml" + + def create_serializer(self, sink: BytesWriter) -> ShapeSerializer: + return _XMLShapeSerializer(sink=sink, settings=self._settings) + + def create_deserializer( + self, + source: bytes | BytesReader, + *, + wrapper_elements: tuple[str, ...] = (), + ) -> ShapeDeserializer: + if isinstance(source, bytes): + source = BytesIO(source) + reader = _XMLEventReader( + iterparse(source, events=("start", "end")) # noqa: S314 + ) + return _XMLShapeDeserializer( + settings=self._settings, reader=reader, wrapper_elements=wrapper_elements + ) diff --git a/packages/smithy-xml/src/smithy_xml/_private/__init__.py b/packages/smithy-xml/src/smithy_xml/_private/__init__.py new file mode 100644 index 000000000..33cbe867a --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/smithy-xml/src/smithy_xml/_private/deserializers.py b/packages/smithy-xml/src/smithy_xml/_private/deserializers.py new file mode 100644 index 000000000..ae1ece78f --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/deserializers.py @@ -0,0 +1,378 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from base64 import b64decode +from collections.abc import Callable +from decimal import Decimal +from xml.etree.ElementTree import Element + +from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer +from smithy_core.documents import Document +from smithy_core.exceptions import SmithyError +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import ( + TimestampFormatTrait, + XmlAttributeTrait, + XmlFlattenedTrait, + XmlNameTrait, +) + +from ..settings import XMLSettings +from .readers import XMLEvent, XMLEventReader + + +def _local_name(tag: str) -> str: + """Strip namespace URI from an element tag: {uri}local -> local.""" + if tag.startswith("{"): + return tag.split("}", 1)[1] + return tag + + +def _expected_root_name(schema: Schema) -> str | None: + """Get the expected root element name for root validation.""" + if schema.shape_type not in (ShapeType.STRUCTURE, ShapeType.UNION): + return None + if xml_name := schema.get_trait(XmlNameTrait): + return xml_name.value + return schema.id.name + + +def _validate_element_name(expected: str, elem: Element) -> None: + """Raise XMLParseError if the element's local name doesn't match expected.""" + found = _local_name(elem.tag) + if found != expected: + raise XMLParseError(f"Expected element '{expected}', got '{found}'") + + +def _xml_member_name(member_schema: Schema) -> str: + """Get the XML element name for a member, respecting @xmlName.""" + if xml_name := member_schema.get_trait(XmlNameTrait): + return xml_name.value + return member_schema.expect_member_name() + + +def _parse_xml_float(text: str) -> float: + """Parse an XML float string, handling NaN and Infinity.""" + match text: + case "NaN": + return float("nan") + case "Infinity": + return float("inf") + case "-Infinity": + return float("-inf") + case _: + return float(text) + + +class XMLParseError(SmithyError): + def __init__(self, message: str) -> None: + super().__init__(f"Error parsing XML: {message}") + + +class XMLShapeDeserializer(ShapeDeserializer): + """Deserializer that reads XML from a streaming pull parser.""" + + def __init__( + self, + settings: XMLSettings, + reader: XMLEventReader, + wrapper_elements: tuple[str, ...] = (), + ) -> None: + self._settings = settings + self._reader = reader + self._is_root = not bool(wrapper_elements) + self._xml_names: dict[ShapeID, dict[str, Schema]] = {} + self._preconsumed_start: Element | None = None + + # Wrapper elements are protocol transport containers (e.g. awsQuery's + # ). The last wrapper's start element is kept + # so that the next read can reuse it. + for wrapper in wrapper_elements: + event = next(self._reader) + if event.type != "start": + raise XMLParseError(f"Expected start element, got '{event.type}'") + _validate_element_name(wrapper, event.elem) + self._preconsumed_start = event.elem + + def is_null(self) -> bool: + return False + + def read_null(self) -> None: + return None + + def read_boolean(self, schema: Schema) -> bool: + text = self._read_text() + match text: + case "true": + return True + case "false": + return False + case _: + raise XMLParseError(f"Expected 'true' or 'false', got '{text}'") + + def read_blob(self, schema: Schema) -> bytes: + return b64decode(self._read_text()) + + def read_integer(self, schema: Schema) -> int: + return int(self._read_text()) + + def read_float(self, schema: Schema) -> float: + return _parse_xml_float(self._read_text()) + + def read_big_decimal(self, schema: Schema) -> Decimal: + return Decimal(self._read_text()) + + def read_string(self, schema: Schema) -> str: + return self._read_text() + + def read_document(self, schema: Schema) -> Document: + raise NotImplementedError("XML does not support document types") + + def read_timestamp(self, schema: Schema) -> datetime.datetime: + fmt = self._settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + + text = self._read_text() + return fmt.deserialize(text) + + def read_struct( + self, + schema: Schema, + consumer: Callable[[Schema, "ShapeDeserializer"], None], + ) -> None: + xml_names = self._get_xml_names(schema) + start_from_wrapper = self._preconsumed_start is not None + start_elem = self._consume_start_event() + if self._is_root: + self._is_root = False + expected = _expected_root_name(schema) + if expected is not None: + _validate_element_name(expected, start_elem) + + # Wrapper elements are protocol transport containers, not modeled structs, + # so their attributes cannot be deserialized as struct members. + if not start_from_wrapper: + for member_schema in schema.members.values(): + if member_schema.get_trait(XmlAttributeTrait) is None: + continue + expected_attr_name = _xml_member_name(member_schema) + for attr_name, attr_value in start_elem.attrib.items(): + attr_local_name = _local_name(attr_name) + if attr_local_name == expected_attr_name: + consumer( + member_schema, + _AttributeDeserializer(attr_value, self._settings), + ) + break + + # Flattened members lack an enclosing element, so there is no way to + # know when all items have been parsed. Their events are collected + # during iteration and replayed through a bounded reader afterwards. + flattened_buffers: dict[str, list[XMLEvent]] = {} + flattened_names = { + xml_name: member_schema + for xml_name, member_schema in xml_names.items() + if member_schema.get_trait(XmlFlattenedTrait) is not None + } + + while self._reader.peek().type != "end": + tag = _local_name(self._reader.peek().elem.tag) + + if tag in flattened_names: + flattened_buffers.setdefault(tag, []).extend(self._buffer_element()) + elif tag in xml_names: + consumer(xml_names[tag], self) + else: + # Skip unknown tag + self._consume_start_event() + self._skip_to_end() + + next(self._reader) + + for tag, events in flattened_buffers.items(): + member_schema = flattened_names[tag] + buffered_de = XMLShapeDeserializer( + self._settings, + XMLEventReader(iter(events)), + ) + consumer(member_schema, buffered_de) + + def read_list( + self, + schema: Schema, + consumer: Callable[["ShapeDeserializer"], None], + ) -> None: + is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + if not is_flattened: + self._consume_start_event() + while self._reader.peek().type != "end": + consumer(self) + else: + while self._reader.has_next(): + consumer(self) + + if not is_flattened: + next(self._reader) + + def read_map( + self, + schema: Schema, + consumer: Callable[[str, "ShapeDeserializer"], None], + ) -> None: + is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + key_schema = schema.members["key"] + value_schema = schema.members["value"] + key_tag = _xml_member_name(key_schema) + value_tag = _xml_member_name(value_schema) + + if not is_flattened: + self._consume_start_event() + while self._reader.peek().type != "end": + self._read_map_entry(key_tag, value_tag, consumer) + else: + while self._reader.has_next(): + self._read_map_entry(key_tag, value_tag, consumer) + + if not is_flattened: + next(self._reader) + + def _read_text(self) -> str: + """Consume a complete element (start through end) and return its text.""" + elem = self._consume_start_event() + self._skip_to_end() + # elem.text is populated only after consuming the "end" event + return elem.text or "" + + def _consume_start_event(self) -> Element: + """Consume and return the next start element. + + If a start element was pre-consumed (e.g. from consuming wrapper elements), + it is returned first and cleared. + """ + if self._preconsumed_start is not None: + elem = self._preconsumed_start + self._preconsumed_start = None + return elem + event = next(self._reader) + if event.type != "start": + raise XMLParseError(f"Expected start element, got '{event.type}'") + return event.elem + + def _skip_to_end(self) -> None: + """Skip to the matching end event. Assumes start was already consumed.""" + depth = 1 + while depth > 0: + event = next(self._reader) + if event.type == "start": + depth += 1 + elif event.type == "end": + depth -= 1 + + def _buffer_element(self) -> list[XMLEvent]: + """Buffer a complete element's events (start through matching end).""" + events: list[XMLEvent] = [] + event = next(self._reader) + events.append(event) + depth = 1 + while depth > 0: + event = next(self._reader) + events.append(event) + if event.type == "start": + depth += 1 + elif event.type == "end": + depth -= 1 + return events + + def _get_xml_names(self, schema: Schema) -> dict[str, Schema]: + """Get or build the XML element name -> member schema mapping for a shape.""" + if schema.id in self._xml_names: + return self._xml_names[schema.id] + result: dict[str, Schema] = {} + for member_schema in schema.members.values(): + if member_schema.get_trait(XmlAttributeTrait) is not None: + continue + xml_name = _xml_member_name(member_schema) + result[xml_name] = member_schema + self._xml_names[schema.id] = result + return result + + def _read_map_entry( + self, + key_tag: str, + value_tag: str, + consumer: Callable[[str, "ShapeDeserializer"], None], + ) -> None: + """Read one map entry element and emit key/value pairs via consumer.""" + self._consume_start_event() + + key: str | None = None + while self._reader.peek().type != "end": + child_tag = _local_name(self._reader.peek().elem.tag) + if child_tag == key_tag: + key = self._read_text() + elif child_tag == value_tag: + if key is None: + raise XMLParseError( + "Map key element must appear before value element" + ) + consumer(key, self) + else: + # Skip unknown child tag + self._consume_start_event() + self._skip_to_end() + + next(self._reader) + + +class _AttributeDeserializer(SpecificShapeDeserializer): + """Deserializer for a value extracted from an XML attribute string.""" + + def __init__(self, value: str, settings: XMLSettings) -> None: + self._value = value + self._settings = settings + + def read_string(self, schema: Schema) -> str: + return self._value + + def read_boolean(self, schema: Schema) -> bool: + match self._value: + case "true": + return True + case "false": + return False + case _: + raise XMLParseError(f"Expected 'true' or 'false', got '{self._value}'") + + def read_byte(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_short(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_integer(self, schema: Schema) -> int: + return int(self._value) + + def read_long(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_big_integer(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_float(self, schema: Schema) -> float: + return _parse_xml_float(self._value) + + def read_double(self, schema: Schema) -> float: + return self.read_float(schema) + + def read_big_decimal(self, schema: Schema) -> Decimal: + return Decimal(self._value) + + def read_timestamp(self, schema: Schema) -> datetime.datetime: + fmt = self._settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + + return fmt.deserialize(self._value) diff --git a/packages/smithy-xml/src/smithy_xml/_private/readers.py b/packages/smithy-xml/src/smithy_xml/_private/readers.py new file mode 100644 index 000000000..98164f1e2 --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/readers.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterator +from typing import NamedTuple +from xml.etree.ElementTree import Element + + +class XMLEvent(NamedTuple): + type: str + elem: Element + + +class XMLEventReader: + """Buffered iterator over XML pull parser events with peek support. + + Wraps an iterator of ``(event, element)`` tuples — either from + ``iterparse`` (streaming from a byte source) or from an in-memory list + (for flattened member replay). + """ + + def __init__(self, events: Iterator[tuple[str, Element] | XMLEvent]) -> None: + self._iter = events + self._pending: XMLEvent | None = None + + def __iter__(self): + return self + + def __next__(self) -> XMLEvent: + if self._pending is not None: + result = self._pending + self._pending = None + return result + return self._next() + + def _next(self) -> XMLEvent: + event = next(self._iter) + if isinstance(event, XMLEvent): + return event + event_type, elem = event + return XMLEvent(event_type, elem) + + def has_next(self) -> bool: + if self._pending is not None: + return True + try: + self._pending = self._next() + return True + except StopIteration: + return False + + def peek(self) -> XMLEvent: + if self._pending is None: + self._pending = self._next() + return self._pending diff --git a/packages/smithy-xml/src/smithy_xml/_private/serializers.py b/packages/smithy-xml/src/smithy_xml/_private/serializers.py new file mode 100644 index 000000000..c3f242070 --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/_private/serializers.py @@ -0,0 +1,372 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from base64 import b64encode +from collections.abc import Callable +from contextlib import AbstractContextManager +from datetime import datetime +from decimal import Decimal +from types import TracebackType +from typing import Self +from xml.etree.ElementTree import Element, SubElement, tostring + +from smithy_core.documents import Document +from smithy_core.interfaces import BytesWriter +from smithy_core.schemas import Schema +from smithy_core.serializers import ( + InterceptingSerializer, + MapSerializer, + ShapeSerializer, + SpecificShapeSerializer, +) +from smithy_core.shapes import ShapeType +from smithy_core.traits import ( + TimestampFormatTrait, + XmlAttributeTrait, + XmlFlattenedTrait, + XmlNamespaceTrait, + XmlNameTrait, +) + +from ..settings import XMLSettings + +_INF: float = float("inf") +_NEG_INF: float = float("-inf") + + +def _xml_member_name(member_schema: Schema) -> str: + """Get the XML element name for a member, respecting @xmlName.""" + if xml_name := member_schema.get_trait(XmlNameTrait): + return xml_name.value + return member_schema.expect_member_name() + + +def _xml_root_name(schema: Schema) -> str: + """Get the XML root element name, respecting @xmlName and member targets.""" + if xml_name := schema.get_trait(XmlNameTrait): + return xml_name.value + if schema.member_target is not None: + return schema.expect_member_target().id.name + return schema.id.name + + +def _set_xml_namespace( + element: Element, + schema: Schema, + settings: XMLSettings, + *, + is_root: bool = False, +) -> None: + """Apply @xmlNamespace to an element, or the default namespace if root.""" + if namespace_trait := schema.get_trait(XmlNamespaceTrait): + if namespace_trait.prefix: + element.set(f"xmlns:{namespace_trait.prefix}", namespace_trait.uri) + else: + element.set("xmlns", namespace_trait.uri) + return + + if is_root and settings.default_namespace: + element.set("xmlns", settings.default_namespace) + + +def _format_xml_float(value: float) -> str: + """Format a float for XML, handling NaN and Infinity.""" + if value != value: + return "NaN" + if value == _INF: + return "Infinity" + if value == _NEG_INF: + return "-Infinity" + return repr(value) + + +def _is_flattened_collection_schema(schema: Schema) -> bool: + """Check if a schema is a flattened list or map.""" + return schema.get_trait(XmlFlattenedTrait) is not None and schema.shape_type in ( + ShapeType.LIST, + ShapeType.MAP, + ) + + +class XMLShapeSerializer(ShapeSerializer): + """Serializes Smithy shapes into XML and writes the result to a BytesWriter. + + Builds an in-memory XML tree backed by an element stack. ``write_*`` + methods target the top element, and struct/list/map serializers push and + pop child elements to control nesting. ``flush`` writes the tree to the + sink. + """ + + def __init__(self, sink: BytesWriter, settings: XMLSettings) -> None: + self._sink = sink + self.settings = settings + self._root: Element | None = None + self.element_stack: list[Element] = [] + + @property + def current(self) -> Element: + return self.element_stack[-1] + + def ensure_root(self, schema: Schema) -> None: + if self._root is not None: + return + root = Element(_xml_root_name(schema)) + _set_xml_namespace(root, schema, self.settings, is_root=True) + self._root = root + self.element_stack.append(root) + + def begin_struct( + self, schema: "Schema" + ) -> AbstractContextManager["ShapeSerializer"]: + return XMLStructSerializer(self, schema) + + def begin_list( + self, schema: "Schema", size: int + ) -> AbstractContextManager["ShapeSerializer"]: + return XMLListSerializer(self, schema) + + def begin_map( + self, schema: "Schema", size: int + ) -> AbstractContextManager["MapSerializer"]: + return XMLMapSerializer(self, schema) + + def write_null(self, schema: "Schema") -> None: + self.ensure_root(schema) + + def write_boolean(self, schema: "Schema", value: bool) -> None: + self.ensure_root(schema) + self.current.text = "true" if value else "false" + + def write_integer(self, schema: "Schema", value: int) -> None: + self.ensure_root(schema) + self.current.text = str(value) + + def write_float(self, schema: "Schema", value: float) -> None: + self.ensure_root(schema) + self.current.text = _format_xml_float(value) + + def write_big_decimal(self, schema: "Schema", value: Decimal) -> None: + self.ensure_root(schema) + self.current.text = str(value.normalize()) + + def write_string(self, schema: "Schema", value: str) -> None: + self.ensure_root(schema) + self.current.text = value + + def write_blob(self, schema: "Schema", value: bytes) -> None: + self.ensure_root(schema) + self.current.text = b64encode(value).decode("utf-8") + + def write_timestamp(self, schema: "Schema", value: datetime) -> None: + self.ensure_root(schema) + fmt = self.settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + self.current.text = str(fmt.serialize(value)) + + def write_document(self, schema: "Schema", value: Document) -> None: + raise NotImplementedError("XML does not support document types.") + + def flush(self) -> None: + if self._root is None: + return + xml_bytes = tostring(self._root, encoding="utf-8", xml_declaration=False) + self._sink.write(xml_bytes) + + self._root = None + self.element_stack.clear() + + +class XMLStructSerializer(InterceptingSerializer): + """Serializes struct members as child XML elements. + + ``before`` pushes a child element for the member onto the parent's stack, + and ``after`` pops it. Attributes and flattened collections are special-cased + to skip the push/pop. + """ + + def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._schema = schema + + def __enter__(self) -> Self: + self._parent.ensure_root(self._schema) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: "Schema") -> ShapeSerializer: + member_name = _xml_member_name(schema) + + # Attributes are written on the current element, not as children. + if schema.get_trait(XmlAttributeTrait) is not None: + return XMLAttributeSerializer( + self._parent.current, member_name, self._parent.settings + ) + + # Flattened collections have no wrapper element. Items are added + # directly under the current element without changing the stack. + if _is_flattened_collection_schema(schema): + return self._parent + + # Non-flattened collections push a wrapper element onto the stack. + child = SubElement(self._parent.current, member_name) + _set_xml_namespace(child, schema, self._parent.settings) + self._parent.element_stack.append(child) + return self._parent + + def after(self, schema: "Schema") -> None: + # Attributes and flattened collections didn't push, so don't pop. + if schema.get_trait(XmlAttributeTrait) is not None: + return + if _is_flattened_collection_schema(schema): + return + self._parent.element_stack.pop() + + +class XMLListSerializer(InterceptingSerializer): + """Serializes list items as repeated child elements. + + ``before`` pushes a child element for each item, ``after`` pops it. + """ + + def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._schema = schema + is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + + if is_flattened: + if schema.member_target is not None: + self._item_tag = _xml_member_name(schema) + else: + self._item_tag = _xml_root_name(schema) + else: + self._item_tag = _xml_member_name(schema.members["member"]) + + def __enter__(self) -> Self: + self._parent.ensure_root(self._schema) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: "Schema") -> ShapeSerializer: + child = SubElement(self._parent.current, self._item_tag) + _set_xml_namespace(child, schema, self._parent.settings) + self._parent.element_stack.append(child) + return self._parent + + def after(self, schema: "Schema") -> None: + self._parent.element_stack.pop() + + +class XMLMapSerializer(MapSerializer): + """Serializes map entries as ```` elements. + + Each ``entry`` call pushes the value element onto the stack for the + value writer callback, then pops it. + """ + + def __init__(self, parent: XMLShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._schema = schema + self._is_flattened = schema.get_trait(XmlFlattenedTrait) is not None + + self._key_schema = schema.members["key"] + self._value_schema = schema.members["value"] + self._key_tag = _xml_member_name(self._key_schema) + self._value_tag = _xml_member_name(self._value_schema) + + if self._is_flattened: + if schema.member_target is not None: + self._entry_tag = _xml_member_name(schema) + else: + self._entry_tag = _xml_root_name(schema) + else: + self._entry_tag = "entry" + + def __enter__(self) -> Self: + self._parent.ensure_root(self._schema) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def entry(self, key: str, value_writer: Callable[[ShapeSerializer], None]) -> None: + settings = self._parent.settings + entry_el = SubElement(self._parent.current, self._entry_tag) + if self._is_flattened: + _set_xml_namespace(entry_el, self._schema, settings) + + key_el = SubElement(entry_el, self._key_tag) + _set_xml_namespace(key_el, self._key_schema, settings) + key_el.text = key + + value_el = SubElement(entry_el, self._value_tag) + _set_xml_namespace(value_el, self._value_schema, settings) + self._parent.element_stack.append(value_el) + value_writer(self._parent) + self._parent.element_stack.pop() + + +class XMLAttributeSerializer(SpecificShapeSerializer): + """Serializer that writes values as XML attributes on the parent element.""" + + def __init__(self, element: Element, attr_name: str, settings: XMLSettings) -> None: + self._element = element + self._attr_name = attr_name + self._settings = settings + + def write_null(self, schema: "Schema") -> None: + pass + + def write_boolean(self, schema: "Schema", value: bool) -> None: + self._element.set(self._attr_name, "true" if value else "false") + + def write_byte(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_short(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_integer(self, schema: "Schema", value: int) -> None: + self._element.set(self._attr_name, str(value)) + + def write_long(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_big_integer(self, schema: "Schema", value: int) -> None: + self.write_integer(schema, value) + + def write_float(self, schema: "Schema", value: float) -> None: + self._element.set(self._attr_name, _format_xml_float(value)) + + def write_double(self, schema: "Schema", value: float) -> None: + self.write_float(schema, value) + + def write_big_decimal(self, schema: "Schema", value: Decimal) -> None: + self._element.set(self._attr_name, str(value.normalize())) + + def write_string(self, schema: "Schema", value: str) -> None: + self._element.set(self._attr_name, value) + + def write_timestamp(self, schema: "Schema", value: datetime) -> None: + fmt = self._settings.default_timestamp_format + if format_trait := schema.get_trait(TimestampFormatTrait): + fmt = format_trait.format + self._element.set(self._attr_name, str(fmt.serialize(value))) diff --git a/packages/smithy-xml/src/smithy_xml/py.typed b/packages/smithy-xml/src/smithy_xml/py.typed new file mode 100644 index 000000000..f5642f79f --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/py.typed @@ -0,0 +1 @@ +Marker diff --git a/packages/smithy-xml/src/smithy_xml/settings.py b/packages/smithy-xml/src/smithy_xml/settings.py new file mode 100644 index 000000000..ea25cd1b6 --- /dev/null +++ b/packages/smithy-xml/src/smithy_xml/settings.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from smithy_core.types import TimestampFormat + + +@dataclass(frozen=True) +class XMLSettings: + """Configuration for XML serialization/deserialization.""" + + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME + """Default timestamp format when a member does not define @timestampFormat.""" + + default_namespace: str | None = None + """Default XML namespace (``xmlns``) applied to the root element during serialization.""" diff --git a/packages/smithy-xml/tests/__init__.py b/packages/smithy-xml/tests/__init__.py new file mode 100644 index 000000000..04f8b7b76 --- /dev/null +++ b/packages/smithy-xml/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/smithy-xml/tests/unit/__init__.py b/packages/smithy-xml/tests/unit/__init__.py new file mode 100644 index 000000000..630200031 --- /dev/null +++ b/packages/smithy-xml/tests/unit/__init__.py @@ -0,0 +1,662 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from datetime import UTC, datetime +from decimal import Decimal +from typing import Any, Self + +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.prelude import ( + BIG_DECIMAL, + BLOB, + BOOLEAN, + FLOAT, + INTEGER, + STRING, + TIMESTAMP, +) +from smithy_core.schemas import Schema +from smithy_core.serializers import ShapeSerializer +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import ( + TimestampFormatTrait, + XmlAttributeTrait, + XmlFlattenedTrait, + XmlNamespaceTrait, + XmlNameTrait, +) + +STRING_LIST_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#StringList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "target": STRING, + } + }, +) + +STRING_MAP_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": { + "target": STRING, + }, + "value": { + "target": STRING, + }, + }, +) + +# List with @xmlName on the member element +RENAMED_LIST_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#RenamedList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "target": STRING, + "traits": [XmlNameTrait("item")], + } + }, +) + +# Map with @xmlName on key/value members +RENAMED_MAP_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#RenamedMap"), + shape_type=ShapeType.MAP, + members={ + "key": { + "target": STRING, + "traits": [XmlNameTrait("Attribute")], + }, + "value": { + "target": STRING, + "traits": [XmlNameTrait("Setting")], + }, + }, +) + +# Map with @xmlName and @xmlNamespace on key/value members +RENAMED_NS_MAP_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#RenamedNsMap"), + shape_type=ShapeType.MAP, + members={ + "key": { + "target": STRING, + "traits": [ + XmlNameTrait("K"), + XmlNamespaceTrait({"uri": "https://the-key.example.com"}), + ], + }, + "value": { + "target": STRING, + "traits": [ + XmlNameTrait("V"), + XmlNamespaceTrait({"uri": "https://the-value.example.com"}), + ], + }, + }, +) + +# List with @xmlNamespace on member +NAMESPACED_LIST_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#NamespacedList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "target": STRING, + "traits": [XmlNamespaceTrait({"uri": "http://bux.com"})], + } + }, +) + +# Struct with @xmlNamespace (default xmlns) +NAMESPACED_STRUCT_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#NsStruct"), + traits=[XmlNamespaceTrait({"uri": "https://example.com"})], + members={ + "value": {"target": STRING}, + }, +) + +# Struct with @xmlNamespace (prefixed xmlns) +PREFIXED_NS_STRUCT_SCHEMA = Schema.collection( + id=ShapeID("smithy.example#PrefixedNsStruct"), + traits=[XmlNamespaceTrait({"uri": "https://example.com", "prefix": "baz"})], + members={ + "value": {"target": STRING}, + }, +) + +SCHEMA: Schema = Schema.collection( + id=ShapeID("smithy.example#SerdeShape"), + members={ + "booleanMember": { + "target": BOOLEAN, + }, + "integerMember": { + "target": INTEGER, + }, + "floatMember": { + "target": FLOAT, + }, + "bigDecimalMember": { + "target": BIG_DECIMAL, + }, + "stringMember": { + "target": STRING, + }, + "xmlNameMember": { + "target": STRING, + "traits": [XmlNameTrait("CustomName")], + }, + "blobMember": { + "target": BLOB, + }, + "timestampMember": { + "target": TIMESTAMP, + }, + "dateTimeMember": { + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("date-time")], + }, + "httpDateMember": { + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("http-date")], + }, + "epochSecondsMember": { + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("epoch-seconds")], + }, + "listMember": { + "target": STRING_LIST_SCHEMA, + }, + "mapMember": { + "target": STRING_MAP_SCHEMA, + }, + "structMember": None, + "xmlAttributeMember": { + "target": STRING, + "traits": [XmlAttributeTrait()], + }, + "renamedListMember": { + "target": RENAMED_LIST_SCHEMA, + }, + "flattenedListMember": { + "target": STRING_LIST_SCHEMA, + "traits": [XmlFlattenedTrait()], + }, + "flattenedMapMember": { + "target": STRING_MAP_SCHEMA, + "traits": [XmlFlattenedTrait()], + }, + "flattenedRenamedListMember": { + "target": STRING_LIST_SCHEMA, + "traits": [XmlFlattenedTrait(), XmlNameTrait("customItem")], + }, + "flattenedRenamedMapMember": { + "target": RENAMED_MAP_SCHEMA, + "traits": [XmlFlattenedTrait(), XmlNameTrait("KVP")], + }, + "xmlAttributeNamedMember": { + "target": STRING, + "traits": [XmlAttributeTrait(), XmlNameTrait("test")], + }, + }, +) +SCHEMA.members["structMember"] = Schema.member( + id=SCHEMA.id.with_member("structMember"), + target=SCHEMA, + index=13, +) + + +@dataclass +class SerdeShape: + boolean_member: bool | None = None + integer_member: int | None = None + float_member: float | None = None + big_decimal_member: Decimal | None = None + string_member: str | None = None + xml_name_member: str | None = None + blob_member: bytes | None = None + timestamp_member: datetime | None = None + date_time_member: datetime | None = None + http_date_member: datetime | None = None + epoch_seconds_member: datetime | None = None + list_member: list[str] | None = None + map_member: dict[str, str] | None = None + struct_member: "SerdeShape | None" = None + xml_attribute_member: str | None = None + renamed_list_member: list[str] | None = None + flattened_list_member: list[str] | None = None + flattened_map_member: dict[str, str] | None = None + flattened_renamed_list_member: list[str] | None = None + flattened_renamed_map_member: dict[str, str] | None = None + xml_attribute_named_member: str | None = None + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.boolean_member is not None: + serializer.write_boolean( + SCHEMA.members["booleanMember"], self.boolean_member + ) + if self.integer_member is not None: + serializer.write_integer( + SCHEMA.members["integerMember"], self.integer_member + ) + if self.float_member is not None: + serializer.write_float(SCHEMA.members["floatMember"], self.float_member) + if self.big_decimal_member is not None: + serializer.write_big_decimal( + SCHEMA.members["bigDecimalMember"], self.big_decimal_member + ) + if self.string_member is not None: + serializer.write_string(SCHEMA.members["stringMember"], self.string_member) + if self.xml_name_member is not None: + serializer.write_string( + SCHEMA.members["xmlNameMember"], self.xml_name_member + ) + if self.blob_member is not None: + serializer.write_blob(SCHEMA.members["blobMember"], self.blob_member) + if self.timestamp_member is not None: + serializer.write_timestamp( + SCHEMA.members["timestampMember"], self.timestamp_member + ) + if self.date_time_member is not None: + serializer.write_timestamp( + SCHEMA.members["dateTimeMember"], self.date_time_member + ) + if self.http_date_member is not None: + serializer.write_timestamp( + SCHEMA.members["httpDateMember"], self.http_date_member + ) + if self.epoch_seconds_member is not None: + serializer.write_timestamp( + SCHEMA.members["epochSecondsMember"], self.epoch_seconds_member + ) + if self.list_member is not None: + schema = SCHEMA.members["listMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.list_member)) as ls: + for element in self.list_member: + ls.write_string(target_schema, element) + if self.map_member is not None: + schema = SCHEMA.members["mapMember"] + target_schema = schema.expect_member_target().members["value"] + with serializer.begin_map(schema, len(self.map_member)) as ms: + for key, value in self.map_member.items(): + ms.entry(key, lambda vs: vs.write_string(target_schema, value)) # type: ignore + if self.struct_member is not None: + serializer.write_struct(SCHEMA.members["structMember"], self.struct_member) + if self.xml_attribute_member is not None: + serializer.write_string( + SCHEMA.members["xmlAttributeMember"], self.xml_attribute_member + ) + if self.renamed_list_member is not None: + schema = SCHEMA.members["renamedListMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.renamed_list_member)) as ls: + for element in self.renamed_list_member: + ls.write_string(target_schema, element) + if self.flattened_list_member is not None: + schema = SCHEMA.members["flattenedListMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.flattened_list_member)) as ls: + for element in self.flattened_list_member: + ls.write_string(target_schema, element) + if self.flattened_map_member is not None: + schema = SCHEMA.members["flattenedMapMember"] + target_schema = schema.expect_member_target().members["value"] + with serializer.begin_map(schema, len(self.flattened_map_member)) as ms: + for key, value in self.flattened_map_member.items(): + ms.entry(key, lambda vs: vs.write_string(target_schema, value)) # type: ignore + if self.flattened_renamed_list_member is not None: + schema = SCHEMA.members["flattenedRenamedListMember"] + target_schema = schema.expect_member_target().members["member"] + with serializer.begin_list( + schema, len(self.flattened_renamed_list_member) + ) as ls: + for element in self.flattened_renamed_list_member: + ls.write_string(target_schema, element) + if self.flattened_renamed_map_member is not None: + schema = SCHEMA.members["flattenedRenamedMapMember"] + target_schema = schema.expect_member_target().members["value"] + with serializer.begin_map( + schema, len(self.flattened_renamed_map_member) + ) as ms: + for key, value in self.flattened_renamed_map_member.items(): + ms.entry(key, lambda vs: vs.write_string(target_schema, value)) # type: ignore + if self.xml_attribute_named_member is not None: + serializer.write_string( + SCHEMA.members["xmlAttributeNamedMember"], + self.xml_attribute_named_member, + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["boolean_member"] = de.read_boolean( + SCHEMA.members["booleanMember"] + ) + case 1: + kwargs["integer_member"] = de.read_integer( + SCHEMA.members["integerMember"] + ) + case 2: + kwargs["float_member"] = de.read_float( + SCHEMA.members["floatMember"] + ) + case 3: + kwargs["big_decimal_member"] = de.read_big_decimal( + SCHEMA.members["bigDecimalMember"] + ) + case 4: + kwargs["string_member"] = de.read_string( + SCHEMA.members["stringMember"] + ) + case 5: + kwargs["xml_name_member"] = de.read_string( + SCHEMA.members["xmlNameMember"] + ) + case 6: + kwargs["blob_member"] = de.read_blob(SCHEMA.members["blobMember"]) + case 7: + kwargs["timestamp_member"] = de.read_timestamp( + SCHEMA.members["timestampMember"] + ) + case 8: + kwargs["date_time_member"] = de.read_timestamp( + SCHEMA.members["dateTimeMember"] + ) + case 9: + kwargs["http_date_member"] = de.read_timestamp( + SCHEMA.members["httpDateMember"] + ) + case 10: + kwargs["epoch_seconds_member"] = de.read_timestamp( + SCHEMA.members["epochSecondsMember"] + ) + case 11: + list_value: list[str] = [] + de.read_list( + SCHEMA.members["listMember"], + lambda d: list_value.append(d.read_string(STRING)), + ) + kwargs["list_member"] = list_value + case 12: + map_value: dict[str, str] = {} + de.read_map( + SCHEMA.members["mapMember"], + lambda k, d: map_value.__setitem__(k, d.read_string(STRING)), + ) + kwargs["map_member"] = map_value + case 13: + kwargs["struct_member"] = SerdeShape.deserialize(de) + case 14: + kwargs["xml_attribute_member"] = de.read_string( + SCHEMA.members["xmlAttributeMember"] + ) + case 15: + renamed_list_value: list[str] = [] + de.read_list( + SCHEMA.members["renamedListMember"], + lambda d: renamed_list_value.append(d.read_string(STRING)), + ) + kwargs["renamed_list_member"] = renamed_list_value + case 16: + flat_list_value: list[str] = [] + de.read_list( + SCHEMA.members["flattenedListMember"], + lambda d: flat_list_value.append(d.read_string(STRING)), + ) + kwargs["flattened_list_member"] = flat_list_value + case 17: + flat_map_value: dict[str, str] = {} + de.read_map( + SCHEMA.members["flattenedMapMember"], + lambda k, d: flat_map_value.__setitem__( + k, d.read_string(STRING) + ), + ) + kwargs["flattened_map_member"] = flat_map_value + case 18: + flat_renamed_list: list[str] = [] + de.read_list( + SCHEMA.members["flattenedRenamedListMember"], + lambda d: flat_renamed_list.append(d.read_string(STRING)), + ) + kwargs["flattened_renamed_list_member"] = flat_renamed_list + case 19: + flat_renamed_map: dict[str, str] = {} + de.read_map( + SCHEMA.members["flattenedRenamedMapMember"], + lambda k, d: flat_renamed_map.__setitem__( + k, d.read_string(STRING) + ), + ) + kwargs["flattened_renamed_map_member"] = flat_renamed_map + case 20: + kwargs["xml_attribute_named_member"] = de.read_string( + SCHEMA.members["xmlAttributeNamedMember"] + ) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +# ---- Serde test cases ---- +# Each entry is (value, xml_bytes) for round-trip testing. +# Inspired by the awsQuery/restXml protocol compliance tests + +XML_SERDE_CASES: list[tuple[Any, bytes]] = [ + # Scalars + (True, b"true"), + (False, b"false"), + (1, b"1"), + (1.5, b"1.5"), + (float("inf"), b"Infinity"), + (float("-inf"), b"-Infinity"), + (Decimal("1.1"), b"1.1"), + (b"value", b"dmFsdWU="), + ("foo", b"foo"), + ( + datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC), + b"2014-04-29T18:30:38Z", + ), + # Wrapped list — elements + ( + ["foo", "bar"], + b"foobar", + ), + # Wrapped map — + ( + {"foo": "bar"}, + b"foobar", + ), + # Struct with single scalar + ( + SerdeShape(string_member="foo"), + b"foo", + ), + ( + SerdeShape(boolean_member=True), + b"true", + ), + ( + SerdeShape(integer_member=3), + b"3", + ), + ( + SerdeShape(float_member=5.5), + b"5.5", + ), + ( + SerdeShape(big_decimal_member=Decimal("1.1")), + b"1.1", + ), + ( + SerdeShape(blob_member=b"value"), + b"dmFsdWU=", + ), + # @xmlName — member serialized under custom element name + ( + SerdeShape(xml_name_member="bar"), + b"bar", + ), + # Timestamps with different formats + ( + SerdeShape(timestamp_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"2014-04-29T18:30:38Z", + ), + ( + SerdeShape(date_time_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"2014-04-29T18:30:38Z", + ), + ( + SerdeShape(http_date_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"Tue, 29 Apr 2014 18:30:38 GMT", + ), + ( + SerdeShape(epoch_seconds_member=datetime(2014, 4, 29, 18, 30, 38, tzinfo=UTC)), + b"1398796238", + ), + # List inside struct + ( + SerdeShape(list_member=["foo", "bar"]), + ( + b"" + b"foobar" + b"" + ), + ), + # Map inside struct + ( + SerdeShape(map_member={"foo": "bar"}), + ( + b"" + b"foobar" + b"" + ), + ), + # Nested struct + ( + SerdeShape(struct_member=SerdeShape(string_member="nested")), + ( + b"" + b"nested" + b"" + ), + ), + # @xmlAttribute — member as attribute on parent element + ( + SerdeShape(xml_attribute_member="attr_val"), + b'', + ), + # List with @xmlName("item") on member + ( + SerdeShape(renamed_list_member=["foo", "bar"]), + ( + b"" + b"foobar" + b"" + ), + ), + # @xmlFlattened list + ( + SerdeShape(flattened_list_member=["hi", "bye"]), + ( + b"" + b"hi" + b"bye" + b"" + ), + ), + # @xmlFlattened map + ( + SerdeShape(flattened_map_member={"foo": "Foo", "baz": "Baz"}), + ( + b"" + b"fooFoo" + b"bazBaz" + b"" + ), + ), + # @xmlFlattened + @xmlName on list member + ( + SerdeShape(flattened_renamed_list_member=["hi", "bye"]), + ( + b"" + b"hi" + b"bye" + b"" + ), + ), + # @xmlFlattened + @xmlName on map member with renamed key/value + ( + SerdeShape(flattened_renamed_map_member={"foo": "Foo"}), + ( + b"" + b"fooFoo" + b"" + ), + ), + # @xmlAttribute + @xmlName + ( + SerdeShape(xml_attribute_named_member="attr_val"), + b'', + ), + # Multiple members in one struct — realistic multi-member test + ( + SerdeShape( + boolean_member=True, + integer_member=42, + string_member="hello", + list_member=["a", "b"], + ), + ( + b"" + b"true" + b"42" + b"hello" + b"ab" + b"" + ), + ), + # Nested struct 3 levels deep + ( + SerdeShape( + struct_member=SerdeShape(struct_member=SerdeShape(string_member="deep")) + ), + ( + b"" + b"" + b"deep" + b"" + b"" + ), + ), + # XML escaping in text content + ( + SerdeShape(string_member=""), + b"<foo&bar>", + ), + # Empty collections — wrapper element with no children + ( + SerdeShape(list_member=[]), + b"", + ), + ( + SerdeShape(map_member={}), + b"", + ), +] diff --git a/packages/smithy-xml/tests/unit/test_deserializers.py b/packages/smithy-xml/tests/unit/test_deserializers.py new file mode 100644 index 000000000..f3ee9e703 --- /dev/null +++ b/packages/smithy-xml/tests/unit/test_deserializers.py @@ -0,0 +1,145 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import math +from datetime import datetime +from decimal import Decimal +from typing import Any + +import pytest +from smithy_core.prelude import ( + BIG_DECIMAL, + BLOB, + BOOLEAN, + DOCUMENT, + FLOAT, + INTEGER, + STRING, + TIMESTAMP, +) +from smithy_xml import XMLCodec + +from . import ( + STRING_LIST_SCHEMA, + STRING_MAP_SCHEMA, + XML_SERDE_CASES, + SerdeShape, +) + + +@pytest.mark.parametrize("expected, given", XML_SERDE_CASES) +def test_xml_deserializer(expected: Any, given: bytes) -> None: + codec = XMLCodec() + deserializer = codec.create_deserializer(given) + match expected: + case bool(): + actual = deserializer.read_boolean(BOOLEAN) + case int(): + actual = deserializer.read_integer(INTEGER) + case float(): + actual = deserializer.read_float(FLOAT) + case Decimal(): + actual = deserializer.read_big_decimal(BIG_DECIMAL) + case bytes(): + actual = deserializer.read_blob(BLOB) + case str(): + actual = deserializer.read_string(STRING) + case datetime(): + actual = deserializer.read_timestamp(TIMESTAMP) + case list(): + actual_list: list[str] = [] + deserializer.read_list( + STRING_LIST_SCHEMA, + lambda d: actual_list.append(d.read_string(STRING)), + ) + actual = actual_list + case dict(): + actual_map: dict[str, str] = {} + deserializer.read_map( + STRING_MAP_SCHEMA, + lambda k, d: actual_map.__setitem__(k, d.read_string(STRING)), + ) + actual = actual_map + case SerdeShape(): + actual = SerdeShape.deserialize(deserializer) + case _: + raise Exception(f"Unexpected type: {type(expected)}") + + assert actual == expected + + +def test_read_document_raises() -> None: + """XML does not support document types.""" + deserializer = XMLCodec().create_deserializer(b"foo") + with pytest.raises( + NotImplementedError, match="XML does not support document types" + ): + deserializer.read_document(DOCUMENT) + + +def test_deserialize_nan() -> None: + actual = XMLCodec().create_deserializer(b"NaN").read_float(FLOAT) + assert math.isnan(actual) + + +def test_deserialize_empty_string_self_closed() -> None: + assert XMLCodec().create_deserializer(b"").read_string(STRING) == "" + + +def test_deserialize_empty_string_open_close() -> None: + assert XMLCodec().create_deserializer(b"").read_string(STRING) == "" + + +def test_deserialize_empty_blob() -> None: + assert XMLCodec().create_deserializer(b"").read_blob(BLOB) == b"" + + +def test_deserialize_empty_blob_self_closed() -> None: + assert XMLCodec().create_deserializer(b"").read_blob(BLOB) == b"" + + +def test_wrapper_elements() -> None: + """Deserializer can unwrap awsQuery-style response wrappers.""" + xml = ( + b"" + b"hello" + b"" + ) + deserializer = XMLCodec().create_deserializer( + xml, wrapper_elements=("OpResponse", "OpResult") + ) + result = SerdeShape.deserialize(deserializer) + assert result.string_member == "hello" + + +def test_wrapper_elements_scalar_read() -> None: + xml = b"hello" + deserializer = XMLCodec().create_deserializer( + xml, wrapper_elements=("OpResponse", "OpResult") + ) + assert deserializer.read_string(STRING) == "hello" + + +def test_flattened_list_interleaved_with_other_members() -> None: + """Flattened list elements can be interleaved with other struct members.""" + xml = ( + b"" + b"first" + b"middle" + b"second" + b"" + ) + result = SerdeShape.deserialize(XMLCodec().create_deserializer(xml)) + assert result.flattened_list_member == ["first", "second"] + assert result.string_member == "middle" + + +def test_unknown_members_skipped() -> None: + xml = ( + b"" + b"keep" + b"ignore" + b"5" + b"" + ) + result = SerdeShape.deserialize(XMLCodec().create_deserializer(xml)) + assert result == SerdeShape(string_member="keep", integer_member=5) diff --git a/packages/smithy-xml/tests/unit/test_serializers.py b/packages/smithy-xml/tests/unit/test_serializers.py new file mode 100644 index 000000000..2356c0df9 --- /dev/null +++ b/packages/smithy-xml/tests/unit/test_serializers.py @@ -0,0 +1,209 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from decimal import Decimal +from io import BytesIO +from typing import Any, cast +from xml.etree.ElementTree import canonicalize + +import pytest +from smithy_core.prelude import ( + BIG_DECIMAL, + BLOB, + BOOLEAN, + FLOAT, + INTEGER, + STRING, + TIMESTAMP, +) +from smithy_xml import XMLCodec + +from . import ( + NAMESPACED_LIST_SCHEMA, + NAMESPACED_STRUCT_SCHEMA, + PREFIXED_NS_STRUCT_SCHEMA, + RENAMED_NS_MAP_SCHEMA, + STRING_LIST_SCHEMA, + STRING_MAP_SCHEMA, + XML_SERDE_CASES, + SerdeShape, +) + + +def _canonicalize(xml_bytes: bytes) -> str: + """Canonicalize XML for comparison, stripping whitespace differences.""" + return canonicalize(xml_bytes, strip_text=True) + + +@pytest.mark.parametrize("given, expected", XML_SERDE_CASES) +def test_xml_serializer(given: Any, expected: bytes) -> None: + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + match given: + case bool(): + serializer.write_boolean(BOOLEAN, given) + case int(): + serializer.write_integer(INTEGER, given) + case float(): + serializer.write_float(FLOAT, given) + case Decimal(): + serializer.write_big_decimal(BIG_DECIMAL, given) + case bytes(): + serializer.write_blob(BLOB, given) + case str(): + serializer.write_string(STRING, given) + case datetime(): + serializer.write_timestamp(TIMESTAMP, given) + case list(): + given = cast(list[str], given) + with serializer.begin_list(STRING_LIST_SCHEMA, len(given)) as ls: + member_schema = STRING_LIST_SCHEMA.members["member"] + for e in given: + ls.write_string(member_schema, e) + case dict(): + given = cast(dict[str, str], given) + with serializer.begin_map(STRING_MAP_SCHEMA, len(given)) as ms: + member_schema = STRING_MAP_SCHEMA.members["value"] + for k, v in given.items(): + ms.entry(k, lambda vs: vs.write_string(member_schema, v)) # type: ignore + case SerdeShape(): + given.serialize(serializer) + case _: + raise Exception(f"Unexpected type: {type(given)}") + + serializer.flush() + sink.seek(0) + actual = sink.read() + assert _canonicalize(actual) == _canonicalize(expected) + + +def test_write_null() -> None: + """write_null creates an empty element (no text content).""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + serializer.write_null(STRING) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == b"" + + +def test_write_document_raises() -> None: + """XML does not support document types.""" + from smithy_core.documents import Document + from smithy_core.prelude import DOCUMENT + + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + with pytest.raises(NotImplementedError, match="XML does not support document"): + serializer.write_document(DOCUMENT, Document(None)) + + +def test_float_nan() -> None: + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + serializer.write_float(FLOAT, float("nan")) + serializer.flush() + sink.seek(0) + assert sink.read() == b"NaN" + + +def test_default_namespace() -> None: + """Default namespace is set on the root element.""" + sink = BytesIO() + serializer = XMLCodec(default_namespace="https://example.com/").create_serializer( + sink + ) + serializer.write_string(STRING, "hi") + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == b'hi' + + +def test_xml_escaping_in_attribute() -> None: + """XML special characters are escaped in attribute values.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + shape = SerdeShape(xml_attribute_named_member='<"test">&') + shape.serialize(serializer) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == b'' + + +def test_flush_with_no_writes() -> None: + """Flushing without any writes produces no output.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + serializer.flush() + sink.seek(0) + assert sink.read() == b"" + + +def test_list_with_namespace_on_member() -> None: + """@xmlNamespace on list member adds xmlns to each item element.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + items = ["Bar"] + with serializer.begin_list(NAMESPACED_LIST_SCHEMA, len(items)) as ls: + member_schema = NAMESPACED_LIST_SCHEMA.members["member"] + for e in items: + ls.write_string(member_schema, e) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert ( + actual + == b'Bar' + ) + + +def test_map_with_xmlname_and_namespace() -> None: + """Map with @xmlName + @xmlNamespace on key and value members.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + data = {"a": "A"} + with serializer.begin_map(RENAMED_NS_MAP_SCHEMA, len(data)) as ms: + member_schema = RENAMED_NS_MAP_SCHEMA.members["value"] + for k, v in data.items(): + ms.entry(k, lambda vs: vs.write_string(member_schema, v)) + serializer.flush() + sink.seek(0) + actual = sink.read() + assert actual == ( + b"" + b'a' + b'A' + b"" + ) + + +def test_struct_with_xml_namespace() -> None: + """@xmlNamespace on struct adds default xmlns to root element.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + with serializer.begin_struct(NAMESPACED_STRUCT_SCHEMA) as ss: + ss.write_string(NAMESPACED_STRUCT_SCHEMA.members["value"], "hi") + serializer.flush() + sink.seek(0) + actual = sink.read() + assert ( + actual == b'hi' + ) + + +def test_struct_with_xml_namespace_prefix() -> None: + """@xmlNamespace with prefix adds prefixed xmlns to root element.""" + sink = BytesIO() + serializer = XMLCodec().create_serializer(sink) + with serializer.begin_struct(PREFIXED_NS_STRUCT_SCHEMA) as ss: + ss.write_string(PREFIXED_NS_STRUCT_SCHEMA.members["value"], "hi") + serializer.flush() + sink.seek(0) + actual = sink.read() + assert ( + actual + == b'hi' + ) diff --git a/pyproject.toml b/pyproject.toml index d7a23d99e..fa33308f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ members = ["packages/*"] smithy_core = { workspace = true } smithy_http = { workspace = true } smithy_json = { workspace = true } +smithy_xml = { workspace = true } smithy_aws_core = { workspace = true } smithy_aws_event_stream = { workspace = true } aws_sdk_signers = {workspace = true } diff --git a/uv.lock b/uv.lock index ca20fac7e..53f054dc5 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ members = [ "smithy-http", "smithy-json", "smithy-python", + "smithy-xml", ] [[package]] @@ -777,6 +778,16 @@ test = [ ] typing = [{ name = "pyright", specifier = ">=1.1.400" }] +[[package]] +name = "smithy-xml" +source = { editable = "packages/smithy-xml" } +dependencies = [ + { name = "smithy-core" }, +] + +[package.metadata] +requires-dist = [{ name = "smithy-core", editable = "packages/smithy-core" }] + [[package]] name = "typing-extensions" version = "4.13.2"