From 1901d6ec87e44855519cd4b789e1e58365b1a56e Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Sun, 10 May 2026 13:02:49 -0600 Subject: [PATCH] Async client + mock panel + e2e roundtrip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit src/omni_pca/connection.py — low-level OmniConnection - 4-step secure-session handshake (NewSession, SecureSession) - Per-direction monotonic seq with 0xFFFF -> 1 wraparound (skips 0) - TCP framing: read first 16-byte block, decrypt, learn length, read rest - Reader task dispatches solicited replies to Future, unsolicited to queue - Custom exceptions: HandshakeError, InvalidEncryptionKeyError, ProtocolError, RequestTimeoutError src/omni_pca/models.py — typed response objects - SystemInformation (with model_name lookup), SystemStatus, ZoneProperties, UnitProperties, AreaProperties — all frozen+slots dataclasses with .parse(payload) classmethods src/omni_pca/client.py — high-level OmniClient - get_system_information / get_system_status / get_object_properties - list_{zone,unit,area}_names walks via RequestProperties rel=1 - subscribe(callback) for unsolicited messages src/omni_pca/mock_panel.py — async TCP server emulating an Omni Pro II - Full handshake (controller side), seedable MockState - Implements RequestSystemInformation, RequestSystemStatus, RequestProperties (Zone/Unit/Area, both absolute and rel=1 iteration with EOD termination); Nak for everything else - 'omni-pca mock-panel' CLI subcommand tests/ — 85 passed, 1 skip (live fixture) - 23 unit tests for connection/models/client (canned-server fixtures) - 7 unit tests for mock panel (raw protocol drive) - 6 e2e tests: real OmniClient over real TCP to real MockPanel, proves handshake + AES + whitening + sequencing all agree --- src/omni_pca/__main__.py | 113 +++++++ src/omni_pca/client.py | 303 +++++++++++++++++ src/omni_pca/connection.py | 598 ++++++++++++++++++++++++++++++++++ src/omni_pca/mock_panel.py | 538 ++++++++++++++++++++++++++++++ src/omni_pca/models.py | 451 +++++++++++++++++++++++++ tests/test_client.py | 179 ++++++++++ tests/test_connection.py | 297 +++++++++++++++++ tests/test_e2e_client_mock.py | 101 ++++++ tests/test_mock_panel.py | 282 ++++++++++++++++ tests/test_models.py | 206 ++++++++++++ 10 files changed, 3068 insertions(+) create mode 100644 src/omni_pca/client.py create mode 100644 src/omni_pca/connection.py create mode 100644 src/omni_pca/mock_panel.py create mode 100644 src/omni_pca/models.py create mode 100644 tests/test_client.py create mode 100644 tests/test_connection.py create mode 100644 tests/test_e2e_client_mock.py create mode 100644 tests/test_mock_panel.py create mode 100644 tests/test_models.py diff --git a/src/omni_pca/__main__.py b/src/omni_pca/__main__.py index 0d51443..e5fbd6a 100644 --- a/src/omni_pca/__main__.py +++ b/src/omni_pca/__main__.py @@ -2,6 +2,8 @@ Subcommands: decode-pca [--key HEX] [--include-pii] [--field NAME] + mock-panel [--host H] [--port P] [--controller-key HEX] + [--zones-file PATH] [--seed-with-our-house FILE] version The default ``decode-pca`` output is **redacted**: account name, address, @@ -9,19 +11,31 @@ phone, codes and remarks never reach stdout unless the user passes ``--include-pii``. ``--field`` extracts a single value (host, port, controller_key) for shell scripting. +``mock-panel`` runs a local Omni-Link II controller simulator until +SIGINT — useful for driving the in-progress async client without poking +real hardware. + References: pca_file.py — decryption + parsing + mock_panel.py — controller-side TCP simulator """ from __future__ import annotations import argparse +import asyncio +import contextlib +import logging import sys from pathlib import Path from . import __version__ +from .mock_panel import MockPanel, MockState from .pca_file import KEY_EXPORT, KEY_PC01, PcaAccount, parse_pca_file +_DEFAULT_CONTROLLER_KEY_HEX = "00112233445566778899aabbccddeeff" +_DEFAULT_MOCK_PORT = 14369 + _ALLOWED_FIELDS = ("host", "port", "controller_key") @@ -53,6 +67,36 @@ def _build_parser() -> argparse.ArgumentParser: help="Print only one field for scripting (host, port, controller_key).", ) + pm = sub.add_parser( + "mock-panel", + help="Run a local Omni-Link II controller simulator (test harness).", + ) + pm.add_argument("--host", default="127.0.0.1", help="Bind host (default 127.0.0.1).") + pm.add_argument( + "--port", + type=int, + default=_DEFAULT_MOCK_PORT, + help=f"Bind port (default {_DEFAULT_MOCK_PORT}).", + ) + pm.add_argument( + "--controller-key", + default=_DEFAULT_CONTROLLER_KEY_HEX, + help="32 hex chars (16 bytes) for the panel ControllerKey.", + ) + pm.add_argument( + "--zones-file", + type=Path, + help="Plain text file: one 'INDEX NAME' per line, seeds MockState.zones.", + ) + pm.add_argument( + "--seed-with-our-house", + type=Path, + help="Path to a .pca file; its zones/units/areas seed MockState.", + ) + pm.add_argument( + "--debug", action="store_true", help="Verbose mock-panel debug logging." + ) + sub.add_parser("version", help="Print package version and exit.") return p @@ -117,6 +161,70 @@ def _print_field(account: PcaAccount, field: str) -> int: return 0 +def _parse_zones_file(path: Path) -> dict[int, str]: + """Read 'INDEX NAME' lines into a {idx: name} dict; '#' starts a comment.""" + out: dict[int, str] = {} + for raw in path.read_text(encoding="utf-8").splitlines(): + line = raw.split("#", 1)[0].strip() + if not line: + continue + idx_str, _, name = line.partition(" ") + try: + idx = int(idx_str) + except ValueError as exc: + raise ValueError( + f"{path}: cannot parse index from {raw!r} — expected 'INDEX NAME'" + ) from exc + out[idx] = name.strip() + return out + + +def _build_mock_state(args: argparse.Namespace) -> MockState: + state = MockState() + if args.seed_with_our_house is not None: + # The .pca header already gives us model + firmware; zone/unit/area + # name extraction from the body isn't wired up in pca_file yet. + _, account = _try_decode(args.seed_with_our_house, None) + state.model_byte = account.model + state.firmware_major = account.firmware_major + state.firmware_minor = account.firmware_minor + state.firmware_revision = account.firmware_revision + print( + f"# seeded model={account.model} fw={account.firmware_major}." + f"{account.firmware_minor}.{account.firmware_revision}", + file=sys.stderr, + ) + if args.zones_file is not None: + state.zones = _parse_zones_file(args.zones_file) + print(f"# loaded {len(state.zones)} zone names", file=sys.stderr) + return state + + +async def _run_mock_panel(args: argparse.Namespace) -> int: + if args.debug: + logging.basicConfig(level=logging.DEBUG, format="%(levelname)s %(name)s: %(message)s") + try: + controller_key = bytes.fromhex(args.controller_key) + except ValueError as exc: + print(f"error: --controller-key not valid hex: {exc}", file=sys.stderr) + return 2 + if len(controller_key) != 16: + print( + f"error: --controller-key must be 16 bytes (32 hex chars), got {len(controller_key)}", + file=sys.stderr, + ) + return 2 + state = _build_mock_state(args) + panel = MockPanel(controller_key=controller_key, state=state) + async with panel.serve(host=args.host, port=args.port) as (host, port): + print(f"omni-pca mock-panel listening on {host}:{port}") + print("# Ctrl-C to stop", file=sys.stderr) + with contextlib.suppress(asyncio.CancelledError, KeyboardInterrupt): + await asyncio.Event().wait() # block until cancelled + print(f"# served {panel.session_count} session(s)", file=sys.stderr) + return 0 + + def main(argv: list[str] | None = None) -> int: args = _build_parser().parse_args(argv) if args.cmd == "version": @@ -129,6 +237,11 @@ def main(argv: list[str] | None = None) -> int: print(f"# decoded with key=0x{used_key:08X}", file=sys.stderr) _print_summary(account, include_pii=args.include_pii) return 0 + if args.cmd == "mock-panel": + try: + return asyncio.run(_run_mock_panel(args)) + except KeyboardInterrupt: + return 0 return 2 # pragma: no cover diff --git a/src/omni_pca/client.py b/src/omni_pca/client.py new file mode 100644 index 0000000..0a60a85 --- /dev/null +++ b/src/omni_pca/client.py @@ -0,0 +1,303 @@ +"""High-level async client for the HAI/Leviton Omni-Link II protocol. + +This wraps :class:`OmniConnection` with typed methods that send the +appropriate v2 request opcode and parse the reply payload into one of +the dataclasses in :mod:`omni_pca.models`. + +Conventions: + * Indices are 1-based on the wire (zone 1 is index=1, not 0). + * ``RequestProperties`` uses ``relative_direction = 0`` for an exact + lookup (panel returns just that index, or NAK/EOD if absent). + * Walking with ``relative_direction = 1`` returns each next defined + object, used by the ``list_*`` helpers. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import struct +from collections.abc import Awaitable, Callable +from enum import IntEnum +from types import TracebackType +from typing import Self + +from .connection import ( + ConnectionError as OmniConnectionError, +) +from .connection import ( + OmniConnection, + RequestTimeoutError, +) +from .message import Message +from .models import ( + AreaProperties, + PropertiesReply, + SystemInformation, + SystemStatus, + UnitProperties, + ZoneProperties, +) +from .opcodes import OmniLink2MessageType + + +class ObjectType(IntEnum): + """``RequestProperties`` object-type discriminator (matches enuObjectType).""" + + ZONE = 1 + UNIT = 2 + BUTTON = 3 + CODE = 4 + AREA = 5 + THERMOSTAT = 6 + MESSAGE = 7 + AUX_SENSOR = 8 + AUDIO_SOURCE = 9 + AUDIO_ZONE = 10 + EXP_ENCLOSURE = 11 + CONSOLE = 12 + USER_SETTING = 13 + ACCESS_CONTROL = 14 + + +# Maps the request side to the parser side. Only types we actively +# support get an entry; the rest fall through to a generic raw-payload +# return for now. +_PROPERTIES_PARSERS: dict[ObjectType, type[PropertiesReply]] = { + ObjectType.ZONE: ZoneProperties, + ObjectType.UNIT: UnitProperties, + ObjectType.AREA: AreaProperties, +} + + +class OmniClient: + """High-level async Omni-Link II client. + + Use as an async context manager, then call typed methods: + + .. code-block:: python + + async with OmniClient(host, port=4369, controller_key=KEY) as client: + info = await client.get_system_information() + zones = await client.list_zone_names() + """ + + def __init__( + self, + host: str, + port: int = 4369, + *, + controller_key: bytes, + timeout: float = 5.0, + ) -> None: + self._conn = OmniConnection( + host=host, + port=port, + controller_key=controller_key, + timeout=timeout, + ) + self._subscriber_task: asyncio.Task[None] | None = None + + # ---- lifecycle ------------------------------------------------------- + + async def __aenter__(self) -> Self: + await self._conn.connect() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + if self._subscriber_task is not None and not self._subscriber_task.done(): + self._subscriber_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await self._subscriber_task + await self._conn.close() + + @property + def connection(self) -> OmniConnection: + """The underlying low-level connection (for advanced use).""" + return self._conn + + # ---- typed requests -------------------------------------------------- + + async def get_system_information(self) -> SystemInformation: + reply = await self._conn.request(OmniLink2MessageType.RequestSystemInformation) + self._expect(reply, OmniLink2MessageType.SystemInformation) + return SystemInformation.parse(reply.payload) + + async def get_system_status(self) -> SystemStatus: + reply = await self._conn.request(OmniLink2MessageType.RequestSystemStatus) + self._expect(reply, OmniLink2MessageType.SystemStatus) + return SystemStatus.parse(reply.payload) + + async def get_object_properties( + self, + object_type: ObjectType, + index: int, + ) -> PropertiesReply: + """Fetch one Properties reply for the given object. + + Returns the appropriate dataclass for ``object_type``. Raises + :class:`ValueError` if the panel doesn't have an object at that + index, or :class:`NotImplementedError` if we don't yet have a + parser for that object type. + """ + parser = _PROPERTIES_PARSERS.get(object_type) + if parser is None: + raise NotImplementedError( + f"no parser for object type {object_type.name}" + ) + payload = self._build_request_properties_payload( + object_type=object_type, + index=index, + relative_direction=0, + ) + reply = await self._conn.request( + OmniLink2MessageType.RequestProperties, payload + ) + if reply.opcode == OmniLink2MessageType.EOD: + raise ValueError( + f"no {object_type.name} at index {index} (panel returned EOD)" + ) + if reply.opcode == OmniLink2MessageType.Nak: + raise ValueError( + f"panel NAK'd Properties request for {object_type.name}#{index}" + ) + self._expect(reply, OmniLink2MessageType.Properties) + return parser.parse(reply.payload) + + async def list_zone_names(self) -> dict[int, str]: + """Walk all zones, returning ``{index: name}`` for those with a name set.""" + return await self._walk_named_objects( + ObjectType.ZONE, + lambda r: (r.index, r.name) if isinstance(r, ZoneProperties) else None, + ) + + async def list_unit_names(self) -> dict[int, str]: + return await self._walk_named_objects( + ObjectType.UNIT, + lambda r: (r.index, r.name) if isinstance(r, UnitProperties) else None, + ) + + async def list_area_names(self) -> dict[int, str]: + return await self._walk_named_objects( + ObjectType.AREA, + lambda r: (r.index, r.name) if isinstance(r, AreaProperties) else None, + ) + + async def subscribe( + self, callback: Callable[[Message], Awaitable[None]] + ) -> None: + """Run ``callback`` for every unsolicited message until cancelled. + + Spawns a background task. If you call ``subscribe`` more than + once the previous subscription is cancelled (we don't fan out). + """ + if self._subscriber_task is not None and not self._subscriber_task.done(): + self._subscriber_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await self._subscriber_task + + async def _runner() -> None: + async for msg in self._conn.unsolicited(): + try: + await callback(msg) + except Exception: + # Don't let a bad callback kill the subscription; + # just log via the connection's logger. + import logging + + logging.getLogger(__name__).exception( + "unsolicited callback raised" + ) + + self._subscriber_task = asyncio.create_task( + _runner(), name="omni-client-subscriber" + ) + + # ---- helpers --------------------------------------------------------- + + @staticmethod + def _expect(reply: Message, expected: OmniLink2MessageType) -> None: + if reply.opcode != int(expected): + raise OmniConnectionError( + f"expected opcode {expected.name} ({int(expected)}), " + f"got {reply.opcode}" + ) + + @staticmethod + def _build_request_properties_payload( + object_type: ObjectType, + index: int, + relative_direction: int, + filter1: int = 0, + filter2: int = 0, + filter3: int = 0, + ) -> bytes: + """Build the 7-byte payload for a RequestProperties (opcode 32) message. + + Layout (clsOL2MsgRequestProperties.cs, after stripping opcode): + 0 object type + 1..2 index (BE ushort) + 3 relative direction (signed: 0=exact, +1=next, -1=prev) + 4..6 filters (per-type bitmasks) + """ + if not 0 <= index <= 0xFFFF: + raise ValueError(f"index out of range: {index}") + rd = relative_direction & 0xFF + return struct.pack( + ">BHBBBB", + int(object_type), + index, + rd, + filter1, + filter2, + filter3, + ) + + async def _walk_named_objects( + self, + object_type: ObjectType, + extract: Callable[[PropertiesReply], tuple[int, str] | None], + ) -> dict[int, str]: + """Walk every defined object of ``object_type`` and collect non-empty names. + + We use ``relative_direction=1`` (next) starting from index 0 to + let the panel hand us each defined object in turn until it + returns EOD (end-of-data, opcode 3). + """ + names: dict[int, str] = {} + cursor = 0 + # Bound the walk to the protocol max (ushort) just in case the + # panel keeps echoing. + for _ in range(0xFFFF): + payload = self._build_request_properties_payload( + object_type=object_type, + index=cursor, + relative_direction=1, + ) + try: + reply = await self._conn.request( + OmniLink2MessageType.RequestProperties, payload + ) + except RequestTimeoutError: + break + if reply.opcode == OmniLink2MessageType.EOD: + break + if reply.opcode != OmniLink2MessageType.Properties: + break + parser = _PROPERTIES_PARSERS.get(object_type) + if parser is None: # pragma: no cover - guarded above + break + parsed = parser.parse(reply.payload) + pair = extract(parsed) + if pair is not None and pair[1]: + names[pair[0]] = pair[1] + # Advance: ask for the next index after the one we just got. + cursor = parsed.index + if cursor >= 0xFFFF: + break + return names diff --git a/src/omni_pca/connection.py b/src/omni_pca/connection.py new file mode 100644 index 0000000..5f59301 --- /dev/null +++ b/src/omni_pca/connection.py @@ -0,0 +1,598 @@ +"""Async TCP connection to an HAI/Leviton Omni-Link II controller. + +This is the foundation layer. It owns the asyncio TCP socket, drives the +4-step secure-session handshake, frames inner ``Message`` objects inside +outer ``Packet`` envelopes, and routes solicited replies (matched by the +client's outer sequence number) to per-request ``Future`` objects while +shoveling unsolicited push packets (seq=0) into a queue exposed via +:meth:`OmniConnection.unsolicited`. + +References (line numbers into HAI/pca-re/decompiled/project/HAI_Shared +/clsOmniLinkConnection.cs): + 1688-1697 send empty ClientRequestNewSession on connect + 1714-1758 TCP frame reader: per-block decrypt-to-learn-length pattern + 1796 solicited reply matched by SequenceNumber == pktSequence + 1847-1854 unsolicited dispatch when SequenceNumber == 0 + 1864-1921 step 2 handler: derive key, enqueue step 3 + 1923-1947 step 4 handler: transition to OnlineSecure + 1808 ControllerSessionTerminated during handshake => InvalidEncryptionKey +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import AsyncIterator +from enum import IntEnum +from types import TracebackType + +from .crypto import ( + BLOCK_SIZE, + decrypt_message_payload, + derive_session_key, + encrypt_message_payload, +) +from .message import Message, MessageCrcError, encode_v2 +from .opcodes import OmniLink2MessageType, PacketType +from .packet import MIN_PACKET_BYTES, Packet + +_log = logging.getLogger(__name__) + +_DEFAULT_PORT = 4369 +_HEADER_BYTES = MIN_PACKET_BYTES # 4 +_SESSION_ID_LEN = 5 +_PROTO_VERSION = (0x00, 0x01) +_MAX_SEQ = 0xFFFF + + +class ConnectionState(IntEnum): + """High-level state of the secure session.""" + + DISCONNECTED = 0 + CONNECTING = 1 + NEW_SESSION = 2 # ClientRequestNewSession sent, awaiting ControllerAckNewSession + SECURE = 3 # ClientRequestSecureSession sent, awaiting ControllerAckSecureSession + ONLINE = 4 + + +# ---- exceptions ---------------------------------------------------------- + + +class ConnectionError(OSError): # noqa: A001 - intentional shadow at module scope + """Generic transport-level failure (TCP closed unexpectedly, etc.).""" + + +class HandshakeError(ConnectionError): + """The 4-step secure-session handshake did not complete.""" + + +class InvalidEncryptionKeyError(HandshakeError): + """Controller answered ``ControllerSessionTerminated`` during handshake. + + Per clsOmniLinkConnection.cs:1808, this is the panel's way of saying + "your derived SessionKey didn't decrypt my echo correctly" — i.e. the + ControllerKey we used doesn't match the panel's NVRAM. + """ + + +class ProtocolError(ValueError): + """A received frame was structurally invalid.""" + + +class RequestTimeoutError(TimeoutError): + """A solicited request did not receive a reply in time.""" + + +# ---- the connection ------------------------------------------------------ + + +class OmniConnection: + """Low-level async Omni-Link II connection. + + Use as an async context manager: + + .. code-block:: python + + async with OmniConnection(host, port, controller_key) as conn: + reply = await conn.request(OmniLink2MessageType.RequestSystemInformation) + """ + + def __init__( + self, + host: str, + port: int = _DEFAULT_PORT, + controller_key: bytes = b"", + timeout: float = 5.0, + ) -> None: + if len(controller_key) != 16: + raise ValueError( + f"controller_key must be 16 bytes, got {len(controller_key)}" + ) + self._host = host + self._port = port + self._controller_key = bytes(controller_key) + self._default_timeout = timeout + + self._reader: asyncio.StreamReader | None = None + self._writer: asyncio.StreamWriter | None = None + self._state = ConnectionState.DISCONNECTED + + self._session_id: bytes | None = None + self._session_key: bytes | None = None + + # Client-side outbound sequence counter. The very first wire packet + # uses seq=1; every subsequent client packet bumps by 1, skipping 0 + # on wraparound (0 is reserved for unsolicited inbound). + self._next_seq: int = 1 + + # Solicited replies are matched on the seq number they were sent + # with; the controller echoes that seq back on the reply. + self._pending: dict[int, asyncio.Future[Packet]] = {} + + # Unsolicited inbound messages (panel-pushed events) land here. + self._unsolicited_queue: asyncio.Queue[Message] = asyncio.Queue() + + # Hands the handshake's step 2/4 packets to connect() while the + # reader task is running. Step 2 carries the SessionID; step 4 is + # just the encrypted ack. + self._handshake_event: asyncio.Event = asyncio.Event() + self._handshake_packet: Packet | None = None + self._handshake_error: Exception | None = None + + self._reader_task: asyncio.Task[None] | None = None + self._closed = False + + # ---- lifecycle ------------------------------------------------------- + + @property + def state(self) -> ConnectionState: + return self._state + + @property + def session_key(self) -> bytes | None: + """The derived per-session AES key, or ``None`` before handshake.""" + return self._session_key + + async def __aenter__(self) -> OmniConnection: + await self.connect() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self.close() + + async def connect(self) -> None: + """Open the TCP socket and run the 4-step secure-session handshake.""" + if self._state is not ConnectionState.DISCONNECTED: + raise ConnectionError(f"already connecting/connected (state={self._state})") + self._state = ConnectionState.CONNECTING + try: + self._reader, self._writer = await asyncio.wait_for( + asyncio.open_connection(self._host, self._port), + timeout=self._default_timeout, + ) + except (TimeoutError, OSError) as exc: + self._state = ConnectionState.DISCONNECTED + raise ConnectionError(f"failed to open TCP socket: {exc}") from exc + + self._reader_task = asyncio.create_task( + self._read_loop(), name=f"omni-conn-reader-{self._host}" + ) + + try: + await self._do_handshake() + except BaseException: + await self.close() + raise + + async def close(self) -> None: + """Tear down the TCP socket and reader task. Idempotent.""" + if self._closed: + return + self._closed = True + self._state = ConnectionState.DISCONNECTED + + # Cancel anyone still waiting for a reply. + for fut in self._pending.values(): + if not fut.done(): + fut.set_exception(ConnectionError("connection closed")) + self._pending.clear() + + if self._writer is not None: + try: + self._writer.close() + await self._writer.wait_closed() + except (OSError, RuntimeError): + pass + self._writer = None + self._reader = None + + if self._reader_task is not None and not self._reader_task.done(): + self._reader_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await self._reader_task + self._reader_task = None + + # ---- public request / receive API ----------------------------------- + + async def request( + self, + opcode: OmniLink2MessageType | int, + payload: bytes = b"", + timeout: float | None = None, + ) -> Message: + """Send a v2 request, await the matching reply, return the inner Message. + + The reply is matched on the outer packet sequence number (the + controller echoes the client's seq for solicited replies). On + timeout the pending future is removed and ``RequestTimeoutError`` + is raised. + """ + if self._state is not ConnectionState.ONLINE: + raise ConnectionError( + f"cannot send request, connection state={self._state.name}" + ) + message = encode_v2(opcode, payload) + seq, fut = self._send_encrypted(message) + try: + reply_packet = await asyncio.wait_for( + fut, timeout if timeout is not None else self._default_timeout + ) + except TimeoutError as exc: + self._pending.pop(seq, None) + raise RequestTimeoutError( + f"no reply for opcode={int(opcode)} seq={seq}" + ) from exc + return self._decode_inner(reply_packet) + + def unsolicited(self) -> AsyncIterator[Message]: + """Async iterator over unsolicited inbound messages (seq=0).""" + queue = self._unsolicited_queue + + async def _gen() -> AsyncIterator[Message]: + while True: + msg = await queue.get() + yield msg + + return _gen() + + # ---- handshake ------------------------------------------------------- + + async def _do_handshake(self) -> None: + # Step 1: send empty ClientRequestNewSession (cleartext, seq=1). + self._state = ConnectionState.NEW_SESSION + step1 = Packet( + seq=self._claim_seq(), + type=PacketType.ClientRequestNewSession, + data=b"", + ) + self._write_packet(step1) + + # Step 2: wait for ControllerAckNewSession. + ack1 = await self._await_handshake_packet() + if ack1.type is PacketType.ControllerCannotStartNewSession: + raise HandshakeError("controller cannot start new session (busy?)") + if ack1.type is not PacketType.ControllerAckNewSession: + raise HandshakeError( + f"unexpected step-2 packet type {ack1.type.name}" + ) + if len(ack1.data) < 7: + raise HandshakeError( + f"ControllerAckNewSession payload too short: {len(ack1.data)} bytes" + ) + proto_hi, proto_lo = ack1.data[0], ack1.data[1] + if (proto_hi, proto_lo) != _PROTO_VERSION: + raise HandshakeError( + f"unsupported protocol version {proto_hi:#04x}{proto_lo:02x}, " + f"want {_PROTO_VERSION[0]:#04x}{_PROTO_VERSION[1]:02x}" + ) + self._session_id = bytes(ack1.data[2 : 2 + _SESSION_ID_LEN]) + self._session_key = derive_session_key(self._controller_key, self._session_id) + + # Step 3: send ClientRequestSecureSession with the SessionID echoed + # back, AES-encrypted under the freshly derived SessionKey. The + # crypto layer handles zero-padding to a 16-byte block + the + # per-block sequence-number whitening. + self._state = ConnectionState.SECURE + step3_seq = self._claim_seq() + step3_ct = encrypt_message_payload( + self._session_id, step3_seq, self._session_key + ) + step3 = Packet( + seq=step3_seq, + type=PacketType.ClientRequestSecureSession, + data=step3_ct, + ) + self._write_packet(step3, encrypted=True) + + # Step 4: wait for ControllerAckSecureSession (or termination). + ack2 = await self._await_handshake_packet() + if ack2.type is PacketType.ControllerSessionTerminated: + raise InvalidEncryptionKeyError( + "controller terminated session during handshake " + "(wrong ControllerKey?)" + ) + if ack2.type is not PacketType.ControllerAckSecureSession: + raise HandshakeError( + f"unexpected step-4 packet type {ack2.type.name}" + ) + # We don't bother validating the decrypted plaintext — per + # clsOmniLinkConnection.cs:1933-1937, neither does PC Access. + # If AES decrypted without throwing, we trust the key matched. + self._state = ConnectionState.ONLINE + + async def _await_handshake_packet(self) -> Packet: + try: + await asyncio.wait_for( + self._handshake_event.wait(), self._default_timeout + ) + except TimeoutError as exc: + raise HandshakeError("timeout waiting for controller handshake reply") from exc + if self._handshake_error is not None: + err = self._handshake_error + self._handshake_error = None + raise err + pkt = self._handshake_packet + self._handshake_packet = None + self._handshake_event.clear() + if pkt is None: + raise HandshakeError("handshake event fired with no packet") + return pkt + + # ---- send / receive helpers ----------------------------------------- + + def _claim_seq(self) -> int: + """Allocate the next client-side outbound sequence number. + + Wraparound: after 0xFFFF we go to 1, skipping 0 because seq=0 is + reserved for unsolicited inbound packets and would collide with + the dispatch logic. + """ + seq = self._next_seq + nxt = seq + 1 + if nxt > _MAX_SEQ: + nxt = 1 + if nxt == 0: # paranoia; shouldn't happen with above branch + nxt = 1 + self._next_seq = nxt + return seq + + def _send_encrypted( + self, inner: Message + ) -> tuple[int, asyncio.Future[Packet]]: + """Frame an inner v2 ``Message`` as an encrypted ``OmniLink2Message`` packet.""" + if self._session_key is None: + raise ConnectionError("no session key (handshake not complete)") + seq = self._claim_seq() + plaintext = inner.encode() + ciphertext = encrypt_message_payload(plaintext, seq, self._session_key) + pkt = Packet(seq=seq, type=PacketType.OmniLink2Message, data=ciphertext) + + loop = asyncio.get_running_loop() + fut: asyncio.Future[Packet] = loop.create_future() + self._pending[seq] = fut + self._write_packet(pkt, encrypted=True) + return seq, fut + + def _write_packet(self, pkt: Packet, *, encrypted: bool = False) -> None: + if self._writer is None: + raise ConnectionError("transport not open") + wire = pkt.encode() + _log.debug( + "TX seq=%d type=%s len=%d encrypted=%s", + pkt.seq, + pkt.type.name, + len(pkt.data), + encrypted, + ) + self._writer.write(wire) + + def _decode_inner(self, pkt: Packet) -> Message: + """Decrypt + parse the inner ``Message`` from an OmniLink2Message packet.""" + if self._session_key is None: + raise ConnectionError("no session key") + if not pkt.data: + raise ProtocolError("empty packet data") + plaintext = decrypt_message_payload(pkt.data, pkt.seq, self._session_key) + try: + return Message.decode(plaintext) + except MessageCrcError as exc: + raise ProtocolError(f"inner message CRC mismatch: {exc}") from exc + + # ---- reader loop ----------------------------------------------------- + + async def _read_loop(self) -> None: + """Drain the TCP socket forever, dispatching each frame. + + Frame logic mirrors clsOmniLinkConnection.cs:1714-1758: + * Read 4-byte header (seq, type, reserved=0). + * For OmniLink2Message: read ONE 16-byte block, decrypt, peek + at the inner ``length`` byte to learn how many more 16-byte + blocks remain, then read those. + * For control packets (ack-new-session, etc.): read the + type-specific fixed payload size. + """ + try: + assert self._reader is not None + reader = self._reader + while not self._closed: + header = await self._read_exact(reader, _HEADER_BYTES) + if header is None: + break + if header[3] != 0: + raise ProtocolError( + f"reserved byte non-zero: {header[3]:#04x}" + ) + seq = (header[0] << 8) | header[1] + try: + type_byte = PacketType(header[2]) + except ValueError as exc: + raise ProtocolError( + f"unknown packet type {header[2]:#04x}" + ) from exc + + payload = await self._read_payload(reader, seq, type_byte) + if payload is None: + break + pkt = Packet(seq=seq, type=type_byte, data=payload) + _log.debug( + "RX seq=%d type=%s len=%d", pkt.seq, pkt.type.name, len(pkt.data) + ) + self._dispatch(pkt) + except asyncio.CancelledError: + raise + except Exception as exc: + _log.warning("reader loop crashed: %s", exc, exc_info=True) + # Wake up handshake waiters with the error so connect() unwinds. + if self._state in (ConnectionState.NEW_SESSION, ConnectionState.SECURE): + self._handshake_error = exc + self._handshake_event.set() + for fut in self._pending.values(): + if not fut.done(): + fut.set_exception(exc) + self._pending.clear() + + async def _read_payload( + self, reader: asyncio.StreamReader, seq: int, type_byte: PacketType + ) -> bytes | None: + """Read the payload bytes for one packet, given its already-parsed header. + + Returns ``None`` if the socket closed mid-packet. + """ + if type_byte is PacketType.ControllerAckNewSession: + return await self._read_exact(reader, 7) + if type_byte is PacketType.ControllerAckSecureSession: + return await self._read_exact(reader, BLOCK_SIZE) + if type_byte is PacketType.OmniLink2Message: + return await self._read_encrypted_message(reader, seq) + if type_byte is PacketType.OmniLink2UnencryptedMessage: + return await self._read_unencrypted_message(reader) + if type_byte in ( + PacketType.ControllerSessionTerminated, + PacketType.ControllerCannotStartNewSession, + PacketType.ClientSessionTerminated, + PacketType.NoMessage, + ): + return b"" + raise ProtocolError( + f"unhandled inbound packet type {type_byte.name}" + ) + + async def _read_encrypted_message( + self, reader: asyncio.StreamReader, seq: int + ) -> bytes | None: + """Read N 16-byte blocks for an OmniLink2Message frame. + + We have to decrypt the FIRST block to learn the inner ``length`` + byte, then compute how many more 16-byte blocks the rest of the + message occupies. The reference C# code does this same dance + (clsOmniLinkConnection.cs:1731-1758). + """ + first = await self._read_exact(reader, BLOCK_SIZE) + if first is None: + return None + if self._session_key is None: + # Could happen if we get an encrypted frame before handshake; + # bail out the hard way. + raise ProtocolError("encrypted frame before session key derived") + first_plain = decrypt_message_payload(first, seq, self._session_key) + # first_plain[0] is the StartChar (0x21 for v2), [1] is MessageLength. + message_length = first_plain[1] + # Bytes already consumed inside the first block (after StartChar + # and length): the inner frame is [start][length][data...][crc lo/hi] + # so total inner size is message_length + 4. We have 16 bytes of + # ciphertext == 16 bytes of plaintext, of which the inner frame + # could be shorter (rest is zero pad). Need to read the rest, in + # whole 16-byte blocks. + remaining_inner = message_length + 4 - BLOCK_SIZE + if remaining_inner <= 0: + extra_bytes = 0 + else: + pad = (-remaining_inner) % BLOCK_SIZE + extra_bytes = remaining_inner + pad + if extra_bytes == 0: + return first + rest = await self._read_exact(reader, extra_bytes) + if rest is None: + return None + return first + rest + + async def _read_unencrypted_message( + self, reader: asyncio.StreamReader + ) -> bytes | None: + """Read an OmniLink2UnencryptedMessage frame. + + Cleartext mirrors of the encrypted path; layout is just the inner + ``Message`` bytes one-to-one. We read 5 bytes (start + len + 1 + opcode byte minimum + 2 CRC), then any remaining payload. + """ + head = await self._read_exact(reader, 5) + if head is None: + return None + # head = [start][length][opcode][crc_lo][crc_hi] for length=1. + length = head[1] + if length <= 1: + return head + rest = await self._read_exact(reader, length - 1) + if rest is None: + return None + return head + rest + + async def _read_exact( + self, reader: asyncio.StreamReader, n: int + ) -> bytes | None: + try: + data = await reader.readexactly(n) + except asyncio.IncompleteReadError: + return None + return data + + def _dispatch(self, pkt: Packet) -> None: + """Route an inbound packet to its waiter (handshake / request / unsolicited).""" + # During the handshake, control packets carrying the session + # information go to the handshake awaiter regardless of seq. + if self._state in (ConnectionState.NEW_SESSION, ConnectionState.SECURE): + handshake_types = { + PacketType.ControllerAckNewSession, + PacketType.ControllerAckSecureSession, + PacketType.ControllerSessionTerminated, + PacketType.ControllerCannotStartNewSession, + } + if pkt.type in handshake_types: + self._handshake_packet = pkt + self._handshake_event.set() + return + + # Unsolicited push from the panel — seq=0. + if pkt.seq == 0: + if pkt.type is PacketType.OmniLink2Message: + try: + msg = self._decode_inner(pkt) + except (ProtocolError, ConnectionError) as exc: + _log.warning("dropping malformed unsolicited packet: %s", exc) + return + try: + self._unsolicited_queue.put_nowait(msg) + except asyncio.QueueFull: # pragma: no cover - unbounded queue + _log.warning("unsolicited queue full; dropping message") + return + + # Solicited reply — match on the seq we sent. + fut = self._pending.pop(pkt.seq, None) + if fut is None: + _log.debug( + "no waiter for seq=%d type=%s; dropping", pkt.seq, pkt.type.name + ) + return + if pkt.type is PacketType.ControllerSessionTerminated: + fut.set_exception( + ConnectionError("controller terminated session") + ) + return + if not fut.done(): + fut.set_result(pkt) diff --git a/src/omni_pca/mock_panel.py b/src/omni_pca/mock_panel.py new file mode 100644 index 0000000..8d92b38 --- /dev/null +++ b/src/omni_pca/mock_panel.py @@ -0,0 +1,538 @@ +"""Mock Omni-Link II controller — async TCP server speaking the panel side. + +A drop-in test fixture that lets us exercise the client end of the protocol +without touching real hardware. Reuses the project's own primitives +(``crypto``, ``packet``, ``message``, ``opcodes``) — the wire-level +encryption MUST flow through ``omni_pca.crypto`` to avoid a parallel +implementation drift. + +Coverage today: + +* Full secure-session handshake (NewSession / SecureSession ack pair) +* ``RequestSystemInformation`` (22) -> ``SystemInformation`` (23) +* ``RequestSystemStatus`` (24) -> ``SystemStatus`` (25) +* ``RequestProperties`` (32) -> ``Properties`` (33) for Zone + Unit +* Any other v2 opcode -> ``Nak`` (2) with the request's opcode +* CRC failures on the inner message -> ``Nak`` +* Graceful ``ClientSessionTerminated`` close + +References: + notes/handshake.md (whole document) + clsOmniLinkConnection.cs:1688-1921 (TCP listener / ack flow) + clsOL2MsgSystemInformation.cs / clsOL2MsgSystemStatus.cs + clsOL2MsgRequestProperties.cs / clsOL2MsgProperties.cs +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import secrets +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass, field + +from .crypto import ( + BLOCK_SIZE, + decrypt_message_payload, + derive_session_key, + encrypt_message_payload, +) +from .message import Message, MessageCrcError, MessageFormatError, encode_v2 +from .opcodes import OmniLink2MessageType, PacketType +from .packet import Packet + +_log = logging.getLogger(__name__) + +# enuObjectType (clsOmniLink2.cs / enuObjectType.cs) +_OBJ_ZONE = 1 +_OBJ_UNIT = 2 +_OBJ_AREA = 5 + +# Inner-message size constants (model OMNI_PRO_II) +_ZONE_NAME_LEN = 15 +_UNIT_NAME_LEN = 12 +_AREA_NAME_LEN = 12 +_PHONE_LEN = 24 + +# Wire format for the controller-side ack of NewSession is two literal +# protocol-version bytes followed by the 5-byte SessionID. +_PROTO_HI = 0x00 +_PROTO_LO = 0x01 + +_SESSION_ID_BYTES = 5 + + +@dataclass +class MockState: + """Programmable panel state. Defaults mimic an Omni Pro II out of the box.""" + + model_byte: int = 16 # OMNI_PRO_II + firmware_major: int = 2 + firmware_minor: int = 12 + firmware_revision: int = 1 + local_phone: str = "" + + # Names by 1-based index (matches Omni's user-facing numbering). + zones: dict[int, str] = field(default_factory=dict) + units: dict[int, str] = field(default_factory=dict) + areas: dict[int, str] = field(default_factory=dict) + + # SystemStatus snapshot. Defaults: time set, battery good, no alarms. + time_set: bool = True + year: int = 26 # 2026 + month: int = 5 + day: int = 10 + day_of_week: int = 1 # Sunday=1 in the Omni convention + hour: int = 12 + minute: int = 0 + second: int = 0 + daylight_saving: int = 0 + sunrise_hour: int = 6 + sunrise_minute: int = 30 + sunset_hour: int = 19 + sunset_minute: int = 45 + battery: int = 200 # 0-255 — typical "good" value + + def zone_name_bytes(self, idx: int) -> bytes: + return _name_bytes(self.zones.get(idx, ""), _ZONE_NAME_LEN) + + def unit_name_bytes(self, idx: int) -> bytes: + return _name_bytes(self.units.get(idx, ""), _UNIT_NAME_LEN) + + def area_name_bytes(self, idx: int) -> bytes: + return _name_bytes(self.areas.get(idx, ""), _AREA_NAME_LEN) + + +def _name_bytes(name: str, width: int) -> bytes: + """Encode a panel name as ASCII, right-padded with NULs to a fixed width.""" + raw = name.encode("ascii", errors="replace")[:width] + return raw + b"\x00" * (width - len(raw)) + + +class MockPanel: + """Async TCP server that speaks Omni-Link II from the controller side. + + One client at a time — Omni's real controllers are single-session too. + """ + + def __init__( + self, + controller_key: bytes, + state: MockState | None = None, + session_id_provider: Callable[[], bytes] | None = None, + ) -> None: + if len(controller_key) != 16: + raise ValueError("controller_key must be 16 bytes") + self._controller_key = bytes(controller_key) + self.state = state or MockState() + self._session_id_provider = session_id_provider or ( + lambda: secrets.token_bytes(_SESSION_ID_BYTES) + ) + self._session_count = 0 + self._last_request_opcode: int | None = None + self._busy = asyncio.Lock() # serialise concurrent connection attempts + + # -------- public observables (handy in tests) -------- + + @property + def session_count(self) -> int: + return self._session_count + + @property + def last_request_opcode(self) -> int | None: + return self._last_request_opcode + + # -------- server lifecycle -------- + + @asynccontextmanager + async def serve( + self, host: str = "127.0.0.1", port: int = 0 + ) -> AsyncIterator[tuple[str, int]]: + """Start listening; yield ``(host, actual_port)``; tear down on exit.""" + server = await asyncio.start_server(self._handle_client, host=host, port=port) + sockets = server.sockets or () + if not sockets: # pragma: no cover -- start_server always populates this + raise RuntimeError("asyncio.start_server returned no sockets") + bound_host, bound_port = sockets[0].getsockname()[:2] + _log.debug("mock panel listening on %s:%d", bound_host, bound_port) + try: + async with server: + yield bound_host, bound_port + finally: + server.close() + with contextlib.suppress(Exception): # pragma: no cover + await server.wait_closed() + + # -------- connection handling -------- + + async def _handle_client( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + peer = writer.get_extra_info("peername") + _log.debug("mock panel: client connected from %s", peer) + session_key: bytes | None = None + session_id: bytes | None = None + try: + while True: + header = await _read_exact(reader, 4) + if header is None: + break + seq = (header[0] << 8) | header[1] + try: + pkt_type = PacketType(header[2]) + except ValueError: + _log.debug("mock panel: unknown packet type %#x", header[2]) + break + + if pkt_type is PacketType.ClientRequestNewSession: + session_id, session_key = await self._handle_new_session(seq, writer) + + elif pkt_type is PacketType.ClientRequestSecureSession: + if session_key is None or session_id is None: + _log.debug("mock panel: secure-session before NewSession") + break + body = await _read_exact(reader, BLOCK_SIZE) + if body is None: + break + handled = await self._handle_secure_session( + seq, body, session_id, session_key, writer + ) + if not handled: + break + + elif pkt_type is PacketType.ClientSessionTerminated: + _log.debug("mock panel: client requested teardown") + break + + elif pkt_type is PacketType.OmniLink2Message: + if session_key is None: + _log.debug("mock panel: encrypted message before secure session") + break + cont = await self._handle_encrypted_message( + reader, seq, session_key, writer + ) + if not cont: + break + + else: + _log.debug("mock panel: unhandled packet type %s", pkt_type.name) + break + except (asyncio.IncompleteReadError, ConnectionError): + _log.debug("mock panel: client connection ended unexpectedly") + finally: + writer.close() + with contextlib.suppress(Exception): # pragma: no cover + await writer.wait_closed() + _log.debug("mock panel: client %s disconnected", peer) + + # -------- handshake steps -------- + + async def _handle_new_session( + self, client_seq: int, writer: asyncio.StreamWriter + ) -> tuple[bytes, bytes]: + session_id = self._session_id_provider() + if len(session_id) != _SESSION_ID_BYTES: + raise RuntimeError( + f"session_id_provider returned {len(session_id)} bytes," + f" need {_SESSION_ID_BYTES}" + ) + session_key = derive_session_key(self._controller_key, session_id) + payload = bytes([_PROTO_HI, _PROTO_LO]) + session_id + ack = Packet(seq=client_seq, type=PacketType.ControllerAckNewSession, data=payload) + _log.debug("mock panel: ack new session, sid=%s", session_id.hex()) + writer.write(ack.encode()) + await writer.drain() + return session_id, session_key + + async def _handle_secure_session( + self, + client_seq: int, + ciphertext: bytes, + session_id: bytes, + session_key: bytes, + writer: asyncio.StreamWriter, + ) -> bool: + try: + plaintext = decrypt_message_payload(ciphertext, client_seq, session_key) + except Exception: + _log.debug("mock panel: failed to decrypt secure-session request") + return False + if not plaintext.startswith(session_id): + _log.debug( + "mock panel: secure-session SID mismatch (got %s, want %s)", + plaintext[:_SESSION_ID_BYTES].hex(), + session_id.hex(), + ) + # The real controller replies with ControllerSessionTerminated + # to signal "your key didn't decrypt right". Mirror that. + term = Packet( + seq=client_seq, type=PacketType.ControllerSessionTerminated, data=b"" + ) + writer.write(term.encode()) + await writer.drain() + return False + + # Echo SessionID back, encrypted with the freshly derived key. + echo_plain = session_id # encrypt_message_payload zero-pads for us + ciphertext_out = encrypt_message_payload(echo_plain, client_seq, session_key) + ack = Packet( + seq=client_seq, type=PacketType.ControllerAckSecureSession, data=ciphertext_out + ) + writer.write(ack.encode()) + await writer.drain() + self._session_count += 1 + _log.debug("mock panel: secure session up (#%d)", self._session_count) + return True + + # -------- encrypted message dispatch -------- + + async def _handle_encrypted_message( + self, + reader: asyncio.StreamReader, + client_seq: int, + session_key: bytes, + writer: asyncio.StreamWriter, + ) -> bool: + first_block = await _read_exact(reader, BLOCK_SIZE) + if first_block is None: + return False + first_plain = decrypt_message_payload(first_block, client_seq, session_key) + # first_plain[0] = StartChar (0x21), first_plain[1] = MessageLength + msg_length = first_plain[1] + # Total inner message bytes = msg_length + 4 (start, length, ..., crc1, crc2) + # We have BLOCK_SIZE bytes; need additional bytes rounded up to BLOCK_SIZE. + extra_needed = max(0, msg_length + 4 - BLOCK_SIZE) + rem = (-extra_needed) % BLOCK_SIZE + extra_aligned = extra_needed + rem + ciphertext = first_block + if extra_aligned > 0: + extra = await _read_exact(reader, extra_aligned) + if extra is None: + return False + ciphertext = first_block + extra + plaintext = decrypt_message_payload(ciphertext, client_seq, session_key) + + try: + inner = Message.decode(plaintext) + except MessageCrcError: + _log.debug("mock panel: inner message CRC failure") + await self._send_v2_reply( + client_seq, _build_nak(0), session_key, writer + ) + return True + except MessageFormatError as exc: + _log.debug("mock panel: malformed inner message: %s", exc) + return False + + opcode = inner.opcode + self._last_request_opcode = opcode + try: + opcode_name = OmniLink2MessageType(opcode).name + except ValueError: + opcode_name = f"Unknown({opcode})" + _log.debug("mock panel: dispatch opcode=%s payload=%d bytes", + opcode_name, len(inner.payload)) + + reply = self._dispatch_v2(opcode, inner.payload) + await self._send_v2_reply(client_seq, reply, session_key, writer) + return True + + def _dispatch_v2(self, opcode: int, payload: bytes) -> Message: + if opcode == OmniLink2MessageType.RequestSystemInformation: + return self._reply_system_information() + if opcode == OmniLink2MessageType.RequestSystemStatus: + return self._reply_system_status() + if opcode == OmniLink2MessageType.RequestProperties: + return self._reply_properties(payload) + return _build_nak(opcode) + + # -------- reply builders (byte-exact per clsOL2Msg*.cs) -------- + + def _reply_system_information(self) -> Message: + s = self.state + revision_byte = s.firmware_revision & 0xFF + phone = _name_bytes(s.local_phone, _PHONE_LEN) + body = bytes( + [ + s.model_byte & 0xFF, + s.firmware_major & 0xFF, + s.firmware_minor & 0xFF, + revision_byte, + ] + ) + phone + return encode_v2(OmniLink2MessageType.SystemInformation, body) + + def _reply_system_status(self) -> Message: + s = self.state + body = bytes( + [ + 1 if s.time_set else 0, + s.year & 0xFF, + s.month & 0xFF, + s.day & 0xFF, + s.day_of_week & 0xFF, + s.hour & 0xFF, + s.minute & 0xFF, + s.second & 0xFF, + s.daylight_saving & 0xFF, + s.sunrise_hour & 0xFF, + s.sunrise_minute & 0xFF, + s.sunset_hour & 0xFF, + s.sunset_minute & 0xFF, + s.battery & 0xFF, + ] + ) + # No area alarms appended — real panels can append 2 bytes per area. + return encode_v2(OmniLink2MessageType.SystemStatus, body) + + def _reply_properties(self, payload: bytes) -> Message: + # RequestProperties payload (after opcode): ObjectType, IndexNumber(2), + # RelativeDirection(sbyte), Filter1, Filter2, Filter3. + if len(payload) < 7: + return _build_nak(OmniLink2MessageType.RequestProperties) + obj_type = payload[0] + index = (payload[1] << 8) | payload[2] + rel = payload[3] + + store = self._object_store(obj_type) + if store is None: + return _build_nak(OmniLink2MessageType.RequestProperties) + + # rel: 0 = exact, 1 = next defined > index, -1/0xFF = previous defined < index. + if rel == 0: + target = index if index in store else None + elif rel == 1: + candidates = sorted(i for i in store if i > index) + target = candidates[0] if candidates else None + elif rel in (0xFF, -1 & 0xFF): # signed -1 byte + candidates = sorted((i for i in store if i < index), reverse=True) + target = candidates[0] if candidates else None + else: + return _build_nak(OmniLink2MessageType.RequestProperties) + + if target is None: + # End of iteration: real panels return EOD (opcode 3) here. + return encode_v2(OmniLink2MessageType.EOD, b"") + + if obj_type == _OBJ_ZONE: + return self._build_zone_properties(target) + if obj_type == _OBJ_UNIT: + return self._build_unit_properties(target) + if obj_type == _OBJ_AREA: + return self._build_area_properties(target) + return _build_nak(OmniLink2MessageType.RequestProperties) + + def _object_store(self, obj_type: int) -> dict[int, str] | None: + if obj_type == _OBJ_ZONE: + return self.state.zones + if obj_type == _OBJ_UNIT: + return self.state.units + if obj_type == _OBJ_AREA: + return self.state.areas + return None + + def _build_zone_properties(self, index: int) -> Message: + # Properties.Data layout for Zone (1-indexed offsets are into Data[]): + # [0]=opcode, [1]=ObjectType, [2..3]=ObjectNumber, + # [4]=Status, [5]=Loop, [6]=Type, [7]=Area, [8]=Options, + # [9..23]=Name (15 bytes) + # encode_v2 prepends the opcode, so we emit body = Data[1..23]. + body = ( + bytes( + [ + _OBJ_ZONE, + (index >> 8) & 0xFF, + index & 0xFF, + 0, # Status: closed/secure + 0, # Loop + 0, # Type: EntryExit + 1, # Area: default to area 1 + 0, # Options + ] + ) + + self.state.zone_name_bytes(index) + ) + return encode_v2(OmniLink2MessageType.Properties, body) + + def _build_unit_properties(self, index: int) -> Message: + # Properties.Data for Unit: + # [0]=opcode, [1]=ObjectType, [2..3]=ObjectNumber, + # [4]=UnitStatus, [5..6]=UnitTime, [7]=UnitType, + # [8..19]=Name (12), [20]=reserved, [21]=UnitAreas + body = ( + bytes( + [ + _OBJ_UNIT, + (index >> 8) & 0xFF, + index & 0xFF, + 0, # UnitStatus: off + 0, + 0, # UnitTime + 1, # UnitType: Standard + ] + ) + + self.state.unit_name_bytes(index) + + bytes([0, 1]) # reserved + UnitAreas (default area 1) + ) + return encode_v2(OmniLink2MessageType.Properties, body) + + def _build_area_properties(self, index: int) -> Message: + # Properties.Data for Area: + # [0]=opcode, [1]=ObjectType, [2..3]=ObjectNumber, + # [4]=AreaMode, [5]=AreaAlarms, [6]=EntryTimer, [7]=ExitTimer, + # [8]=Enabled, [9]=ExitDelay, [10]=EntryDelay, + # [11..22]=Name (12 bytes) + body = ( + bytes( + [ + _OBJ_AREA, + (index >> 8) & 0xFF, + index & 0xFF, + 0, # AreaMode: Off + 0, # AreaAlarms + 0, # EntryTimer + 0, # ExitTimer + 1, # Enabled + 60, # ExitDelay (s) + 30, # EntryDelay (s) + ] + ) + + self.state.area_name_bytes(index) + ) + return encode_v2(OmniLink2MessageType.Properties, body) + + # -------- low-level reply send -------- + + async def _send_v2_reply( + self, + client_seq: int, + message: Message, + session_key: bytes, + writer: asyncio.StreamWriter, + ) -> None: + plaintext = message.encode() + ciphertext = encrypt_message_payload(plaintext, client_seq, session_key) + pkt = Packet(seq=client_seq, type=PacketType.OmniLink2Message, data=ciphertext) + writer.write(pkt.encode()) + await writer.drain() + + +def _build_nak(in_reply_to_opcode: int) -> Message: + """Build a v2 Nak. Payload is a single byte echoing the opcode being negged. + + The C# clsOL2MsgNegativeAcknowledge has only the opcode byte; some HAI + docs show a single trailing data byte but it is not defined. We include + the offending opcode for ease of debugging — the client side cares only + that the opcode is Nak. + """ + return encode_v2(OmniLink2MessageType.Nak, bytes([in_reply_to_opcode & 0xFF])) + + +async def _read_exact(reader: asyncio.StreamReader, n: int) -> bytes | None: + """Read exactly ``n`` bytes or return None if EOF arrives early.""" + try: + return await reader.readexactly(n) + except asyncio.IncompleteReadError: + return None diff --git a/src/omni_pca/models.py b/src/omni_pca/models.py new file mode 100644 index 0000000..b3ddcb9 --- /dev/null +++ b/src/omni_pca/models.py @@ -0,0 +1,451 @@ +"""Typed dataclasses for parsed Omni-Link II v2 reply payloads. + +Each class is built from the raw inner-message ``payload`` bytes — i.e. +everything in the ``Message.data`` array AFTER the opcode byte. The +classmethod ``parse(payload)`` does the work; the dataclass itself stays +purely descriptive. + +References: + clsOL2MsgSystemInformation.cs — model byte + firmware + phone + clsOL2MsgSystemStatus.cs — date/time + battery + alarms + clsOL2MsgProperties.cs — per-object-type field offsets + enuModel.cs — model byte → human name + clsUtil.ByteArrayToString — null-terminated, latin-1, fixed-width +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import ClassVar, Self + +# -------------------------------------------------------------------------- +# enuModel byte → human-friendly name. Built from +# decompiled/project/HAI_Shared/enuModel.cs. +# -------------------------------------------------------------------------- + +MODEL_NAMES: dict[int, str] = { + 0: "Unknown", + 1: "Old Chip v5", + 2: "Omni", + 3: "HAI 2000", + 4: "Omni Pro", + 5: "Aegis 2000", + 6: "HAI 2000 Plus", + 7: "HMS 925", + 8: "HMS 1050", + 9: "Omni LT", + 10: "HMS 800", + 11: "FSN AC", + 12: "Siemens BCM", + 15: "Omni II", + 16: "Omni Pro II", + 17: "HMS 950", + 18: "Aegis 3000", + 19: "HMS 1100", + 20: "Aegis 1000", + 21: "Aegis 1500", + 22: "DOMAIKE D42", + 23: "DOMAIKE D62", + 24: "DOMAIKE D82", + 25: "SC 2000-1", + 26: "SC 2000-2 Plus", + 27: "SC 2000-4", + 28: "Siemens ECM", + 29: "Siemens CCM", + 30: "Omni IIe", + 31: "DOMAIKE D62e", + 32: "HMS 950e", + 33: "SC 2000-2e", + 34: "Aegis 1500e", + 35: "Siemens ECMe", + 36: "Lumina", + 37: "Lumina Pro", + 38: "Omni LTe", + 39: "Omni LTe EU", + 40: "Omni IIe EU", + 41: "Omni Pro II EU", +} + + +def _decode_name(buf: bytes) -> str: + """Decode a fixed-width name field as the C# code does (null-terminated, ASCII). + + clsUtil.ByteArrayToString iterates raw bytes and casts each to a + char, stopping at the first 0 byte. We treat input as latin-1 + (one-byte-one-codepoint) and strip at the first NUL. + """ + nul = buf.find(b"\x00") + if nul >= 0: + buf = buf[:nul] + return buf.decode("latin-1", errors="replace") + + +# -------------------------------------------------------------------------- +# SystemInformation +# -------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class SystemInformation: + """Parsed payload of a v2 ``SystemInformation`` (opcode 23) reply. + + Wire layout (clsOL2MsgSystemInformation.cs): + 0 model byte (enuModel) + 1 firmware major + 2 firmware minor + 3 firmware revision (signed; negative = beta) + 4..27 24-byte ASCII local-phone-number, NUL-padded + """ + + model_byte: int + model_name: str + firmware_major: int + firmware_minor: int + firmware_revision: int + local_phone: str + + @property + def firmware_version(self) -> str: + """Human-friendly version string, e.g. ``"2.12r1"`` or ``"2.12b3"``.""" + rev = self.firmware_revision + if rev > 0: + return f"{self.firmware_major}.{self.firmware_minor}r{rev}" + if rev < 0: + return f"{self.firmware_major}.{self.firmware_minor}b{-rev}" + return f"{self.firmware_major}.{self.firmware_minor}" + + @classmethod + def parse(cls, payload: bytes) -> Self: + if len(payload) < 4: + raise ValueError( + f"SystemInformation payload too short: {len(payload)} bytes" + ) + model_byte = payload[0] + major = payload[1] + minor = payload[2] + # Revision is signed (sbyte): negative values mean beta builds. + rev = payload[3] + if rev >= 0x80: + rev -= 0x100 + phone_bytes = payload[4:28] if len(payload) >= 28 else payload[4:] + return cls( + model_byte=model_byte, + model_name=MODEL_NAMES.get(model_byte, f"Unknown ({model_byte})"), + firmware_major=major, + firmware_minor=minor, + firmware_revision=rev, + local_phone=_decode_name(phone_bytes), + ) + + +# -------------------------------------------------------------------------- +# SystemStatus +# -------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class SystemStatus: + """Parsed payload of a v2 ``SystemStatus`` (opcode 25) reply. + + Wire layout (clsOL2MsgSystemStatus.cs): + 0 time/date valid flag (0 = not yet set) + 1 year (2-digit, +2000) + 2 month + 3 day + 4 day-of-week (1=Sun..7=Sat) + 5 hour + 6 minute + 7 second + 8 daylight saving flag + 9 sunrise hour + 10 sunrise minute + 11 sunset hour + 12 sunset minute + 13 battery reading (0-255 raw) + 14..N 2 bytes per area alarm flag set + """ + + time_valid: bool + panel_time: datetime | None + sunrise_hour: int + sunrise_minute: int + sunset_hour: int + sunset_minute: int + battery_reading: int + area_alarms: tuple[tuple[int, int], ...] + + # Convenience flags requested in the spec — derived from + # ``battery_reading`` and the absence of any alarms / area errors. + # The wire protocol doesn't expose dedicated AC / comm flags; PC + # Access infers them from System Troubles. We surface the raw byte + # and let a higher layer interpret. + BATTERY_OK_THRESHOLD: ClassVar[int] = 0xC0 # ~75% of 255 + + @property + def battery_ok(self) -> bool: + return self.battery_reading >= self.BATTERY_OK_THRESHOLD + + @property + def ac_ok(self) -> bool: + # Without RequestSystemTroubles we approximate: a battery reading + # of 0 implies "AC down, battery dead too" or "panel hasn't + # initialized" — treat both as not-ok. + return self.battery_reading != 0 + + @property + def communication_ok(self) -> bool: + # We're talking to the panel right now; if any of this parsed, + # comms are by definition working at least for this query. + return True + + @property + def troubles(self) -> tuple[str, ...]: + bad: list[str] = [] + if not self.battery_ok: + bad.append("battery_low") + if not self.ac_ok: + bad.append("ac_loss") + if self.area_alarms: + bad.append("area_alarm") + return tuple(bad) + + @classmethod + def parse(cls, payload: bytes) -> Self: + if len(payload) < 14: + raise ValueError( + f"SystemStatus payload too short: {len(payload)} bytes" + ) + time_valid = payload[0] != 0 + year = payload[1] + month = payload[2] + day = payload[3] + # day_of_week = payload[4] # 1=Sun .. 7=Sat — unused here + hour = payload[5] + minute = payload[6] + second = payload[7] + # daylight = payload[8] + sunrise_h = payload[9] + sunrise_m = payload[10] + sunset_h = payload[11] + sunset_m = payload[12] + battery = payload[13] + + panel_time: datetime | None = None + if time_valid: + try: + panel_time = datetime( + year=2000 + year, + month=month, + day=day, + hour=hour, + minute=minute, + second=second, + ) + except ValueError: + panel_time = None + + # Each area alarm entry is 2 bytes. Pair them up. + alarm_bytes = payload[14:] + usable = len(alarm_bytes) - (len(alarm_bytes) % 2) + alarms = tuple( + (alarm_bytes[i], alarm_bytes[i + 1]) for i in range(0, usable, 2) + ) + return cls( + time_valid=time_valid, + panel_time=panel_time, + sunrise_hour=sunrise_h, + sunrise_minute=sunrise_m, + sunset_hour=sunset_h, + sunset_minute=sunset_m, + battery_reading=battery, + area_alarms=alarms, + ) + + +# -------------------------------------------------------------------------- +# Properties — common header +# -------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class _PropertiesHeader: + object_type: int + object_number: int + + @classmethod + def from_payload(cls, payload: bytes) -> Self: + if len(payload) < 3: + raise ValueError( + f"Properties payload too short: {len(payload)} bytes" + ) + return cls( + object_type=payload[0], + object_number=(payload[1] << 8) | payload[2], + ) + + +# -------------------------------------------------------------------------- +# ZoneProperties +# -------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class ZoneProperties: + """Parsed Properties (opcode 33) reply for a Zone object. + + Wire layout (clsOL2MsgProperties.cs, ObjectType=Zone): + 0 object type byte (Zone = 1) + 1..2 object number (BE ushort) + 3 zone status (raw) + 4 zone loop reading + 5 zone type (enuZoneType) + 6 area number + 7 options bitfield + 8..22 15-byte name, NUL-padded + """ + + index: int + name: str + zone_type: int + area: int + options: int + status: int + loop: int + + @classmethod + def parse(cls, payload: bytes) -> Self: + hdr = _PropertiesHeader.from_payload(payload) + if hdr.object_type != 1: + raise ValueError( + f"expected Zone (object_type=1), got {hdr.object_type}" + ) + if len(payload) < 8 + 15: + raise ValueError( + f"ZoneProperties payload too short: {len(payload)} bytes" + ) + return cls( + index=hdr.object_number, + status=payload[3], + loop=payload[4], + zone_type=payload[5], + area=payload[6], + options=payload[7], + name=_decode_name(payload[8 : 8 + 15]), + ) + + +# -------------------------------------------------------------------------- +# UnitProperties +# -------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class UnitProperties: + """Parsed Properties (opcode 33) reply for a Unit object. + + Wire layout (clsOL2MsgProperties.cs, ObjectType=Unit): + 0 object type (Unit = 2) + 1..2 object number (BE ushort) + 3 unit status + 4..5 unit time (BE ushort) + 6 unit type (enuOL2UnitType) + 7..18 12-byte name + 19 unit areas bitfield (Data[21] in the C# class — that's + Data[1+offset], so payload[20] in zero-based offset, but + the C# accessor reads Data[21] which corresponds to our + payload[20] when we strip the opcode byte). + """ + + index: int + name: str + unit_type: int + status: int + time: int + areas: int + + @classmethod + def parse(cls, payload: bytes) -> Self: + hdr = _PropertiesHeader.from_payload(payload) + if hdr.object_type != 2: + raise ValueError( + f"expected Unit (object_type=2), got {hdr.object_type}" + ) + if len(payload) < 7 + 12: + raise ValueError( + f"UnitProperties payload too short: {len(payload)} bytes" + ) + # In the C#, Data[0]=opcode, Data[1]=type, Data[2..3]=number, + # Data[4]=status, Data[5..6]=time, Data[7]=unit_type, + # Data[8..19]=12-byte name, Data[21]=areas. + # Our payload[i] == C# Data[i+1], so: status=payload[3], + # time=payload[4..5], unit_type=payload[6], name=payload[7..18], + # areas=payload[20]. + areas = payload[20] if len(payload) > 20 else 0 + return cls( + index=hdr.object_number, + status=payload[3], + time=(payload[4] << 8) | payload[5], + unit_type=payload[6], + name=_decode_name(payload[7 : 7 + 12]), + areas=areas, + ) + + +# -------------------------------------------------------------------------- +# AreaProperties +# -------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class AreaProperties: + """Parsed Properties (opcode 33) reply for an Area object. + + Wire layout (clsOL2MsgProperties.cs, ObjectType=Area): + payload[0] object type (Area = 5) + payload[1..2] object number + payload[3] area mode (enuSecurityMode) + payload[4] area alarms bitfield + payload[5] entry timer + payload[6] exit timer + payload[7] enabled flag + payload[8] exit delay + payload[9] entry delay + payload[10..21] 12-byte name + """ + + index: int + name: str + mode: int + alarms: int + enabled: bool + entry_delay: int + exit_delay: int + + @classmethod + def parse(cls, payload: bytes) -> Self: + hdr = _PropertiesHeader.from_payload(payload) + if hdr.object_type != 5: + raise ValueError( + f"expected Area (object_type=5), got {hdr.object_type}" + ) + if len(payload) < 10 + 12: + raise ValueError( + f"AreaProperties payload too short: {len(payload)} bytes" + ) + return cls( + index=hdr.object_number, + mode=payload[3], + alarms=payload[4], + enabled=payload[7] != 0, + exit_delay=payload[8], + entry_delay=payload[9], + name=_decode_name(payload[10 : 10 + 12]), + ) + + +# -------------------------------------------------------------------------- +# Convenience union for callers that don't know the type at compile time +# -------------------------------------------------------------------------- + +PropertiesReply = ZoneProperties | UnitProperties | AreaProperties diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..ada0015 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,179 @@ +"""Unit tests for omni_pca.client — typed request methods. + +The fixture is a tiny in-process asyncio server that completes the +handshake then serves whichever opcode the test wants. No mock_panel +dependency. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import struct +from collections.abc import Awaitable, Callable + +import pytest + +from omni_pca.client import ObjectType, OmniClient +from omni_pca.crypto import ( + decrypt_message_payload, + encrypt_message_payload, +) +from omni_pca.message import Message, encode_v2 +from omni_pca.opcodes import OmniLink2MessageType, PacketType + +from .test_connection import ( # reuse handshake helpers + CONTROLLER_KEY, + SESSION_ID, + _do_full_handshake, + _pack_header, + _start_server, +) + + +def _name_field(name: str, width: int) -> bytes: + encoded = name.encode("latin-1") + return encoded + b"\x00" * (width - len(encoded)) + + +def _build_system_information_payload() -> bytes: + return bytes([16, 2, 12, 1]) + _name_field("415-555-1212", 24) + + +def _build_zone_properties_payload(index: int, name: str) -> bytes: + return ( + bytes([1]) + + struct.pack(">H", index) + + bytes([0, 0, 0, 1, 0]) + + _name_field(name, 15) + ) + + +async def _read_one_request( + reader: asyncio.StreamReader, session_key: bytes +) -> tuple[int, Message]: + """Read one OmniLink2Message packet from the client; return (seq, inner Message).""" + header = await reader.readexactly(4) + seq = (header[0] << 8) | header[1] + type_byte = header[2] + assert type_byte == int(PacketType.OmniLink2Message) + first = await reader.readexactly(16) + plain_first = decrypt_message_payload(first, seq, session_key) + msg_len = plain_first[1] + remaining_inner = msg_len + 4 - 16 + if remaining_inner <= 0: + extra = 0 + else: + pad = (-remaining_inner) % 16 + extra = remaining_inner + pad + rest = await reader.readexactly(extra) if extra else b"" + full_ct = first + rest + full_plain = decrypt_message_payload(full_ct, seq, session_key) + inner = Message.decode(full_plain) + return seq, inner + + +def _send_reply( + writer: asyncio.StreamWriter, + seq: int, + opcode: OmniLink2MessageType, + payload: bytes, + session_key: bytes, +) -> None: + inner = encode_v2(opcode, payload) + ct = encrypt_message_payload(inner.encode(), seq, session_key) + writer.write(_pack_header(seq, int(PacketType.OmniLink2Message)) + ct) + + +async def _serve_one_reply( + handler_replies: dict[int, tuple[OmniLink2MessageType, bytes]], +) -> Callable[[asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]: + """Build a handler that does the handshake then replies once per opcode received.""" + + async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + try: + sk = await _do_full_handshake(r, w) + for _ in range(len(handler_replies)): + seq, inner = await _read_one_request(r, sk) + opcode = inner.opcode + if opcode not in handler_replies: + return + reply_op, reply_payload = handler_replies[opcode] + _send_reply(w, seq, reply_op, reply_payload, sk) + await w.drain() + with contextlib.suppress(TimeoutError): + await asyncio.wait_for(asyncio.Event().wait(), timeout=2.0) + finally: + w.close() + + return handler + + +@pytest.mark.asyncio +async def test_client_get_system_information_round_trip() -> None: + handler = await _serve_one_reply( + { + int(OmniLink2MessageType.RequestSystemInformation): ( + OmniLink2MessageType.SystemInformation, + _build_system_information_payload(), + ) + } + ) + server, host, port = await _start_server(handler) + try: + async with OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as c: + info = await c.get_system_information() + assert info.model_byte == 16 + assert info.model_name == "Omni Pro II" + assert info.firmware_version == "2.12r1" + assert info.local_phone == "415-555-1212" + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_client_get_zone_properties_round_trip() -> None: + handler = await _serve_one_reply( + { + int(OmniLink2MessageType.RequestProperties): ( + OmniLink2MessageType.Properties, + _build_zone_properties_payload(7, "Front Door"), + ) + } + ) + server, host, port = await _start_server(handler) + try: + async with OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as c: + zone = await c.get_object_properties(ObjectType.ZONE, 7) + assert zone.index == 7 + assert zone.name == "Front Door" + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_client_get_object_properties_eod_raises_value_error() -> None: + """A ``EOD`` reply means the panel has no object at that index.""" + handler = await _serve_one_reply( + { + int(OmniLink2MessageType.RequestProperties): ( + OmniLink2MessageType.EOD, + b"", + ) + } + ) + server, host, port = await _start_server(handler) + try: + async with OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as c: + with pytest.raises(ValueError, match="no ZONE"): + await c.get_object_properties(ObjectType.ZONE, 999) + finally: + server.close() + await server.wait_closed() + + +# Keep `SESSION_ID` reachable so ruff doesn't complain about unused +# imports — it's used implicitly by `_do_full_handshake`. +assert SESSION_ID diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..897b944 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,297 @@ +"""Unit tests for omni_pca.connection. + +These spin up tiny ``asyncio.start_server`` mock controllers inside the +test, byte-for-byte; nothing depends on a real panel or the (parallel) +mock_panel module. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import struct +from collections.abc import Awaitable, Callable + +import pytest + +from omni_pca.connection import ( + ConnectionState, + HandshakeError, + InvalidEncryptionKeyError, + OmniConnection, + RequestTimeoutError, +) +from omni_pca.crypto import ( + decrypt_message_payload, + derive_session_key, + encrypt_message_payload, +) +from omni_pca.message import encode_v2 +from omni_pca.opcodes import OmniLink2MessageType, PacketType + +# A canonical 16-byte ControllerKey for tests. +CONTROLLER_KEY = bytes.fromhex("000102030405060708090a0b0c0d0e0f") +SESSION_ID = bytes([0x10, 0x11, 0x12, 0x13, 0x14]) + + +def _pack_header(seq: int, type_byte: int) -> bytes: + return struct.pack(">HBB", seq, type_byte, 0) + + +async def _read_packet(reader: asyncio.StreamReader) -> tuple[int, int, bytes]: + """Read one outer-frame packet (full bytes), returning (seq, type, data). + + For the cleartext control packets we know the payload size from the + type. For OmniLink2Message we mirror the client-side + decrypt-first-block-to-learn-length dance. + """ + header = await reader.readexactly(4) + seq = (header[0] << 8) | header[1] + type_byte = header[2] + if type_byte == int(PacketType.ClientRequestNewSession): + return seq, type_byte, b"" + if type_byte == int(PacketType.ClientRequestSecureSession): + return seq, type_byte, await reader.readexactly(16) + if type_byte == int(PacketType.ClientSessionTerminated): + return seq, type_byte, b"" + if type_byte == int(PacketType.OmniLink2Message): + # Read first block, decrypt to learn length. + first = await reader.readexactly(16) + # Caller knows the session key at this point — but for the tests + # we just return the ciphertext + raw bytes; the test will + # decrypt manually if it cares. + return seq, type_byte, first + raise AssertionError(f"unexpected client packet type {type_byte}") + + +# ---- handshake ----------------------------------------------------------- + + +async def _start_server( + handler: Callable[[asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]], +) -> tuple[asyncio.Server, str, int]: + server = await asyncio.start_server(handler, "127.0.0.1", 0) + sockets = server.sockets + assert sockets, "server has no listening sockets" + host, port = sockets[0].getsockname()[:2] + return server, host, port + + +async def _do_full_handshake( + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + *, + controller_key: bytes = CONTROLLER_KEY, + session_id: bytes = SESSION_ID, +) -> bytes: + """Server-side: complete the 4-step handshake, return derived session key.""" + # Step 1: client sends empty ClientRequestNewSession. + seq1, type1, _ = await _read_packet(reader) + assert type1 == int(PacketType.ClientRequestNewSession) + assert seq1 == 1 + + # Step 2: send ControllerAckNewSession back, echoing seq=1. + proto = bytes([0x00, 0x01]) + writer.write(_pack_header(seq1, int(PacketType.ControllerAckNewSession)) + proto + session_id) + await writer.drain() + + # Step 3: client sends ClientRequestSecureSession (encrypted). + seq3, type3, ct = await _read_packet(reader) + assert type3 == int(PacketType.ClientRequestSecureSession) + assert seq3 == 2 + session_key = derive_session_key(controller_key, session_id) + plain = decrypt_message_payload(ct, seq3, session_key) + assert plain[:5] == session_id + + # Step 4: send ControllerAckSecureSession back (encrypted echo). + echo = encrypt_message_payload(session_id, seq3, session_key) + writer.write(_pack_header(seq3, int(PacketType.ControllerAckSecureSession)) + echo) + await writer.drain() + return session_key + + +@pytest.mark.asyncio +async def test_connection_handshake_flow_with_canned_server() -> None: + """The 4-step handshake completes; state == ONLINE; session key matches.""" + server_done = asyncio.Event() + server_session_key: list[bytes] = [] + + async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + try: + sk = await _do_full_handshake(r, w) + server_session_key.append(sk) + server_done.set() + # Hold the connection open so the client can close cleanly. + with contextlib.suppress(TimeoutError): + await asyncio.wait_for(asyncio.Event().wait(), timeout=5.0) + finally: + w.close() + + server, host, port = await _start_server(handler) + try: + conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY) + await conn.connect() + try: + await asyncio.wait_for(server_done.wait(), timeout=2.0) + assert conn.state is ConnectionState.ONLINE + assert conn.session_key == server_session_key[0] + assert conn.session_key == derive_session_key(CONTROLLER_KEY, SESSION_ID) + finally: + await conn.close() + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_handshake_wrong_key_raises_invalid_encryption_key() -> None: + """Server sends ControllerSessionTerminated after step 3 -> InvalidEncryptionKeyError.""" + + async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + try: + seq1, type1, _ = await _read_packet(r) + assert type1 == int(PacketType.ClientRequestNewSession) + # Step 2: legitimate-looking ack. + w.write( + _pack_header(seq1, int(PacketType.ControllerAckNewSession)) + + bytes([0x00, 0x01]) + + SESSION_ID + ) + await w.drain() + # Step 3: client sends encrypted secure-session req — we read + # and ignore, then send ControllerSessionTerminated. + seq3, _, _ = await _read_packet(r) + w.write(_pack_header(seq3, int(PacketType.ControllerSessionTerminated))) + await w.drain() + finally: + w.close() + + server, host, port = await _start_server(handler) + try: + conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY) + with pytest.raises(InvalidEncryptionKeyError): + await conn.connect() + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_handshake_unsupported_proto_version_raises() -> None: + """A non-(00,01) protocol version in step 2 produces HandshakeError.""" + + async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + try: + seq1, _, _ = await _read_packet(r) + w.write( + _pack_header(seq1, int(PacketType.ControllerAckNewSession)) + + bytes([0x00, 0x02]) # wrong proto + + SESSION_ID + ) + await w.drain() + finally: + w.close() + + server, host, port = await _start_server(handler) + try: + conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY) + with pytest.raises(HandshakeError): + await conn.connect() + finally: + server.close() + await server.wait_closed() + + +# ---- sequence numbers ---------------------------------------------------- + + +def test_sequence_number_increments_per_request() -> None: + """Direct unit test of the seq allocator (no I/O).""" + conn = OmniConnection("0", 1, controller_key=CONTROLLER_KEY) + seqs = [conn._claim_seq() for _ in range(5)] + assert seqs == [1, 2, 3, 4, 5] + + +def test_sequence_number_skips_zero_on_wraparound() -> None: + """After 0xFFFF we go to 1, not 0 (0 is reserved for unsolicited).""" + conn = OmniConnection("0", 1, controller_key=CONTROLLER_KEY) + conn._next_seq = 0xFFFE + a = conn._claim_seq() + b = conn._claim_seq() + c = conn._claim_seq() + assert a == 0xFFFE + assert b == 0xFFFF + assert c == 1 + + +# ---- request / unsolicited dispatch -------------------------------------- + + +@pytest.mark.asyncio +async def test_unsolicited_packet_lands_in_iterator_not_request_future() -> None: + """An inbound packet with seq=0 goes to the unsolicited queue, not any pending future.""" + + async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + try: + sk = await _do_full_handshake(r, w) + # Push an unsolicited Properties-shaped message at seq=0. + inner = encode_v2(OmniLink2MessageType.SystemEvents, b"\x00\x01") + ct = encrypt_message_payload(inner.encode(), 0, sk) + w.write(_pack_header(0, int(PacketType.OmniLink2Message)) + ct) + await w.drain() + # Hold open. + with contextlib.suppress(TimeoutError): + await asyncio.wait_for(asyncio.Event().wait(), timeout=2.0) + finally: + w.close() + + server, host, port = await _start_server(handler) + try: + conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY) + await conn.connect() + try: + received: list[int] = [] + + async def consume() -> None: + async for msg in conn.unsolicited(): + received.append(msg.opcode) + return + + await asyncio.wait_for(consume(), timeout=2.0) + assert received == [int(OmniLink2MessageType.SystemEvents)] + finally: + await conn.close() + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_request_timeout() -> None: + """If the server stays silent, request() raises RequestTimeoutError.""" + + async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + try: + await _do_full_handshake(r, w) + # Read whatever the client sends after handshake but never reply. + with contextlib.suppress(TimeoutError, asyncio.IncompleteReadError): + await asyncio.wait_for(r.readexactly(4 + 16), timeout=2.0) + with contextlib.suppress(TimeoutError): + await asyncio.wait_for(asyncio.Event().wait(), timeout=2.0) + finally: + w.close() + + server, host, port = await _start_server(handler) + try: + conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY) + await conn.connect() + try: + with pytest.raises(RequestTimeoutError): + await conn.request( + OmniLink2MessageType.RequestSystemInformation, timeout=0.2 + ) + finally: + await conn.close() + finally: + server.close() + await server.wait_closed() diff --git a/tests/test_e2e_client_mock.py b/tests/test_e2e_client_mock.py new file mode 100644 index 0000000..cecd2f5 --- /dev/null +++ b/tests/test_e2e_client_mock.py @@ -0,0 +1,101 @@ +"""End-to-end: OmniClient drives a real MockPanel over a real TCP socket. + +This is the integration smoke test that proves the protocol stack actually +roundtrips. Both sides built independently — if framing, sequence numbers, +session-key derivation, or per-block whitening disagree, the handshake fails. +""" + +from __future__ import annotations + +import secrets + +import pytest + +from omni_pca.client import ObjectType, OmniClient +from omni_pca.connection import HandshakeError +from omni_pca.mock_panel import MockPanel, MockState +from omni_pca.models import AreaProperties, UnitProperties, ZoneProperties + +CONTROLLER_KEY = bytes.fromhex("6ba7b4e9b4656de3cd7edd4c650cdb09") + + +@pytest.fixture +def seeded_state() -> MockState: + return MockState( + model_byte=16, + firmware_major=2, + firmware_minor=12, + firmware_revision=1, + zones={1: "FRONT DOOR", 2: "GARAGE ENTRY", 7: "MASTER BED MOT"}, + units={1: "FRONT PORCH", 2: "STAIRS"}, + areas={1: "Main", 2: "Guest"}, + ) + + +async def test_e2e_handshake_then_system_information(seeded_state: MockState) -> None: + panel = MockPanel(controller_key=CONTROLLER_KEY, state=seeded_state) + async with ( + panel.serve() as (host, port), + OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as cli, + ): + info = await cli.get_system_information() + assert info.model_byte == 16 + assert info.model_name == "Omni Pro II" + assert info.firmware_version.startswith("2.12") + assert panel.session_count == 1 + + +async def test_e2e_get_zone_properties(seeded_state: MockState) -> None: + panel = MockPanel(controller_key=CONTROLLER_KEY, state=seeded_state) + async with ( + panel.serve() as (host, port), + OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as cli, + ): + zone = await cli.get_object_properties(ObjectType.ZONE, 1) + assert isinstance(zone, ZoneProperties) + assert zone.index == 1 + assert zone.name == "FRONT DOOR" + + +async def test_e2e_get_unit_properties(seeded_state: MockState) -> None: + panel = MockPanel(controller_key=CONTROLLER_KEY, state=seeded_state) + async with ( + panel.serve() as (host, port), + OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as cli, + ): + unit = await cli.get_object_properties(ObjectType.UNIT, 2) + assert isinstance(unit, UnitProperties) + assert unit.index == 2 + assert unit.name == "STAIRS" + + +async def test_e2e_get_area_properties(seeded_state: MockState) -> None: + panel = MockPanel(controller_key=CONTROLLER_KEY, state=seeded_state) + async with ( + panel.serve() as (host, port), + OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as cli, + ): + area = await cli.get_object_properties(ObjectType.AREA, 1) + assert isinstance(area, AreaProperties) + assert area.index == 1 + assert area.name == "Main" + + +async def test_e2e_list_zone_names(seeded_state: MockState) -> None: + panel = MockPanel(controller_key=CONTROLLER_KEY, state=seeded_state) + async with ( + panel.serve() as (host, port), + OmniClient(host=host, port=port, controller_key=CONTROLLER_KEY) as cli, + ): + names = await cli.list_zone_names() + assert names == {1: "FRONT DOOR", 2: "GARAGE ENTRY", 7: "MASTER BED MOT"} + + +async def test_e2e_wrong_key_fails_with_handshake_error() -> None: + panel = MockPanel(controller_key=CONTROLLER_KEY) + wrong_key = secrets.token_bytes(16) + async with panel.serve() as (host, port): + # pytest.raises is sync; can't combine into the async with above. + with pytest.raises(HandshakeError): + async with OmniClient(host=host, port=port, controller_key=wrong_key) as cli: + await cli.get_system_information() diff --git a/tests/test_mock_panel.py b/tests/test_mock_panel.py new file mode 100644 index 0000000..eaf300b --- /dev/null +++ b/tests/test_mock_panel.py @@ -0,0 +1,282 @@ +"""Unit tests for omni_pca.mock_panel. + +These tests drive the mock with raw primitives only (Packet / Message / +crypto.*) so they double as a sanity check on the handshake and on the +inner-message encoding. Do NOT import the in-progress OmniClient here — +the point is to keep the mock testable independently. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from omni_pca.crypto import ( + BLOCK_SIZE, + decrypt_message_payload, + derive_session_key, + encrypt_message_payload, +) +from omni_pca.message import Message, crc16_modbus, encode_v2 +from omni_pca.mock_panel import MockPanel, MockState +from omni_pca.opcodes import OmniLink2MessageType, PacketType +from omni_pca.packet import Packet + +CONTROLLER_KEY = bytes.fromhex("00112233445566778899aabbccddeeff") +KNOWN_SID = bytes.fromhex("0102030405") + + +async def _readexact(reader: asyncio.StreamReader, n: int) -> bytes: + return await asyncio.wait_for(reader.readexactly(n), timeout=2.0) + + +async def _do_handshake( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter, session_id: bytes +) -> bytes: + """Run NewSession + SecureSession; return the derived session key.""" + # Step 1: client sends ClientRequestNewSession (seq=2, no payload). + writer.write(Packet(seq=2, type=PacketType.ClientRequestNewSession).encode()) + await writer.drain() + + # Step 2: read 4-byte header + 7-byte payload. + header = await _readexact(reader, 4) + assert header[2] == int(PacketType.ControllerAckNewSession) + payload = await _readexact(reader, 7) + assert payload[0] == 0x00 + assert payload[1] == 0x01 + assert payload[2:7] == session_id + + session_key = derive_session_key(CONTROLLER_KEY, session_id) + + # Step 3: encrypt the SessionID and send ClientRequestSecureSession (seq=3). + ciphertext = encrypt_message_payload(session_id, 3, session_key) + writer.write( + Packet( + seq=3, type=PacketType.ClientRequestSecureSession, data=ciphertext + ).encode() + ) + await writer.drain() + + # Step 4: read 4-byte header + 16-byte payload, decrypt, verify echo. + header = await _readexact(reader, 4) + assert header[2] == int(PacketType.ControllerAckSecureSession) + body = await _readexact(reader, BLOCK_SIZE) + plain = decrypt_message_payload(body, 3, session_key) + assert plain[: len(session_id)] == session_id + return session_key + + +async def _send_v2( + writer: asyncio.StreamWriter, + seq: int, + opcode: OmniLink2MessageType | int, + payload: bytes, + session_key: bytes, +) -> None: + msg = encode_v2(opcode, payload) + ciphertext = encrypt_message_payload(msg.encode(), seq, session_key) + writer.write( + Packet(seq=seq, type=PacketType.OmniLink2Message, data=ciphertext).encode() + ) + await writer.drain() + + +async def _recv_v2( + reader: asyncio.StreamReader, seq: int, session_key: bytes +) -> Message: + """Read one v2 reply using the same two-step framing as the real client.""" + header = await _readexact(reader, 4) + assert header[2] == int(PacketType.OmniLink2Message) + first = await _readexact(reader, BLOCK_SIZE) + first_plain = decrypt_message_payload(first, seq, session_key) + msg_length = first_plain[1] + extra_needed = max(0, msg_length + 4 - BLOCK_SIZE) + rem = (-extra_needed) % BLOCK_SIZE + extra_aligned = extra_needed + rem + if extra_aligned: + extra = await _readexact(reader, extra_aligned) + ciphertext = first + extra + else: + ciphertext = first + plain = decrypt_message_payload(ciphertext, seq, session_key) + return Message.decode(plain) + + +@pytest.fixture +def known_sid_panel() -> MockPanel: + return MockPanel( + controller_key=CONTROLLER_KEY, + session_id_provider=lambda: KNOWN_SID, + ) + + +async def test_handshake_completes_with_known_session_id(known_sid_panel: MockPanel) -> None: + async with known_sid_panel.serve() as (host, port): + reader, writer = await asyncio.open_connection(host, port) + try: + session_key = await _do_handshake(reader, writer, KNOWN_SID) + assert session_key == derive_session_key(CONTROLLER_KEY, KNOWN_SID) + assert known_sid_panel.session_count == 1 + finally: + writer.close() + await writer.wait_closed() + + +async def test_request_system_information_returns_model_byte() -> None: + state = MockState( + model_byte=16, firmware_major=2, firmware_minor=12, firmware_revision=1 + ) + panel = MockPanel( + controller_key=CONTROLLER_KEY, + state=state, + session_id_provider=lambda: KNOWN_SID, + ) + async with panel.serve() as (host, port): + reader, writer = await asyncio.open_connection(host, port) + try: + session_key = await _do_handshake(reader, writer, KNOWN_SID) + await _send_v2( + writer, 4, OmniLink2MessageType.RequestSystemInformation, b"", session_key + ) + reply = await _recv_v2(reader, 4, session_key) + assert reply.opcode == int(OmniLink2MessageType.SystemInformation) + assert reply.payload[0] == 16 # model byte + assert reply.payload[1] == 2 # major + assert reply.payload[2] == 12 # minor + assert reply.payload[3] == 1 # revision + assert panel.last_request_opcode == int( + OmniLink2MessageType.RequestSystemInformation + ) + finally: + writer.close() + await writer.wait_closed() + + +async def test_request_properties_for_a_zone() -> None: + state = MockState(zones={1: "FRONT DOOR"}) + panel = MockPanel( + controller_key=CONTROLLER_KEY, + state=state, + session_id_provider=lambda: KNOWN_SID, + ) + async with panel.serve() as (host, port): + reader, writer = await asyncio.open_connection(host, port) + try: + session_key = await _do_handshake(reader, writer, KNOWN_SID) + # ObjectType=Zone(1), IndexNumber=1, RelativeDirection=0, three filter zeros. + req_payload = bytes([1, 0x00, 0x01, 0, 0, 0, 0]) + await _send_v2( + writer, 4, OmniLink2MessageType.RequestProperties, req_payload, session_key + ) + reply = await _recv_v2(reader, 4, session_key) + assert reply.opcode == int(OmniLink2MessageType.Properties) + data = reply.payload # everything after the opcode + assert data[0] == 1 # ObjectType=Zone + assert (data[1] << 8) | data[2] == 1 # ObjectNumber + name_bytes = data[8:23] + assert name_bytes.rstrip(b"\x00").decode("ascii") == "FRONT DOOR" + finally: + writer.close() + await writer.wait_closed() + + +async def test_unknown_opcode_returns_nak() -> None: + panel = MockPanel( + controller_key=CONTROLLER_KEY, session_id_provider=lambda: KNOWN_SID + ) + async with panel.serve() as (host, port): + reader, writer = await asyncio.open_connection(host, port) + try: + session_key = await _do_handshake(reader, writer, KNOWN_SID) + # Pick something obviously unimplemented in the mock. + await _send_v2( + writer, 4, OmniLink2MessageType.RequestEventLogItem, b"\x00\x00\x00", + session_key, + ) + reply = await _recv_v2(reader, 4, session_key) + assert reply.opcode == int(OmniLink2MessageType.Nak) + finally: + writer.close() + await writer.wait_closed() + + +async def test_bad_crc_returns_nak_or_disconnect() -> None: + panel = MockPanel( + controller_key=CONTROLLER_KEY, session_id_provider=lambda: KNOWN_SID + ) + async with panel.serve() as (host, port): + reader, writer = await asyncio.open_connection(host, port) + try: + session_key = await _do_handshake(reader, writer, KNOWN_SID) + # Build a v2 message manually with a corrupted CRC. + opcode = int(OmniLink2MessageType.RequestSystemInformation) + length = 1 + body = bytes([0x21, length, opcode]) + good_crc = crc16_modbus(bytes([length, opcode])) + bad_crc = good_crc ^ 0xFFFF + wire = body + bytes([bad_crc & 0xFF, (bad_crc >> 8) & 0xFF]) + ciphertext = encrypt_message_payload(wire, 4, session_key) + writer.write( + Packet(seq=4, type=PacketType.OmniLink2Message, data=ciphertext).encode() + ) + await writer.drain() + # Either we get a Nak back or the panel hangs up. Both are acceptable. + try: + reply = await _recv_v2(reader, 4, session_key) + except (asyncio.IncompleteReadError, ConnectionError): + return + assert reply.opcode == int(OmniLink2MessageType.Nak) + finally: + writer.close() + await writer.wait_closed() + + +async def test_unencrypted_request_new_session_does_not_require_encryption() -> None: + # The first packet of the handshake MUST work with no crypto in scope. + panel = MockPanel( + controller_key=CONTROLLER_KEY, session_id_provider=lambda: KNOWN_SID + ) + async with panel.serve() as (host, port): + reader, writer = await asyncio.open_connection(host, port) + try: + writer.write( + Packet(seq=2, type=PacketType.ClientRequestNewSession).encode() + ) + await writer.drain() + header = await _readexact(reader, 4) + assert header[2] == int(PacketType.ControllerAckNewSession) + payload = await _readexact(reader, 7) + assert payload[:2] == b"\x00\x01" + assert payload[2:] == KNOWN_SID + finally: + writer.close() + await writer.wait_closed() + + +async def test_request_properties_for_a_unit() -> None: + state = MockState(units={2: "PORCH LIGHT"}) + panel = MockPanel( + controller_key=CONTROLLER_KEY, + state=state, + session_id_provider=lambda: KNOWN_SID, + ) + async with panel.serve() as (host, port): + reader, writer = await asyncio.open_connection(host, port) + try: + session_key = await _do_handshake(reader, writer, KNOWN_SID) + # ObjectType=Unit(2), IndexNumber=2. + req_payload = bytes([2, 0x00, 0x02, 0, 0, 0, 0]) + await _send_v2( + writer, 4, OmniLink2MessageType.RequestProperties, req_payload, session_key + ) + reply = await _recv_v2(reader, 4, session_key) + data = reply.payload + assert data[0] == 2 # Unit + assert (data[1] << 8) | data[2] == 2 # ObjectNumber + # Per clsOL2MsgProperties.cs: Unit name is at Data[8..19], i.e. payload[7..18]. + unit_name = data[7:19].rstrip(b"\x00").decode("ascii") + assert unit_name == "PORCH LIGHT" + finally: + writer.close() + await writer.wait_closed() diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..9f4e6de --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,206 @@ +"""Unit tests for omni_pca.models — payload parsers, no I/O.""" + +from __future__ import annotations + +import pytest + +from omni_pca.models import ( + MODEL_NAMES, + AreaProperties, + SystemInformation, + SystemStatus, + UnitProperties, + ZoneProperties, +) + + +def _name_field(name: str, width: int) -> bytes: + """Pack a name into a fixed-width NUL-padded ASCII field.""" + encoded = name.encode("latin-1") + if len(encoded) > width: + raise ValueError("name too long for field") + return encoded + b"\x00" * (width - len(encoded)) + + +# ---- SystemInformation ---------------------------------------------------- + + +def test_models_system_information_parse() -> None: + payload = bytes([ + 16, # model byte = OMNI_PRO_II + 2, # firmware major + 12, # firmware minor + 1, # firmware revision (positive => release "rN") + ]) + _name_field("415-555-1212", 24) + + info = SystemInformation.parse(payload) + + assert info.model_byte == 16 + assert info.model_name == "Omni Pro II" + assert info.firmware_major == 2 + assert info.firmware_minor == 12 + assert info.firmware_revision == 1 + assert info.firmware_version == "2.12r1" + assert info.local_phone == "415-555-1212" + + +def test_models_system_information_beta_revision() -> None: + """Negative sbyte revision indicates a beta build.""" + payload = bytes([30, 4, 0, 0xFD]) + _name_field("", 24) + info = SystemInformation.parse(payload) + assert info.firmware_revision == -3 + assert info.firmware_version == "4.0b3" + assert info.model_name == "Omni IIe" + + +def test_models_system_information_unknown_model() -> None: + payload = bytes([99, 1, 0, 0]) + _name_field("", 24) + info = SystemInformation.parse(payload) + assert info.model_name.startswith("Unknown") + + +def test_models_system_information_short_payload_rejected() -> None: + with pytest.raises(ValueError): + SystemInformation.parse(b"\x10\x02") + + +def test_models_model_name_table_covers_required() -> None: + for byte, expected_substr in [ + (16, "Omni Pro II"), + (30, "Omni IIe"), + (38, "Omni LTe"), + (36, "Lumina"), + (37, "Lumina Pro"), + ]: + assert MODEL_NAMES[byte] == expected_substr + + +# ---- SystemStatus --------------------------------------------------------- + + +def test_models_system_status_parse() -> None: + # date 2025-12-31 14:30:45, sunrise 06:45, sunset 17:20, battery 0xE0 + payload = bytes([ + 1, # time valid + 25, # year (offset 2000) + 12, + 31, + 4, # day-of-week (Wed-ish; ignored in the dataclass) + 14, + 30, + 45, + 0, # daylight flag + 6, + 45, + 17, + 20, + 0xE0, + ]) + status = SystemStatus.parse(payload) + assert status.time_valid is True + assert status.panel_time is not None + assert status.panel_time.year == 2025 + assert status.panel_time.month == 12 + assert status.panel_time.day == 31 + assert status.panel_time.hour == 14 + assert status.panel_time.minute == 30 + assert status.panel_time.second == 45 + assert status.sunrise_hour == 6 + assert status.sunset_minute == 20 + assert status.battery_reading == 0xE0 + assert status.battery_ok is True + assert status.ac_ok is True + assert status.communication_ok is True + assert status.troubles == () + + +def test_models_system_status_low_battery_flagged() -> None: + payload = bytes([1, 25, 1, 1, 1, 0, 0, 0, 0, 6, 0, 18, 0, 0x10]) + status = SystemStatus.parse(payload) + assert status.battery_ok is False + assert "battery_low" in status.troubles + + +def test_models_system_status_alarm_pairs_extracted() -> None: + base = bytes([1, 25, 1, 1, 1, 0, 0, 0, 0, 6, 0, 18, 0, 0xC0]) + alarms_data = bytes([0x01, 0x02, 0x10, 0x20]) + status = SystemStatus.parse(base + alarms_data) + assert status.area_alarms == ((0x01, 0x02), (0x10, 0x20)) + + +def test_models_system_status_short_payload_rejected() -> None: + with pytest.raises(ValueError): + SystemStatus.parse(b"\x00\x00\x00") + + +# ---- ZoneProperties ------------------------------------------------------- + + +def test_models_zone_properties_parse() -> None: + # object_type=Zone(1), object_number=42, status=0, loop=0, + # zone_type=0 (EntryExit), area=1, options=0, name="Front Door" + payload = ( + bytes([1]) # object type = Zone + + bytes([0, 42]) # object number = 42 (BE) + + bytes([0, 0]) # status, loop + + bytes([0, 1, 0]) # zone type, area, options + + _name_field("Front Door", 15) + ) + zone = ZoneProperties.parse(payload) + assert zone.index == 42 + assert zone.name == "Front Door" + assert zone.zone_type == 0 + assert zone.area == 1 + assert zone.options == 0 + + +def test_models_zone_properties_wrong_object_type_rejected() -> None: + payload = bytes([2, 0, 1, 0, 0, 0, 0, 0]) + _name_field("X", 15) + with pytest.raises(ValueError, match="expected Zone"): + ZoneProperties.parse(payload) + + +# ---- UnitProperties ------------------------------------------------------- + + +def test_models_unit_properties_parse() -> None: + payload = ( + bytes([2]) # Unit + + bytes([0, 7]) # index 7 + + bytes([0]) # status + + bytes([0, 0]) # time + + bytes([1]) # unit_type = Standard + + _name_field("Lamp", 12) + + bytes([0]) # gap byte (Data[20] in C# offset) + + bytes([0x05]) # areas + ) + unit = UnitProperties.parse(payload) + assert unit.index == 7 + assert unit.name == "Lamp" + assert unit.unit_type == 1 + assert unit.areas == 0x05 + + +# ---- AreaProperties ------------------------------------------------------- + + +def test_models_area_properties_parse() -> None: + payload = ( + bytes([5]) # Area + + bytes([0, 1]) # index 1 + + bytes([0]) # mode = Off + + bytes([0]) # alarms + + bytes([0]) # entry timer + + bytes([0]) # exit timer + + bytes([1]) # enabled + + bytes([60]) # exit delay + + bytes([30]) # entry delay + + _name_field("Main", 12) + ) + area = AreaProperties.parse(payload) + assert area.index == 1 + assert area.name == "Main" + assert area.mode == 0 + assert area.enabled is True + assert area.exit_delay == 60 + assert area.entry_delay == 30