diff --git a/pyproject.toml b/pyproject.toml index 1e9694f4..91e0cf62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "pyrate-limiter>=3.7.0,<4", "aiomqtt>=2.5.0,<3", "click-shell~=2.1", + "Pillow>=10,<12", ] [project.urls] diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index b1ef6626..1d33f9c0 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -251,7 +251,12 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat trait = b01.q10.create(channel) elif "sc" in model_part: # Q7 devices start with 'sc' in their model naming. - trait = b01.q7.create(channel) + trait = b01.q7.create( + channel, + local_key=device.local_key, + serial=device.sn, + model=product.model, + ) else: raise UnsupportedDeviceError(f"Device {device.name} has unsupported B01 model: {product.model}") case _: diff --git a/roborock/devices/rpc/b01_q7_channel.py b/roborock/devices/rpc/b01_q7_channel.py index add5bc97..eb5660a4 100644 --- a/roborock/devices/rpc/b01_q7_channel.py +++ b/roborock/devices/rpc/b01_q7_channel.py @@ -5,6 +5,7 @@ import asyncio import json import logging +import weakref from typing import Any from roborock.devices.transport.mqtt_channel import MqttChannel @@ -14,10 +15,19 @@ decode_rpc_response, encode_mqtt_payload, ) -from roborock.roborock_message import RoborockMessage +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol _LOGGER = logging.getLogger(__name__) _TIMEOUT = 10.0 +_map_command_locks: weakref.WeakKeyDictionary[MqttChannel, asyncio.Lock] = weakref.WeakKeyDictionary() + + +def _get_map_command_lock(mqtt_channel: MqttChannel) -> asyncio.Lock: + lock = _map_command_locks.get(mqtt_channel) + if lock is None: + lock = asyncio.Lock() + _map_command_locks[mqtt_channel] = lock + return lock async def send_decoded_command( @@ -99,3 +109,62 @@ def find_response(response_message: RoborockMessage) -> None: raise finally: unsub() + + +async def send_map_command(mqtt_channel: MqttChannel, request_message: Q7RequestMessage) -> bytes: + """Send map upload command and wait for MAP_RESPONSE payload bytes. + + Map requests are serialized per channel so concurrent map calls cannot + cross-wire responses between callers. + """ + + async with _get_map_command_lock(mqtt_channel): + roborock_message = encode_mqtt_payload(request_message) + future: asyncio.Future[bytes] = asyncio.get_running_loop().create_future() + + def find_response(response_message: RoborockMessage) -> None: + if future.done(): + return + + if response_message.protocol == RoborockMessageProtocol.MAP_RESPONSE and response_message.payload: + if not future.done(): + future.set_result(response_message.payload) + return + + try: + decoded_dps = decode_rpc_response(response_message) + except RoborockException: + return + + for dps_value in decoded_dps.values(): + if not isinstance(dps_value, str): + continue + try: + inner = json.loads(dps_value) + except (json.JSONDecodeError, TypeError): + continue + if not isinstance(inner, dict) or inner.get("msgId") != str(request_message.msg_id): + continue + code = inner.get("code", 0) + if code != 0: + if not future.done(): + future.set_exception( + RoborockException(f"B01 command failed with code {code} ({request_message})") + ) + return + data = inner.get("data") + if isinstance(data, dict) and isinstance(data.get("payload"), str): + try: + if not future.done(): + future.set_result(bytes.fromhex(data["payload"])) + except ValueError: + pass + + unsub = await mqtt_channel.subscribe(find_response) + try: + await mqtt_channel.publish(roborock_message) + return await asyncio.wait_for(future, timeout=_TIMEOUT) + except TimeoutError as ex: + raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex + finally: + unsub() diff --git a/roborock/devices/traits/b01/q7/__init__.py b/roborock/devices/traits/b01/q7/__init__.py index 9c09c05c..1af829d3 100644 --- a/roborock/devices/traits/b01/q7/__init__.py +++ b/roborock/devices/traits/b01/q7/__init__.py @@ -21,10 +21,12 @@ from roborock.roborock_typing import RoborockB01Q7Methods from .clean_summary import CleanSummaryTrait +from .map_content import Q7MapContentTrait __all__ = [ "Q7PropertiesApi", "CleanSummaryTrait", + "Q7MapContentTrait", ] @@ -33,11 +35,24 @@ class Q7PropertiesApi(Trait): clean_summary: CleanSummaryTrait """Trait for clean records / clean summary (Q7 `service.get_record_list`).""" - - def __init__(self, channel: MqttChannel) -> None: + map_content: Q7MapContentTrait | None + + def __init__( + self, + channel: MqttChannel, + *, + local_key: str | None = None, + serial: str | None = None, + model: str | None = None, + ) -> None: """Initialize the B01Props API.""" self._channel = channel self.clean_summary = CleanSummaryTrait(channel) + if local_key and serial and model: + self.map_content = Q7MapContentTrait(channel, local_key=local_key, serial=serial, model=model) + else: + # Keep backwards compatibility for direct callers that only use command/query traits. + self.map_content = None async def query_values(self, props: list[RoborockB01Props]) -> B01Props | None: """Query the device for the values of the given Q7 properties.""" @@ -87,6 +102,17 @@ async def start_clean(self) -> None: }, ) + async def clean_segments(self, segment_ids: list[int]) -> None: + """Start segment cleaning for the given ids (Q7 uses room ids).""" + await self.send( + command=RoborockB01Q7Methods.SET_ROOM_CLEAN, + params={ + "clean_type": CleanTaskTypeMapping.ROOM.code, + "ctrl_value": SCDeviceCleanParam.START.code, + "room_ids": segment_ids, + }, + ) + async def pause_clean(self) -> None: """Pause cleaning.""" await self.send( @@ -131,6 +157,12 @@ async def send(self, command: CommandType, params: ParamsType) -> Any: ) -def create(channel: MqttChannel) -> Q7PropertiesApi: - """Create traits for B01 devices.""" - return Q7PropertiesApi(channel) +def create( + channel: MqttChannel, + *, + local_key: str | None = None, + serial: str | None = None, + model: str | None = None, +) -> Q7PropertiesApi: + """Create traits for B01 Q7 devices.""" + return Q7PropertiesApi(channel, local_key=local_key, serial=serial, model=model) diff --git a/roborock/devices/traits/b01/q7/map_content.py b/roborock/devices/traits/b01/q7/map_content.py new file mode 100644 index 00000000..0d9ac6ea --- /dev/null +++ b/roborock/devices/traits/b01/q7/map_content.py @@ -0,0 +1,80 @@ +"""Map content trait for B01/Q7 devices.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from roborock.devices.rpc.b01_q7_channel import send_decoded_command, send_map_command +from roborock.devices.traits import Trait +from roborock.devices.traits.v1.map_content import MapContent +from roborock.devices.transport.mqtt_channel import MqttChannel +from roborock.exceptions import RoborockException +from roborock.map.b01_map_parser import decode_b01_map_payload, parse_scmap_payload, render_map_png +from roborock.protocols.b01_q7_protocol import Q7RequestMessage +from roborock.roborock_typing import RoborockB01Q7Methods + + +@dataclass +class B01MapContent(MapContent): + """B01 map content wrapper.""" + + rooms: dict[int, str] | None = None + + +def _extract_current_map_id(map_list_response: dict[str, Any] | None) -> int | None: + if not isinstance(map_list_response, dict): + return None + map_list = map_list_response.get("map_list") + if not isinstance(map_list, list) or not map_list: + return None + + for entry in map_list: + if isinstance(entry, dict) and entry.get("cur") and isinstance(entry.get("id"), int): + return entry["id"] + + first = map_list[0] + if isinstance(first, dict) and isinstance(first.get("id"), int): + return first["id"] + return None + + +class Q7MapContentTrait(B01MapContent, Trait): + """Fetch and parse map content from B01/Q7 devices.""" + + def __init__(self, channel: MqttChannel, *, local_key: str, serial: str, model: str) -> None: + super().__init__() + self._channel = channel + self._local_key = local_key + self._serial = serial + self._model = model + + async def refresh(self) -> B01MapContent: + map_list_response = await send_decoded_command( + self._channel, + Q7RequestMessage(dps=10000, command=RoborockB01Q7Methods.GET_MAP_LIST, params={}), + ) + map_id = _extract_current_map_id(map_list_response) + if map_id is None: + raise RoborockException(f"Unable to determine map_id from map list response: {map_list_response!r}") + + raw_payload = await send_map_command( + self._channel, + Q7RequestMessage( + dps=10000, + command=RoborockB01Q7Methods.UPLOAD_BY_MAPID, + params={"map_id": map_id}, + ), + ) + inflated = decode_b01_map_payload( + raw_payload, + local_key=self._local_key, + serial=self._serial, + model=self._model, + ) + parsed = parse_scmap_payload(inflated) + self.raw_api_response = raw_payload + self.map_data = None + self.rooms = parsed.rooms + self.image_content = render_map_png(parsed) + return self diff --git a/roborock/map/__init__.py b/roborock/map/__init__.py index 9835b81d..75adb4e1 100644 --- a/roborock/map/__init__.py +++ b/roborock/map/__init__.py @@ -1,7 +1,13 @@ -"""Module for Roborock map related data classes.""" +"""Utilities and data classes for working with Roborock maps.""" +from .b01_map_parser import B01MapData, decode_b01_map_payload, parse_scmap_payload, render_map_png from .map_parser import MapParserConfig, ParsedMapData __all__ = [ + "B01MapData", "MapParserConfig", + "ParsedMapData", + "decode_b01_map_payload", + "parse_scmap_payload", + "render_map_png", ] diff --git a/roborock/map/b01_map_parser.py b/roborock/map/b01_map_parser.py new file mode 100644 index 00000000..0a155fc7 --- /dev/null +++ b/roborock/map/b01_map_parser.py @@ -0,0 +1,259 @@ +"""B01/Q7 SCMap decoding and rendering support.""" + +from __future__ import annotations + +import base64 +import hashlib +import io +import zlib +from dataclasses import dataclass + +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad +from PIL import Image + +from roborock.exceptions import RoborockException + +_B01_HASH = "5wwh9ikChRjASpMU8cxg7o1d2E" + + +@dataclass +class B01MapData: + """Parsed B01 map payload.""" + + size_x: int + size_y: int + map_data: bytes + rooms: dict[int, str] | None = None + + +def _read_varint(buf: bytes, idx: int) -> tuple[int, int]: + value = 0 + shift = 0 + while True: + if idx >= len(buf): + raise RoborockException("Truncated varint in B01 map payload") + b = buf[idx] + idx += 1 + value |= (b & 0x7F) << shift + if not b & 0x80: + return value, idx + shift += 7 + if shift > 63: + raise RoborockException("Invalid varint in B01 map payload") + + +def _read_len_delimited(buf: bytes, idx: int) -> tuple[bytes, int]: + length, idx = _read_varint(buf, idx) + end = idx + length + if end > len(buf): + raise RoborockException("Invalid length-delimited field in B01 map payload") + return buf[idx:end], end + + +def _parse_map_data_info(blob: bytes) -> bytes: + idx = 0 + while idx < len(blob): + key, idx = _read_varint(blob, idx) + field_no = key >> 3 + wire = key & 0x07 + if wire == 0: + _, idx = _read_varint(blob, idx) + elif wire == 2: + value, idx = _read_len_delimited(blob, idx) + if field_no == 1: + try: + return zlib.decompress(value) + except zlib.error: + return value + elif wire == 5: + idx += 4 + else: + raise RoborockException(f"Unsupported wire type {wire} in B01 map data info") + raise RoborockException("B01 map payload missing mapData") + + +def _parse_room_data_info(blob: bytes) -> tuple[int | None, str | None]: + room_id: int | None = None + room_name: str | None = None + idx = 0 + while idx < len(blob): + key, idx = _read_varint(blob, idx) + field_no = key >> 3 + wire = key & 0x07 + if wire == 0: + value, idx = _read_varint(blob, idx) + if field_no == 1: + room_id = int(value) + elif wire == 2: + value, idx = _read_len_delimited(blob, idx) + if field_no == 2: + room_name = value.decode("utf-8", errors="replace") + elif wire == 5: + idx += 4 + else: + raise RoborockException(f"Unsupported wire type {wire} in B01 room data info") + return room_id, room_name + + +def parse_scmap_payload(payload: bytes) -> B01MapData: + """Parse SCMap protobuf payload and extract occupancy grid bytes and room names.""" + + size_x = 0 + size_y = 0 + grid = b"" + rooms: dict[int, str] = {} + idx = 0 + while idx < len(payload): + key, idx = _read_varint(payload, idx) + field_no = key >> 3 + wire = key & 0x07 + if wire == 0: + _, idx = _read_varint(payload, idx) + continue + if wire != 2: + if wire == 5: + idx += 4 + continue + raise RoborockException(f"Unsupported wire type {wire} in B01 map payload") + value, idx = _read_len_delimited(payload, idx) + if field_no == 3: # mapHead + hidx = 0 + while hidx < len(value): + hkey, hidx = _read_varint(value, hidx) + hfield = hkey >> 3 + hwire = hkey & 0x07 + if hwire == 0: + hvalue, hidx = _read_varint(value, hidx) + if hfield == 2: + size_x = int(hvalue) + elif hfield == 3: + size_y = int(hvalue) + elif hwire == 5: + hidx += 4 + elif hwire == 2: + _, hidx = _read_len_delimited(value, hidx) + else: + raise RoborockException(f"Unsupported wire type {hwire} in B01 map header") + elif field_no == 4: # mapDataInfo + grid = _parse_map_data_info(value) + elif field_no == 12: # roomDataInfo (repeated) + room_id, room_name = _parse_room_data_info(value) + if room_id is not None: + rooms[room_id] = room_name or f"Room {room_id}" + + if not size_x or not size_y or not grid: + raise RoborockException("Failed to parse B01 map header/grid") + if len(grid) < size_x * size_y: + raise RoborockException("B01 map data shorter than expected dimensions") + return B01MapData(size_x=size_x, size_y=size_y, map_data=grid, rooms=rooms or None) + + +def _derive_b01_iv(iv_seed: int) -> bytes: + random_hex = iv_seed.to_bytes(4, "big").hex().lower() + md5 = hashlib.md5((random_hex + _B01_HASH).encode(), usedforsecurity=False).hexdigest() + return md5[9:25].encode() + + +def derive_map_key(serial: str, model: str) -> bytes: + """Derive map decrypt key for B01/Q7 map payloads.""" + + model_suffix = model.split(".")[-1] + model_key = (model_suffix + "0" * 16)[:16].encode() + material = f"{serial}+{model_suffix}+{serial}".encode() + encrypted = AES.new(model_key, AES.MODE_ECB).encrypt(pad(material, AES.block_size)) + md5 = hashlib.md5(base64.b64encode(encrypted), usedforsecurity=False).hexdigest() + return md5[8:24].encode() + + +def _maybe_b64(data: bytes) -> bytes | None: + try: + return base64.b64decode(data, validate=False) + except Exception: + return None + + +def decode_b01_map_payload(raw_payload: bytes, *, local_key: str, serial: str, model: str) -> bytes: + """Decode raw B01 MAP_RESPONSE payload into inflated SCMap protobuf bytes.""" + + layers: list[bytes] = [] + l0 = _maybe_b64(raw_payload) + if l0 is not None: + layers.append(l0) + l1 = _maybe_b64(l0) + if l1 is not None: + layers.append(l1) + else: + layers.append(raw_payload) + + map_key = derive_map_key(serial, model) + for layer in layers: + candidates: list[bytes] = [layer] + if len(layer) > 19 and layer[:3] == b"B01": + iv_seed = int.from_bytes(layer[7:11], "big") + payload_len = int.from_bytes(layer[17:19], "big") + encrypted = layer[19 : 19 + payload_len] + try: + decrypted = AES.new(local_key.encode(), AES.MODE_CBC, _derive_b01_iv(iv_seed)).decrypt(encrypted) + candidates.append(unpad(decrypted, 16)) + except Exception: + pass + + for candidate in list(candidates): + if len(candidate) % 16 == 0: + try: + decrypted = AES.new(map_key, AES.MODE_ECB).decrypt(candidate) + candidates.append(decrypted) + candidates.append(unpad(decrypted, 16)) + except Exception: + pass + + for candidate in candidates: + variants = [candidate] + try: + text = candidate.decode("ascii").strip() + if len(text) > 16 and all(c in "0123456789abcdefABCDEF" for c in text[:32]): + variants.append(bytes.fromhex(text)) + except Exception: + pass + for variant in variants: + try: + inflated = zlib.decompress(variant) + except zlib.error: + continue + parse_scmap_payload(inflated) + return inflated + + raise RoborockException("Failed to decode B01 map payload") + + +def render_map_png(map_data: B01MapData) -> bytes: + """Render occupancy map bytes into PNG.""" + + img = Image.new("RGB", (map_data.size_x, map_data.size_y), (0, 0, 0)) + px = img.load() + room_colors = [ + (80, 150, 255), + (255, 170, 80), + (120, 220, 130), + (210, 130, 255), + (255, 120, 170), + (100, 220, 220), + ] + + for i, value in enumerate(map_data.map_data[: map_data.size_x * map_data.size_y]): + x = i % map_data.size_x + y = map_data.size_y - 1 - (i // map_data.size_x) + if value == 0: + color = (0, 0, 0) + elif value in (1, 127): + color = (180, 180, 180) + elif value >= 128: + color = (255, 255, 255) + else: + color = room_colors[(max(value - 2, 0)) % len(room_colors)] + px[x, y] = color + + output = io.BytesIO() + img.save(output, format="PNG") + return output.getvalue() diff --git a/tests/devices/traits/b01/q7/conftest.py b/tests/devices/traits/b01/q7/conftest.py index 5dc476f6..160c37d3 100644 --- a/tests/devices/traits/b01/q7/conftest.py +++ b/tests/devices/traits/b01/q7/conftest.py @@ -18,7 +18,7 @@ def fake_channel_fixture() -> FakeChannel: @pytest.fixture(name="q7_api") def q7_api_fixture(fake_channel: FakeChannel) -> Q7PropertiesApi: - return Q7PropertiesApi(fake_channel) # type: ignore[arg-type] + return Q7PropertiesApi(fake_channel, local_key="abcdefghijklmnop", serial="test_sn", model="roborock.vacuum.sc05") # type: ignore[arg-type] @pytest.fixture(name="expected_msg_id", autouse=True) diff --git a/tests/devices/traits/b01/q7/test_init.py b/tests/devices/traits/b01/q7/test_init.py index cb16299c..79063487 100644 --- a/tests/devices/traits/b01/q7/test_init.py +++ b/tests/devices/traits/b01/q7/test_init.py @@ -16,8 +16,9 @@ from roborock.devices.rpc.b01_q7_channel import send_decoded_command from roborock.devices.traits.b01.q7 import Q7PropertiesApi from roborock.exceptions import RoborockException +from roborock.map.b01_map_parser import B01MapData from roborock.protocols.b01_q7_protocol import B01_VERSION, Q7RequestMessage -from roborock.roborock_message import RoborockB01Props, RoborockMessageProtocol +from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol from tests.fixtures.channel_fixtures import FakeChannel from . import B01MessageBuilder @@ -257,3 +258,129 @@ async def test_q7_api_find_me(q7_api: Q7PropertiesApi, fake_channel: FakeChannel payload_data = json.loads(unpad(message.payload, AES.block_size)) assert payload_data["dps"]["10000"]["method"] == "service.find_device" assert payload_data["dps"]["10000"]["params"] == {} + + +async def test_q7_api_clean_segments( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): + """Test room/segment cleaning helper for Q7.""" + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.clean_segments([10, 11]) + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + payload_data = json.loads(unpad(message.payload, AES.block_size)) + assert payload_data["dps"]["10000"]["method"] == "service.set_room_clean" + assert payload_data["dps"]["10000"]["params"] == { + "clean_type": CleanTaskTypeMapping.ROOM.code, + "ctrl_value": SCDeviceCleanParam.START.code, + "room_ids": [10, 11], + } + + +async def test_q7_map_content_refresh_from_map_response( + q7_api: Q7PropertiesApi, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, + monkeypatch: pytest.MonkeyPatch, +): + """Test Q7 map content refresh wiring through map list + MAP_RESPONSE payload path.""" + + fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 1772093512, "cur": True}]})) + fake_channel.response_queue.append( + RoborockMessage( + protocol=RoborockMessageProtocol.MAP_RESPONSE, + payload=b"raw-map-payload", + version=b"B01", + seq=message_builder.seq + 1, + ) + ) + + monkeypatch.setattr( + "roborock.devices.traits.b01.q7.map_content.decode_b01_map_payload", + lambda raw_payload, **kwargs: b"inflated-payload", + ) + monkeypatch.setattr( + "roborock.devices.traits.b01.q7.map_content.parse_scmap_payload", + lambda payload: B01MapData(size_x=1, size_y=1, map_data=b"\x01"), + ) + monkeypatch.setattr( + "roborock.devices.traits.b01.q7.map_content.render_map_png", + lambda parsed: b"\x89PNG-test", + ) + + result = await q7_api.map_content.refresh() + + assert result.image_content == b"\x89PNG-test" + assert result.raw_api_response == b"raw-map-payload" + + assert len(fake_channel.published_messages) == 2 + + first = fake_channel.published_messages[0] + first_payload = json.loads(unpad(first.payload, AES.block_size)) + assert first_payload["dps"]["10000"]["method"] == "service.get_map_list" + assert first_payload["dps"]["10000"]["params"] == {} + + second = fake_channel.published_messages[1] + second_payload = json.loads(unpad(second.payload, AES.block_size)) + assert second_payload["dps"]["10000"]["method"] == "service.upload_by_mapid" + assert second_payload["dps"]["10000"]["params"] == {"map_id": 1772093512} + + +async def test_q7_map_content_refresh_falls_back_to_first_map( + q7_api: Q7PropertiesApi, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, + monkeypatch: pytest.MonkeyPatch, +): + """If no map is marked current, use first map from map_list.""" + + fake_channel.response_queue.append( + message_builder.build({"map_list": [{"id": 111}, {"id": 222, "cur": False}]}) + ) + fake_channel.response_queue.append( + RoborockMessage( + protocol=RoborockMessageProtocol.MAP_RESPONSE, + payload=b"raw-map-payload", + version=b"B01", + seq=message_builder.seq + 1, + ) + ) + + monkeypatch.setattr( + "roborock.devices.traits.b01.q7.map_content.decode_b01_map_payload", + lambda raw_payload, **kwargs: b"inflated-payload", + ) + monkeypatch.setattr( + "roborock.devices.traits.b01.q7.map_content.parse_scmap_payload", + lambda payload: B01MapData(size_x=1, size_y=1, map_data=b"\x01"), + ) + monkeypatch.setattr( + "roborock.devices.traits.b01.q7.map_content.render_map_png", + lambda parsed: b"\x89PNG-test", + ) + + await q7_api.map_content.refresh() + + second = fake_channel.published_messages[1] + second_payload = json.loads(unpad(second.payload, AES.block_size)) + assert second_payload["dps"]["10000"]["params"] == {"map_id": 111} + + +async def test_q7_map_content_refresh_errors_without_map_list( + q7_api: Q7PropertiesApi, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +): + """Map refresh should fail clearly when map list response is unusable.""" + + fake_channel.response_queue.append(message_builder.build({"map_list": []})) + + with pytest.raises(RoborockException, match="Unable to determine map_id"): + await q7_api.map_content.refresh() + + +async def test_q7_api_constructor_backwards_compatible_without_map_context(fake_channel: FakeChannel): + """Direct API construction without map context should still work.""" + api = Q7PropertiesApi(fake_channel) # type: ignore[arg-type] + assert api.map_content is None diff --git a/tests/fixtures/b01/raw-mqtt-map301.bin.inflated.bin b/tests/fixtures/b01/raw-mqtt-map301.bin.inflated.bin new file mode 100644 index 00000000..8c762199 Binary files /dev/null and b/tests/fixtures/b01/raw-mqtt-map301.bin.inflated.bin differ diff --git a/tests/map/test_b01_map_parser.py b/tests/map/test_b01_map_parser.py new file mode 100644 index 00000000..d7eaf508 --- /dev/null +++ b/tests/map/test_b01_map_parser.py @@ -0,0 +1,47 @@ +"""Tests for B01/Q7 map decoder/parser/renderer.""" + +from __future__ import annotations + +import base64 +import zlib +from pathlib import Path + +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad + +from roborock.map.b01_map_parser import decode_b01_map_payload, derive_map_key, parse_scmap_payload, render_map_png + +FIXTURE = Path(__file__).resolve().parents[1] / "fixtures" / "b01" / "raw-mqtt-map301.bin.inflated.bin" + + +def test_parse_scmap_payload_fixture() -> None: + payload = FIXTURE.read_bytes() + parsed = parse_scmap_payload(payload) + assert parsed.size_x == 340 + assert parsed.size_y == 300 + assert len(parsed.map_data) >= parsed.size_x * parsed.size_y + assert parsed.rooms is not None + assert parsed.rooms.get(10) == "room1" + + +def test_render_map_png_fixture() -> None: + payload = FIXTURE.read_bytes() + parsed = parse_scmap_payload(payload) + png = render_map_png(parsed) + assert png.startswith(b"\x89PNG\r\n\x1a\n") + assert len(png) > 1024 + + +def test_decode_b01_map_payload_round_trip() -> None: + local_key = "abcdefghijklmnop" + serial = "testsn012345" + model = "roborock.vacuum.sc05" + inflated = FIXTURE.read_bytes() + + compressed = zlib.compress(inflated) + map_key = derive_map_key(serial, model) + encrypted = AES.new(map_key, AES.MODE_ECB).encrypt(pad(compressed.hex().encode(), 16)) + payload = base64.b64encode(base64.b64encode(encrypted)) + + decoded = decode_b01_map_payload(payload, local_key=local_key, serial=serial, model=model) + assert decoded == inflated diff --git a/uv.lock b/uv.lock index 06463df6..afca5390 100644 --- a/uv.lock +++ b/uv.lock @@ -1329,6 +1329,7 @@ dependencies = [ { name = "click-shell" }, { name = "construct" }, { name = "paho-mqtt" }, + { name = "pillow" }, { name = "pycryptodome" }, { name = "pycryptodomex", marker = "sys_platform == 'darwin'" }, { name = "pyrate-limiter" }, @@ -1361,6 +1362,7 @@ requires-dist = [ { name = "click-shell", specifier = "~=2.1" }, { name = "construct", specifier = ">=2.10.57,<3" }, { name = "paho-mqtt", specifier = ">=1.6.1,<3.0.0" }, + { name = "pillow", specifier = ">=10,<12" }, { name = "pycryptodome", specifier = "~=3.18" }, { name = "pycryptodomex", marker = "sys_platform == 'darwin'", specifier = "~=3.18" }, { name = "pyrate-limiter", specifier = ">=3.7.0,<4" },