src/omni_pca/connection.py — low-level OmniConnection
- 4-step secure-session handshake (NewSession, SecureSession)
- Per-direction monotonic seq with 0xFFFF -> 1 wraparound (skips 0)
- TCP framing: read first 16-byte block, decrypt, learn length, read rest
- Reader task dispatches solicited replies to Future, unsolicited to queue
- Custom exceptions: HandshakeError, InvalidEncryptionKeyError, ProtocolError,
RequestTimeoutError
src/omni_pca/models.py — typed response objects
- SystemInformation (with model_name lookup), SystemStatus, ZoneProperties,
UnitProperties, AreaProperties — all frozen+slots dataclasses with
.parse(payload) classmethods
src/omni_pca/client.py — high-level OmniClient
- get_system_information / get_system_status / get_object_properties
- list_{zone,unit,area}_names walks via RequestProperties rel=1
- subscribe(callback) for unsolicited messages
src/omni_pca/mock_panel.py — async TCP server emulating an Omni Pro II
- Full handshake (controller side), seedable MockState
- Implements RequestSystemInformation, RequestSystemStatus,
RequestProperties (Zone/Unit/Area, both absolute and rel=1 iteration
with EOD termination); Nak for everything else
- 'omni-pca mock-panel' CLI subcommand
tests/ — 85 passed, 1 skip (live fixture)
- 23 unit tests for connection/models/client (canned-server fixtures)
- 7 unit tests for mock panel (raw protocol drive)
- 6 e2e tests: real OmniClient over real TCP to real MockPanel,
proves handshake + AES + whitening + sequencing all agree
298 lines
11 KiB
Python
298 lines
11 KiB
Python
"""Unit tests for omni_pca.connection.
|
|
|
|
These spin up tiny ``asyncio.start_server`` mock controllers inside the
|
|
test, byte-for-byte; nothing depends on a real panel or the (parallel)
|
|
mock_panel module.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import struct
|
|
from collections.abc import Awaitable, Callable
|
|
|
|
import pytest
|
|
|
|
from omni_pca.connection import (
|
|
ConnectionState,
|
|
HandshakeError,
|
|
InvalidEncryptionKeyError,
|
|
OmniConnection,
|
|
RequestTimeoutError,
|
|
)
|
|
from omni_pca.crypto import (
|
|
decrypt_message_payload,
|
|
derive_session_key,
|
|
encrypt_message_payload,
|
|
)
|
|
from omni_pca.message import encode_v2
|
|
from omni_pca.opcodes import OmniLink2MessageType, PacketType
|
|
|
|
# A canonical 16-byte ControllerKey for tests.
|
|
CONTROLLER_KEY = bytes.fromhex("000102030405060708090a0b0c0d0e0f")
|
|
SESSION_ID = bytes([0x10, 0x11, 0x12, 0x13, 0x14])
|
|
|
|
|
|
def _pack_header(seq: int, type_byte: int) -> bytes:
|
|
return struct.pack(">HBB", seq, type_byte, 0)
|
|
|
|
|
|
async def _read_packet(reader: asyncio.StreamReader) -> tuple[int, int, bytes]:
|
|
"""Read one outer-frame packet (full bytes), returning (seq, type, data).
|
|
|
|
For the cleartext control packets we know the payload size from the
|
|
type. For OmniLink2Message we mirror the client-side
|
|
decrypt-first-block-to-learn-length dance.
|
|
"""
|
|
header = await reader.readexactly(4)
|
|
seq = (header[0] << 8) | header[1]
|
|
type_byte = header[2]
|
|
if type_byte == int(PacketType.ClientRequestNewSession):
|
|
return seq, type_byte, b""
|
|
if type_byte == int(PacketType.ClientRequestSecureSession):
|
|
return seq, type_byte, await reader.readexactly(16)
|
|
if type_byte == int(PacketType.ClientSessionTerminated):
|
|
return seq, type_byte, b""
|
|
if type_byte == int(PacketType.OmniLink2Message):
|
|
# Read first block, decrypt to learn length.
|
|
first = await reader.readexactly(16)
|
|
# Caller knows the session key at this point — but for the tests
|
|
# we just return the ciphertext + raw bytes; the test will
|
|
# decrypt manually if it cares.
|
|
return seq, type_byte, first
|
|
raise AssertionError(f"unexpected client packet type {type_byte}")
|
|
|
|
|
|
# ---- handshake -----------------------------------------------------------
|
|
|
|
|
|
async def _start_server(
|
|
handler: Callable[[asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]],
|
|
) -> tuple[asyncio.Server, str, int]:
|
|
server = await asyncio.start_server(handler, "127.0.0.1", 0)
|
|
sockets = server.sockets
|
|
assert sockets, "server has no listening sockets"
|
|
host, port = sockets[0].getsockname()[:2]
|
|
return server, host, port
|
|
|
|
|
|
async def _do_full_handshake(
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
*,
|
|
controller_key: bytes = CONTROLLER_KEY,
|
|
session_id: bytes = SESSION_ID,
|
|
) -> bytes:
|
|
"""Server-side: complete the 4-step handshake, return derived session key."""
|
|
# Step 1: client sends empty ClientRequestNewSession.
|
|
seq1, type1, _ = await _read_packet(reader)
|
|
assert type1 == int(PacketType.ClientRequestNewSession)
|
|
assert seq1 == 1
|
|
|
|
# Step 2: send ControllerAckNewSession back, echoing seq=1.
|
|
proto = bytes([0x00, 0x01])
|
|
writer.write(_pack_header(seq1, int(PacketType.ControllerAckNewSession)) + proto + session_id)
|
|
await writer.drain()
|
|
|
|
# Step 3: client sends ClientRequestSecureSession (encrypted).
|
|
seq3, type3, ct = await _read_packet(reader)
|
|
assert type3 == int(PacketType.ClientRequestSecureSession)
|
|
assert seq3 == 2
|
|
session_key = derive_session_key(controller_key, session_id)
|
|
plain = decrypt_message_payload(ct, seq3, session_key)
|
|
assert plain[:5] == session_id
|
|
|
|
# Step 4: send ControllerAckSecureSession back (encrypted echo).
|
|
echo = encrypt_message_payload(session_id, seq3, session_key)
|
|
writer.write(_pack_header(seq3, int(PacketType.ControllerAckSecureSession)) + echo)
|
|
await writer.drain()
|
|
return session_key
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_handshake_flow_with_canned_server() -> None:
|
|
"""The 4-step handshake completes; state == ONLINE; session key matches."""
|
|
server_done = asyncio.Event()
|
|
server_session_key: list[bytes] = []
|
|
|
|
async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
|
try:
|
|
sk = await _do_full_handshake(r, w)
|
|
server_session_key.append(sk)
|
|
server_done.set()
|
|
# Hold the connection open so the client can close cleanly.
|
|
with contextlib.suppress(TimeoutError):
|
|
await asyncio.wait_for(asyncio.Event().wait(), timeout=5.0)
|
|
finally:
|
|
w.close()
|
|
|
|
server, host, port = await _start_server(handler)
|
|
try:
|
|
conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY)
|
|
await conn.connect()
|
|
try:
|
|
await asyncio.wait_for(server_done.wait(), timeout=2.0)
|
|
assert conn.state is ConnectionState.ONLINE
|
|
assert conn.session_key == server_session_key[0]
|
|
assert conn.session_key == derive_session_key(CONTROLLER_KEY, SESSION_ID)
|
|
finally:
|
|
await conn.close()
|
|
finally:
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handshake_wrong_key_raises_invalid_encryption_key() -> None:
|
|
"""Server sends ControllerSessionTerminated after step 3 -> InvalidEncryptionKeyError."""
|
|
|
|
async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
|
try:
|
|
seq1, type1, _ = await _read_packet(r)
|
|
assert type1 == int(PacketType.ClientRequestNewSession)
|
|
# Step 2: legitimate-looking ack.
|
|
w.write(
|
|
_pack_header(seq1, int(PacketType.ControllerAckNewSession))
|
|
+ bytes([0x00, 0x01])
|
|
+ SESSION_ID
|
|
)
|
|
await w.drain()
|
|
# Step 3: client sends encrypted secure-session req — we read
|
|
# and ignore, then send ControllerSessionTerminated.
|
|
seq3, _, _ = await _read_packet(r)
|
|
w.write(_pack_header(seq3, int(PacketType.ControllerSessionTerminated)))
|
|
await w.drain()
|
|
finally:
|
|
w.close()
|
|
|
|
server, host, port = await _start_server(handler)
|
|
try:
|
|
conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY)
|
|
with pytest.raises(InvalidEncryptionKeyError):
|
|
await conn.connect()
|
|
finally:
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handshake_unsupported_proto_version_raises() -> None:
|
|
"""A non-(00,01) protocol version in step 2 produces HandshakeError."""
|
|
|
|
async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
|
try:
|
|
seq1, _, _ = await _read_packet(r)
|
|
w.write(
|
|
_pack_header(seq1, int(PacketType.ControllerAckNewSession))
|
|
+ bytes([0x00, 0x02]) # wrong proto
|
|
+ SESSION_ID
|
|
)
|
|
await w.drain()
|
|
finally:
|
|
w.close()
|
|
|
|
server, host, port = await _start_server(handler)
|
|
try:
|
|
conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY)
|
|
with pytest.raises(HandshakeError):
|
|
await conn.connect()
|
|
finally:
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
|
|
# ---- sequence numbers ----------------------------------------------------
|
|
|
|
|
|
def test_sequence_number_increments_per_request() -> None:
|
|
"""Direct unit test of the seq allocator (no I/O)."""
|
|
conn = OmniConnection("0", 1, controller_key=CONTROLLER_KEY)
|
|
seqs = [conn._claim_seq() for _ in range(5)]
|
|
assert seqs == [1, 2, 3, 4, 5]
|
|
|
|
|
|
def test_sequence_number_skips_zero_on_wraparound() -> None:
|
|
"""After 0xFFFF we go to 1, not 0 (0 is reserved for unsolicited)."""
|
|
conn = OmniConnection("0", 1, controller_key=CONTROLLER_KEY)
|
|
conn._next_seq = 0xFFFE
|
|
a = conn._claim_seq()
|
|
b = conn._claim_seq()
|
|
c = conn._claim_seq()
|
|
assert a == 0xFFFE
|
|
assert b == 0xFFFF
|
|
assert c == 1
|
|
|
|
|
|
# ---- request / unsolicited dispatch --------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsolicited_packet_lands_in_iterator_not_request_future() -> None:
|
|
"""An inbound packet with seq=0 goes to the unsolicited queue, not any pending future."""
|
|
|
|
async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
|
try:
|
|
sk = await _do_full_handshake(r, w)
|
|
# Push an unsolicited Properties-shaped message at seq=0.
|
|
inner = encode_v2(OmniLink2MessageType.SystemEvents, b"\x00\x01")
|
|
ct = encrypt_message_payload(inner.encode(), 0, sk)
|
|
w.write(_pack_header(0, int(PacketType.OmniLink2Message)) + ct)
|
|
await w.drain()
|
|
# Hold open.
|
|
with contextlib.suppress(TimeoutError):
|
|
await asyncio.wait_for(asyncio.Event().wait(), timeout=2.0)
|
|
finally:
|
|
w.close()
|
|
|
|
server, host, port = await _start_server(handler)
|
|
try:
|
|
conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY)
|
|
await conn.connect()
|
|
try:
|
|
received: list[int] = []
|
|
|
|
async def consume() -> None:
|
|
async for msg in conn.unsolicited():
|
|
received.append(msg.opcode)
|
|
return
|
|
|
|
await asyncio.wait_for(consume(), timeout=2.0)
|
|
assert received == [int(OmniLink2MessageType.SystemEvents)]
|
|
finally:
|
|
await conn.close()
|
|
finally:
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_timeout() -> None:
|
|
"""If the server stays silent, request() raises RequestTimeoutError."""
|
|
|
|
async def handler(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
|
try:
|
|
await _do_full_handshake(r, w)
|
|
# Read whatever the client sends after handshake but never reply.
|
|
with contextlib.suppress(TimeoutError, asyncio.IncompleteReadError):
|
|
await asyncio.wait_for(r.readexactly(4 + 16), timeout=2.0)
|
|
with contextlib.suppress(TimeoutError):
|
|
await asyncio.wait_for(asyncio.Event().wait(), timeout=2.0)
|
|
finally:
|
|
w.close()
|
|
|
|
server, host, port = await _start_server(handler)
|
|
try:
|
|
conn = OmniConnection(host=host, port=port, controller_key=CONTROLLER_KEY)
|
|
await conn.connect()
|
|
try:
|
|
with pytest.raises(RequestTimeoutError):
|
|
await conn.request(
|
|
OmniLink2MessageType.RequestSystemInformation, timeout=0.2
|
|
)
|
|
finally:
|
|
await conn.close()
|
|
finally:
|
|
server.close()
|
|
await server.wait_closed()
|