Ryan Malloy ea22f2f9db Fix async/await bugs found by headless E2E test
- Make get_client() sync (was async but did no async work). Callers
  that omitted await silently got a coroutine object instead of the
  SerialClient, causing "'coroutine' object has no attribute 'connect'"
  errors on every tool call.

- Fix esp32_connect: use get_client_or_none() for init check and
  client.event_queue.wait_for() for boot event (wait_event() didn't
  exist on SerialClient).

- Normalise Response.data to dict at parse time — firmware returns
  bare strings on some error paths, which broke .get() calls in tool
  error handlers.

- Remove stale await from ble.py (9 calls) and classic.py (4 calls).

Tested with dual-MCP headless claude session: 26/27 PASS.
2026-02-02 15:54:36 -07:00

226 lines
6.0 KiB
Python

"""Python protocol layer mirroring the ESP32 firmware's protocol.h — NDJSON over UART."""
from __future__ import annotations
import json
import threading
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class MsgType(StrEnum):
CMD = "cmd"
RESP = "resp"
EVENT = "event"
class Status(StrEnum):
OK = "ok"
ERROR = "error"
class IOCapability(StrEnum):
DISPLAY_ONLY = "display_only"
DISPLAY_YESNO = "display_yesno"
KEYBOARD_ONLY = "keyboard_only"
NO_IO = "no_io"
KEYBOARD_DISPLAY = "keyboard_display"
class Transport(StrEnum):
CLASSIC = "classic"
BLE = "ble"
BOTH = "both"
# ---------------------------------------------------------------------------
# Command strings (mirror firmware #defines)
# ---------------------------------------------------------------------------
# System
CMD_PING = "ping"
CMD_RESET = "reset"
CMD_GET_INFO = "get_info"
CMD_GET_STATUS = "get_status"
# Configuration
CMD_CONFIGURE = "configure"
CMD_LOAD_PERSONA = "load_persona"
CMD_LIST_PERSONAS = "list_personas"
CMD_CLASSIC_SET_SSP_MODE = "classic_set_ssp_mode"
# Classic BT
CMD_CLASSIC_ENABLE = "classic_enable"
CMD_CLASSIC_DISABLE = "classic_disable"
CMD_CLASSIC_SET_DISCOVERABLE = "classic_set_discoverable"
CMD_CLASSIC_PAIR_RESPOND = "classic_pair_respond"
# BLE
CMD_BLE_ENABLE = "ble_enable"
CMD_BLE_DISABLE = "ble_disable"
CMD_BLE_ADVERTISE = "ble_advertise"
CMD_BLE_SET_ADV_DATA = "ble_set_adv_data"
# GATT
CMD_GATT_ADD_SERVICE = "gatt_add_service"
CMD_GATT_ADD_CHARACTERISTIC = "gatt_add_characteristic"
CMD_GATT_SET_VALUE = "gatt_set_value"
CMD_GATT_NOTIFY = "gatt_notify"
CMD_GATT_CLEAR = "gatt_clear"
# Events
EVT_BOOT = "boot"
EVT_PAIR_REQUEST = "pair_request"
EVT_PAIR_COMPLETE = "pair_complete"
EVT_CONNECT = "connect"
EVT_DISCONNECT = "disconnect"
EVT_GATT_READ = "gatt_read"
EVT_GATT_WRITE = "gatt_write"
EVT_GATT_SUBSCRIBE = "gatt_subscribe"
# ---------------------------------------------------------------------------
# Protocol constants
# ---------------------------------------------------------------------------
BAUD_RATE: int = 115200
MAX_LINE_LENGTH: int = 2048
# ---------------------------------------------------------------------------
# Monotonic ID generator (thread-safe)
# ---------------------------------------------------------------------------
_id_counter: int = 0
_id_lock = threading.Lock()
def _next_id() -> str:
"""Return a monotonically increasing string ID like '1', '2', ..."""
global _id_counter
with _id_lock:
_id_counter += 1
return str(_id_counter)
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass(slots=True)
class Command:
type: MsgType
id: str
cmd: str
params: dict[str, Any] = field(default_factory=dict)
def to_json(self) -> str:
"""Serialize to a single NDJSON line (no trailing newline)."""
obj: dict[str, Any] = {"type": self.type, "id": self.id, "cmd": self.cmd}
if self.params:
obj["params"] = self.params
return json.dumps(obj, separators=(",", ":"))
@dataclass(slots=True)
class Response:
type: MsgType
id: str
status: Status
data: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_json(cls, line: str) -> Response:
"""Parse a JSON line known to be a response."""
obj = json.loads(line)
raw_data = obj.get("data", {})
# Firmware may return a bare string on some error paths — normalise to dict
if isinstance(raw_data, str):
raw_data = {"error": raw_data}
return cls(
type=MsgType(obj["type"]),
id=obj["id"],
status=Status(obj["status"]),
data=raw_data if isinstance(raw_data, dict) else {},
)
@dataclass(slots=True)
class Event:
type: MsgType
event: str
data: dict[str, Any] = field(default_factory=dict)
ts: int | None = None
@classmethod
def from_json(cls, line: str) -> Event:
"""Parse a JSON line known to be an event."""
obj = json.loads(line)
return cls(
type=MsgType(obj["type"]),
event=obj["event"],
data=obj.get("data", {}),
ts=obj.get("ts"),
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def build_command(
cmd: str, params: dict[str, Any] | None = None, cmd_id: str | None = None
) -> Command:
"""Build a Command with an auto-generated monotonic ID if none is provided."""
return Command(
type=MsgType.CMD,
id=cmd_id or _next_id(),
cmd=cmd,
params=params or {},
)
def parse_message(line: str) -> Command | Response | Event:
"""Parse any NDJSON line into the appropriate message type.
Raises ValueError if the line is not valid JSON or has an unknown message type.
"""
try:
obj = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"invalid JSON: {exc}") from exc
msg_type = obj.get("type")
if msg_type == MsgType.CMD:
return Command(
type=MsgType.CMD,
id=obj["id"],
cmd=obj["cmd"],
params=obj.get("params", {}),
)
if msg_type == MsgType.RESP:
return Response(
type=MsgType.RESP,
id=obj["id"],
status=Status(obj["status"]),
data=obj.get("data", {}),
)
if msg_type == MsgType.EVENT:
return Event(
type=MsgType.EVENT,
event=obj["event"],
data=obj.get("data", {}),
ts=obj.get("ts"),
)
raise ValueError(f"unknown message type: {msg_type!r}")