From d17037f2a14df61e47d5e29419e84d4d0b0f22cd Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Sun, 15 Feb 2026 16:36:25 -0700 Subject: [PATCH] Add SWD/DAP subsystem with DP/AP register access and AP enumeration 10th subsystem: session.swd provides DAP discovery, DP/AP register read/write, AP enumeration, and convenience methods (dpidr, target_id). Includes SWDError, DAPInfo, APInfo types, input validation, ADIv5/v6 AP classification, and 24 mock-only tests covering happy and error paths. --- src/openocd/__init__.py | 6 + src/openocd/errors.py | 4 + src/openocd/session.py | 41 +++++ src/openocd/swd/__init__.py | 5 + src/openocd/swd/controller.py | 148 ++++++++++++++++++ src/openocd/swd/dap.py | 199 +++++++++++++++++++++++ src/openocd/types.py | 33 ++++ tests/mock_server.py | 46 ++++-- tests/test_session.py | 8 + tests/test_swd.py | 286 ++++++++++++++++++++++++++++++++++ 10 files changed, 763 insertions(+), 13 deletions(-) create mode 100644 src/openocd/swd/__init__.py create mode 100644 src/openocd/swd/controller.py create mode 100644 src/openocd/swd/dap.py create mode 100644 tests/test_swd.py diff --git a/src/openocd/__init__.py b/src/openocd/__init__.py index 145dfcd..03ed44e 100644 --- a/src/openocd/__init__.py +++ b/src/openocd/__init__.py @@ -7,14 +7,17 @@ from openocd.errors import ( OpenOCDError, ProcessError, SVDError, + SWDError, TargetError, TargetNotHaltedError, TimeoutError, ) from openocd.session import Session, SyncSession from openocd.types import ( + APInfo, BitField, Breakpoint, + DAPInfo, DecodedRegister, FlashBank, FlashSector, @@ -32,8 +35,10 @@ __all__ = [ "Session", "SyncSession", # Types + "APInfo", "BitField", "Breakpoint", + "DAPInfo", "DecodedRegister", "FlashBank", "FlashSector", @@ -51,6 +56,7 @@ __all__ = [ "OpenOCDError", "ProcessError", "SVDError", + "SWDError", "TargetError", "TargetNotHaltedError", "TimeoutError", diff --git a/src/openocd/errors.py b/src/openocd/errors.py index 8805287..efe22da 100644 --- a/src/openocd/errors.py +++ b/src/openocd/errors.py @@ -39,5 +39,9 @@ class SVDError(OpenOCDError): """SVD file not found, failed to parse, or lookup error.""" +class SWDError(OpenOCDError): + """Raised when an SWD/DAP operation fails.""" + + class ProcessError(OpenOCDError): """OpenOCD subprocess failed to start or exited unexpectedly.""" diff --git a/src/openocd/session.py b/src/openocd/session.py index 33b86e4..0a808f0 100644 --- a/src/openocd/session.py +++ b/src/openocd/session.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from openocd.registers import Registers, SyncRegisters from openocd.rtt import RTTManager from openocd.svd import SVDManager, SyncSVDManager + from openocd.swd import SWDController, SyncSWDController from openocd.target import SyncTarget, Target from openocd.transport import Transport @@ -40,6 +41,7 @@ class Session: self._registers: Registers | None = None self._flash: Flash | None = None self._jtag: JTAGController | None = None + self._swd: SWDController | None = None self._breakpoints: BreakpointManager | None = None self._rtt: RTTManager | None = None self._svd: SVDManager | None = None @@ -145,6 +147,7 @@ class Session: def target(self) -> Target: if self._target is None: from openocd.target import Target + self._target = Target(self._conn) return self._target @@ -152,6 +155,7 @@ class Session: def memory(self) -> Memory: if self._memory is None: from openocd.memory import Memory + self._memory = Memory(self._conn) return self._memory @@ -159,6 +163,7 @@ class Session: def registers(self) -> Registers: if self._registers is None: from openocd.registers import Registers + self._registers = Registers(self._conn) return self._registers @@ -166,6 +171,7 @@ class Session: def flash(self) -> Flash: if self._flash is None: from openocd.flash import Flash + self._flash = Flash(self._conn) return self._flash @@ -173,13 +179,23 @@ class Session: def jtag(self) -> JTAGController: if self._jtag is None: from openocd.jtag import JTAGController + self._jtag = JTAGController(self._conn) return self._jtag + @property + def swd(self) -> SWDController: + if self._swd is None: + from openocd.swd import SWDController + + self._swd = SWDController(self._conn) + return self._swd + @property def breakpoints(self) -> BreakpointManager: if self._breakpoints is None: from openocd.breakpoints import BreakpointManager + self._breakpoints = BreakpointManager(self._conn) return self._breakpoints @@ -187,6 +203,7 @@ class Session: def rtt(self) -> RTTManager: if self._rtt is None: from openocd.rtt import RTTManager + self._rtt = RTTManager(self._conn) return self._rtt @@ -194,6 +211,7 @@ class Session: def svd(self) -> SVDManager: if self._svd is None: from openocd.svd import SVDManager + self._svd = SVDManager(self._conn, self.memory) return self._svd @@ -201,6 +219,7 @@ class Session: def transport(self) -> Transport: if self._transport is None: from openocd.transport import Transport + self._transport = Transport(self._conn) return self._transport @@ -210,16 +229,20 @@ class Session: def on_halt(self, callback: Callable[[str], None]) -> None: """Register a callback for target halt events.""" + def _filter(msg: str) -> None: if "halted" in msg.lower(): callback(msg) + self._conn.on_notification(_filter) def on_reset(self, callback: Callable[[str], None]) -> None: """Register a callback for target reset events.""" + def _filter(msg: str) -> None: if "reset" in msg.lower(): callback(msg) + self._conn.on_notification(_filter) @@ -227,6 +250,7 @@ class Session: # Sync wrapper # ====================================================================== + class SyncSession: """Wraps an async Session for synchronous use.""" @@ -238,6 +262,7 @@ class SyncSession: self._registers: SyncRegisters | None = None self._flash: SyncFlash | None = None self._jtag: SyncJTAGController | None = None + self._swd: SyncSWDController | None = None self._breakpoints: SyncBreakpointManager | None = None self._svd: SyncSVDManager | None = None @@ -254,6 +279,7 @@ class SyncSession: def target(self) -> SyncTarget: if self._target is None: from openocd.target import SyncTarget + self._target = SyncTarget(self._session.target, self._loop) return self._target @@ -261,6 +287,7 @@ class SyncSession: def memory(self) -> SyncMemory: if self._memory is None: from openocd.memory import SyncMemory + self._memory = SyncMemory(self._session.memory, self._loop) return self._memory @@ -268,6 +295,7 @@ class SyncSession: def registers(self) -> SyncRegisters: if self._registers is None: from openocd.registers import SyncRegisters + self._registers = SyncRegisters(self._session.registers, self._loop) return self._registers @@ -275,6 +303,7 @@ class SyncSession: def flash(self) -> SyncFlash: if self._flash is None: from openocd.flash import SyncFlash + self._flash = SyncFlash(self._session.flash, self._loop) return self._flash @@ -282,13 +311,23 @@ class SyncSession: def jtag(self) -> SyncJTAGController: if self._jtag is None: from openocd.jtag import SyncJTAGController + self._jtag = SyncJTAGController(self._session.jtag, self._loop) return self._jtag + @property + def swd(self) -> SyncSWDController: + if self._swd is None: + from openocd.swd import SyncSWDController + + self._swd = SyncSWDController(self._session.swd, self._loop) + return self._swd + @property def breakpoints(self) -> SyncBreakpointManager: if self._breakpoints is None: from openocd.breakpoints import SyncBreakpointManager + self._breakpoints = SyncBreakpointManager(self._session.breakpoints, self._loop) return self._breakpoints @@ -296,6 +335,7 @@ class SyncSession: def svd(self) -> SyncSVDManager: if self._svd is None: from openocd.svd import SyncSVDManager + self._svd = SyncSVDManager(self._session.svd, self._loop) return self._svd @@ -304,6 +344,7 @@ class SyncSession: # Helpers # ====================================================================== + def _get_or_create_loop() -> asyncio.AbstractEventLoop: """Get or create an event loop for synchronous usage. diff --git a/src/openocd/swd/__init__.py b/src/openocd/swd/__init__.py new file mode 100644 index 0000000..7d32b45 --- /dev/null +++ b/src/openocd/swd/__init__.py @@ -0,0 +1,5 @@ +"""SWD/DAP operations: DP/AP register access and DAP discovery.""" + +from openocd.swd.controller import SWDController, SyncSWDController + +__all__ = ["SWDController", "SyncSWDController"] diff --git a/src/openocd/swd/controller.py b/src/openocd/swd/controller.py new file mode 100644 index 0000000..a16564f --- /dev/null +++ b/src/openocd/swd/controller.py @@ -0,0 +1,148 @@ +"""SWDController — unified facade for SWD/DAP operations.""" + +from __future__ import annotations + +import asyncio +import logging + +from openocd.connection.base import Connection +from openocd.errors import SWDError +from openocd.swd import dap as _dap +from openocd.types import APInfo, DAPInfo + +log = logging.getLogger(__name__) + + +class SWDController: + """High-level async interface to SWD/DAP operations. + + Most boards have a single DAP. When *dap* is ``None``, the controller + auto-discovers via ``dap names`` and uses the first (or only) DAP. + Multi-DAP boards (e.g. STM32H7 dual-core) pass the DAP name explicitly. + """ + + def __init__(self, conn: Connection) -> None: + self._conn = conn + self._cached_dap: str | None = None + + # -- DAP name resolution ----------------------------------------------- + + async def _resolve_dap(self, dap: str | None) -> str: + """Return the DAP name to use: explicit or auto-discovered. + + When *dap* is ``None``, uses the auto-discovered DAP (first result + from ``dap names``). Once resolved, the name is cached for the + lifetime of this controller unless :meth:`invalidate_cache` is called. + """ + if dap is not None: + return dap + if self._cached_dap is not None: + return self._cached_dap + + names = await _dap.dap_names(self._conn) + if not names: + raise SWDError("No DAP instances found (is the transport set to SWD?)") + self._cached_dap = names[0] + log.debug("Auto-resolved DAP: %s", self._cached_dap) + return self._cached_dap + + def invalidate_cache(self) -> None: + """Clear the cached DAP name. + + Call after transport changes, probe reconnection, or target + reconfiguration that may change which DAPs are available. + """ + self._cached_dap = None + log.debug("DAP cache invalidated") + + # -- DAP discovery ----------------------------------------------------- + + async def info(self, dap: str | None = None) -> DAPInfo: + """Query DAP information.""" + name = await self._resolve_dap(dap) + return await _dap.dap_info(self._conn, name) + + async def list_aps(self, dap: str | None = None) -> list[APInfo]: + """Enumerate Access Ports on the DAP.""" + name = await self._resolve_dap(dap) + return await _dap.enumerate_aps(self._conn, name) + + # -- DP register access ------------------------------------------------ + + async def dpreg(self, address: int, value: int | None = None, *, dap: str | None = None) -> int: + """Read or write a DP register. + + When *value* is ``None``, performs a read and returns the value. + When *value* is provided, performs a write and returns the written value. + """ + name = await self._resolve_dap(dap) + if value is None: + return await _dap.dpreg_read(self._conn, name, address) + await _dap.dpreg_write(self._conn, name, address, value) + return value + + # -- AP register access ------------------------------------------------ + + async def apreg( + self, ap: int, address: int, value: int | None = None, *, dap: str | None = None + ) -> int: + """Read or write an AP register. + + When *value* is ``None``, performs a read and returns the value. + When *value* is provided, performs a write and returns the written value. + """ + name = await self._resolve_dap(dap) + if value is None: + return await _dap.apreg_read(self._conn, name, ap, address) + await _dap.apreg_write(self._conn, name, ap, address, value) + return value + + # -- Convenience: well-known DP registers ------------------------------ + + async def dpidr(self, dap: str | None = None) -> int: + """Read the DP IDR (address 0x0) — identifies the debug port.""" + return await self.dpreg(0x0, dap=dap) + + async def target_id(self, dap: str | None = None) -> int: + """Read the TARGETID register (DP address 0x24, DPv2+).""" + return await self.dpreg(0x24, dap=dap) + + +# ====================================================================== +# SyncSWDController — blocking wrappers +# ====================================================================== + + +class SyncSWDController: + """Synchronous wrapper around :class:`SWDController`. + + Every async method is exposed with the same signature but runs + through ``loop.run_until_complete``. + """ + + def __init__(self, ctrl: SWDController, loop: asyncio.AbstractEventLoop) -> None: + self._ctrl = ctrl + self._loop = loop + + def info(self, dap: str | None = None) -> DAPInfo: + return self._loop.run_until_complete(self._ctrl.info(dap)) + + def list_aps(self, dap: str | None = None) -> list[APInfo]: + return self._loop.run_until_complete(self._ctrl.list_aps(dap)) + + def dpreg(self, address: int, value: int | None = None, *, dap: str | None = None) -> int: + return self._loop.run_until_complete(self._ctrl.dpreg(address, value, dap=dap)) + + def apreg( + self, ap: int, address: int, value: int | None = None, *, dap: str | None = None + ) -> int: + return self._loop.run_until_complete(self._ctrl.apreg(ap, address, value, dap=dap)) + + def dpidr(self, dap: str | None = None) -> int: + return self._loop.run_until_complete(self._ctrl.dpidr(dap)) + + def target_id(self, dap: str | None = None) -> int: + return self._loop.run_until_complete(self._ctrl.target_id(dap)) + + def invalidate_cache(self) -> None: + self._ctrl.invalidate_cache() diff --git a/src/openocd/swd/dap.py b/src/openocd/swd/dap.py new file mode 100644 index 0000000..e340059 --- /dev/null +++ b/src/openocd/swd/dap.py @@ -0,0 +1,199 @@ +"""Low-level DAP functions for SWD/DAP register access. + +All functions take a connection and a DAP name, then issue the +corresponding OpenOCD ```` sub-commands. Parsing is defensive +because OpenOCD output varies between versions. +""" + +from __future__ import annotations + +import logging +import re + +from openocd.connection.base import Connection +from openocd.errors import SWDError +from openocd.types import APInfo, DAPInfo + +log = logging.getLogger(__name__) + +# Match a hex value anywhere in the response (OpenOCD returns "0x2ba01477\n") +_HEX_RE = re.compile(r"0x([0-9a-fA-F]+)") + +# Count APs in dap info output — looks for "AP # " lines +_AP_NUM_RE = re.compile(r"AP\s*#?\s*(\d+)") + +# DPIDR line in dap info output +_DPIDR_RE = re.compile(r"DPIDR\s*[:=]?\s*(0x[0-9a-fA-F]+)", re.IGNORECASE) + +# OpenOCD error patterns: match the structure of actual error responses, +# not arbitrary English words. Avoids false positives on output like +# "error detection enabled" or register descriptions containing "invalid". +_ERROR_RE = re.compile( + r"^Error:|^invalid command|^invalid|command not found", + re.IGNORECASE | re.MULTILINE, +) + +_U32_MAX = 0xFFFFFFFF +_AP_MAX = 255 + + +def _validate_u32(value: int, name: str) -> None: + """Ensure value is a valid unsigned 32-bit integer.""" + if not isinstance(value, int) or value < 0 or value > _U32_MAX: + raise SWDError(f"{name} must be 0..0xFFFFFFFF, got {value!r}") + + +def _validate_ap_num(ap_num: int) -> None: + """Ensure AP number is in the valid range (0-255 per ARM ADI spec).""" + if not isinstance(ap_num, int) or ap_num < 0 or ap_num > _AP_MAX: + raise SWDError(f"AP number must be 0..255, got {ap_num!r}") + + +def _parse_hex(resp: str, context: str) -> int: + """Extract the first hex value from an OpenOCD response string.""" + m = _HEX_RE.search(resp) + if m is None: + raise SWDError(f"{context}: no hex value in response: {resp.strip()!r}") + return int(m.group(1), 16) + + +def _check_error(resp: str, context: str) -> None: + """Raise SWDError if the response indicates a failure. + + Matches OpenOCD's actual error response patterns (``Error:``, + ``invalid command``) rather than naive substring matching, to avoid + false positives on legitimate output containing words like "error". + """ + if _ERROR_RE.search(resp): + raise SWDError(f"{context}: {resp.strip()}") + + +async def dap_names(conn: Connection) -> list[str]: + """Return the list of DAP instance names known to OpenOCD.""" + resp = await conn.send("dap names") + _check_error(resp, "dap names") + names = [n.strip() for n in resp.strip().splitlines() if n.strip()] + return names + + +async def dap_info(conn: Connection, dap_name: str) -> DAPInfo: + """Query full DAP info and return a structured DAPInfo.""" + resp = await conn.send(f"{dap_name} info") + _check_error(resp, f"{dap_name} info") + + # Extract DPIDR + dpidr = 0 + m = _DPIDR_RE.search(resp) + if m: + dpidr = int(m.group(1), 16) + else: + log.warning( + "Could not parse DPIDR from '%s info' output — " + "OpenOCD format may have changed. Raw: %.200s", + dap_name, + resp, + ) + + # Count APs mentioned + ap_indices = set(_AP_NUM_RE.findall(resp)) + ap_count = len(ap_indices) + + return DAPInfo( + name=dap_name, + dpidr=dpidr, + ap_count=ap_count, + raw_info=resp.strip(), + ) + + +async def dpreg_read(conn: Connection, dap_name: str, address: int) -> int: + """Read a DP register at *address* via `` dpreg ``.""" + _validate_u32(address, "DP register address") + cmd = f"{dap_name} dpreg {address:#x}" + resp = await conn.send(cmd) + _check_error(resp, cmd) + return _parse_hex(resp, cmd) + + +async def dpreg_write(conn: Connection, dap_name: str, address: int, value: int) -> None: + """Write *value* to DP register at *address*.""" + _validate_u32(address, "DP register address") + _validate_u32(value, "DP register value") + cmd = f"{dap_name} dpreg {address:#x} {value:#x}" + resp = await conn.send(cmd) + _check_error(resp, cmd) + + +async def apreg_read(conn: Connection, dap_name: str, ap_num: int, address: int) -> int: + """Read an AP register: `` apreg ``.""" + _validate_ap_num(ap_num) + _validate_u32(address, "AP register address") + cmd = f"{dap_name} apreg {ap_num} {address:#x}" + resp = await conn.send(cmd) + _check_error(resp, cmd) + return _parse_hex(resp, cmd) + + +async def apreg_write( + conn: Connection, dap_name: str, ap_num: int, address: int, value: int +) -> None: + """Write *value* to AP register: `` apreg ``.""" + _validate_ap_num(ap_num) + _validate_u32(address, "AP register address") + _validate_u32(value, "AP register value") + cmd = f"{dap_name} apreg {ap_num} {address:#x} {value:#x}" + resp = await conn.send(cmd) + _check_error(resp, cmd) + + +def _classify_ap(idr: int) -> str: + """Classify an AP by its IDR value. + + The AP IDR Class field (bits 16:13) indicates the AP type per ARM ADI: + 0x0 = no AP / reserved + 0x1 = COM-AP (deprecated MEM-AP variant, ADIv5) + 0x8 = MEM-AP (ADIv5) + 0x9 = MEM-AP (ADIv6) + The Type field (bits 3:0) further distinguishes variants. + """ + if idr == 0: + return "unknown" + class_field = (idr >> 13) & 0xF + if class_field in (0x1, 0x8, 0x9): + return "MEM-AP" + type_field = idr & 0xF + if type_field == 0x0: + return "JTAG-AP" + return "unknown" + + +async def enumerate_aps(conn: Connection, dap_name: str, max_aps: int = 256) -> list[APInfo]: + """Probe APs by reading IDR (offset 0xFC) until we get 0 or hit *max_aps*. + + Each AP with a non-zero IDR is included. We also read the BASE register + (offset 0xF8) to capture the ROM table address. + """ + aps: list[APInfo] = [] + for idx in range(max_aps): + try: + idr = await apreg_read(conn, dap_name, idx, 0xFC) + except SWDError as exc: + log.warning("AP enumeration stopped at index %d due to error: %s", idx, exc) + break + if idr == 0: + break + + try: + base = await apreg_read(conn, dap_name, idx, 0xF8) + except SWDError: + base = 0 + + aps.append( + APInfo( + index=idx, + idr=idr, + base=base, + ap_type=_classify_ap(idr), + ) + ) + return aps diff --git a/src/openocd/types.py b/src/openocd/types.py index 6ab8bad..89848f4 100644 --- a/src/openocd/types.py +++ b/src/openocd/types.py @@ -10,6 +10,7 @@ from typing import Literal # Target # --------------------------------------------------------------------------- + @dataclass(frozen=True) class TargetState: """Snapshot of target execution state.""" @@ -23,6 +24,7 @@ class TargetState: # Registers # --------------------------------------------------------------------------- + @dataclass(frozen=True) class Register: """A single CPU register.""" @@ -38,6 +40,7 @@ class Register: # Flash # --------------------------------------------------------------------------- + @dataclass(frozen=True) class FlashSector: """One sector inside a flash bank.""" @@ -66,6 +69,7 @@ class FlashBank: # JTAG # --------------------------------------------------------------------------- + @dataclass(frozen=True) class TAPInfo: """One TAP discovered on the JTAG chain.""" @@ -103,6 +107,7 @@ class JTAGState(str, Enum): # Memory # --------------------------------------------------------------------------- + @dataclass(frozen=True) class MemoryRegion: """A chunk of memory read from the target.""" @@ -116,6 +121,7 @@ class MemoryRegion: # SVD # --------------------------------------------------------------------------- + @dataclass(frozen=True) class BitField: """One decoded bitfield inside a register.""" @@ -150,6 +156,7 @@ class DecodedRegister: # Breakpoints # --------------------------------------------------------------------------- + @dataclass(frozen=True) class Breakpoint: """An active breakpoint.""" @@ -175,6 +182,7 @@ class Watchpoint: # RTT # --------------------------------------------------------------------------- + @dataclass(frozen=True) class RTTChannel: """An RTT channel descriptor.""" @@ -183,3 +191,28 @@ class RTTChannel: name: str size: int direction: Literal["up", "down"] + + +# --------------------------------------------------------------------------- +# SWD / DAP +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class DAPInfo: + """Debug Access Port information returned by ``dap info``.""" + + name: str # DAP instance name (e.g. "stm32f1x.dap") + dpidr: int # DP ID Register value + ap_count: int # Number of access ports discovered + raw_info: str # Full ``dap info`` output for detailed parsing + + +@dataclass(frozen=True) +class APInfo: + """Access Port descriptor discovered during AP enumeration.""" + + index: int # AP number (0, 1, 2...) + idr: int # AP ID Register (from apreg 0xfc) + base: int # ROM table base address (from apreg 0xf8) + ap_type: str # "MEM-AP", "JTAG-AP", "CTRL-AP", or "unknown" diff --git a/tests/mock_server.py b/tests/mock_server.py index 08aca98..db9d9d4 100644 --- a/tests/mock_server.py +++ b/tests/mock_server.py @@ -7,6 +7,7 @@ An asyncio TCP server that speaks the OpenOCD TCL RPC framing protocol: Supports exact-match and regex-based command routing with pre-loaded responses that mirror real OpenOCD output. """ + from __future__ import annotations import asyncio @@ -58,8 +59,7 @@ REG_ALL_RESPONSE = """\ READ_MEMORY_RESPONSE = "20005000 080001a1 080001ab 080001ad" FLASH_BANKS_RESPONSE = ( - "#0 : stm32f1x.flash (stm32f1x) at 0x08000000," - " size 0x00020000, buswidth 0, chipwidth 0" + "#0 : stm32f1x.flash (stm32f1x) at 0x08000000, size 0x00020000, buswidth 0, chipwidth 0" ) SCAN_CHAIN_RESPONSE = """\ @@ -82,6 +82,26 @@ TRANSPORT_SELECT_RESPONSE = "swd" TRANSPORT_LIST_RESPONSE = "jtag swd" ADAPTER_SPEED_RESPONSE = "4000" +# -- SWD/DAP --------------------------------------------------------------- +DAP_NAMES_RESPONSE = "stm32f1x.dap" + +DAP_INFO_RESPONSE = """\ +AP # 0 + AP ID register 0x04770031 + Type is MEM-AP AHB3 + MEM-AP BASE 0xe00ff003 + Valid ROM table present + Component base address 0xe00ff000 + Peripheral ID 0x04c0010471 + Designer is 0x4bb, ST Microelectronics +DPIDR: 0x2ba01477""" + +DPREG_0_RESPONSE = "0x2ba01477" +DPREG_24_RESPONSE = "0x00000477" +APREG_0_FC_RESPONSE = "0x04770031" +APREG_0_F8_RESPONSE = "0xe00ff003" +APREG_1_FC_RESPONSE = "0x00000000" + def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[str], str]]]: """Build the default command-to-response routing table. @@ -97,7 +117,6 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st (re.compile(r"^step"), ""), (re.compile(r"^reset\s+"), ""), (re.compile(r"^wait_halt"), ""), - # individual register reads (must come before bare "reg") (re.compile(r"^reg\s+pc$"), REG_PC_RESPONSE), (re.compile(r"^reg\s+sp$"), REG_SP_RESPONSE), @@ -107,24 +126,30 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st (re.compile(r"^reg\s+\S+\s+0x"), ""), # bare "reg" -> full listing (re.compile(r"^reg$"), REG_ALL_RESPONSE), - # memory (re.compile(r"^read_memory\s+0x8000000\s+32\s+4$"), READ_MEMORY_RESPONSE), # generic read_memory -- return zeros for widths/counts we haven't mapped (re.compile(r"^read_memory\s+"), _generic_read_memory), (re.compile(r"^write_memory\s+"), ""), - # flash (re.compile(r"^flash banks$"), FLASH_BANKS_RESPONSE), (re.compile(r"^flash\s+"), ""), - + # SWD/DAP + (re.compile(r"^dap names$"), DAP_NAMES_RESPONSE), + (re.compile(r"^stm32f1x\.dap info$"), DAP_INFO_RESPONSE), + (re.compile(r"^stm32f1x\.dap dpreg 0x0$"), DPREG_0_RESPONSE), + (re.compile(r"^stm32f1x\.dap dpreg 0x24$"), DPREG_24_RESPONSE), + (re.compile(r"^stm32f1x\.dap dpreg 0x0 0x"), ""), + (re.compile(r"^stm32f1x\.dap apreg 0 0xfc$"), APREG_0_FC_RESPONSE), + (re.compile(r"^stm32f1x\.dap apreg 0 0xf8$"), APREG_0_F8_RESPONSE), + (re.compile(r"^stm32f1x\.dap apreg 1 0xfc$"), APREG_1_FC_RESPONSE), + (re.compile(r"^stm32f1x\.dap apreg 0 0x0 0x"), ""), # JTAG (re.compile(r"^scan_chain$"), SCAN_CHAIN_RESPONSE), (re.compile(r"^irscan\s+"), "0x01"), (re.compile(r"^drscan\s+"), "0xDEADBEEF"), (re.compile(r"^runtest\s+"), ""), (re.compile(r"^pathmove\s+"), ""), - # breakpoints (re.compile(r"^bp\s+0x"), ""), (re.compile(r"^bp$"), BP_LIST_RESPONSE), @@ -132,14 +157,12 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st (re.compile(r"^wp\s+0x"), ""), (re.compile(r"^wp$"), ""), (re.compile(r"^rwp\s+"), ""), - # transport / adapter (re.compile(r"^transport\s+select$"), TRANSPORT_SELECT_RESPONSE), (re.compile(r"^transport\s+list$"), TRANSPORT_LIST_RESPONSE), (re.compile(r"^adapter\s+speed$"), ADAPTER_SPEED_RESPONSE), (re.compile(r"^adapter\s+speed\s+\d+"), ADAPTER_SPEED_RESPONSE), (re.compile(r"^adapter\s+name$"), "cmsis-dap"), - # RTT (re.compile(r"^rtt\s+channels$"), RTT_CHANNELS_RESPONSE), (re.compile(r"^rtt\s+setup\s+"), ""), @@ -147,7 +170,6 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st (re.compile(r"^rtt\s+stop$"), ""), (re.compile(r"^rtt\s+channelread\s+"), "hello from target"), (re.compile(r"^rtt\s+channelwrite\s+"), ""), - # notifications (re.compile(r"^tcl_notifications\s+"), ""), ] @@ -201,9 +223,7 @@ class MockOpenOCDServer: self._routes.insert(0, (re.compile(pattern), response)) async def start(self) -> None: - self._server = await asyncio.start_server( - self._handle_client, self._host, self._port - ) + self._server = await asyncio.start_server(self._handle_client, self._host, self._port) await self._server.start_serving() async def stop(self) -> None: diff --git a/tests/test_session.py b/tests/test_session.py index 2943bec..191395d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,5 @@ """Tests for the Session class.""" + from __future__ import annotations import pytest @@ -11,6 +12,7 @@ from openocd.registers import Registers from openocd.rtt import RTTManager from openocd.session import Session from openocd.svd import SVDManager +from openocd.swd import SWDController from openocd.target import Target from openocd.transport import Transport @@ -38,6 +40,7 @@ async def test_context_manager(mock_ocd): # After exiting the context, the connection is closed. # Attempting to send should raise. from openocd.errors import ConnectionError + with pytest.raises(ConnectionError): await sess.command("targets") @@ -67,6 +70,11 @@ async def test_subsystem_jtag_type(session): assert isinstance(session.jtag, JTAGController) +async def test_subsystem_swd_type(session): + """session.swd should return an SWDController instance.""" + assert isinstance(session.swd, SWDController) + + async def test_subsystem_breakpoints_type(session): """session.breakpoints should return a BreakpointManager instance.""" assert isinstance(session.breakpoints, BreakpointManager) diff --git a/tests/test_swd.py b/tests/test_swd.py new file mode 100644 index 0000000..bdbc61e --- /dev/null +++ b/tests/test_swd.py @@ -0,0 +1,286 @@ +"""Tests for the SWD/DAP subsystem.""" + +from __future__ import annotations + +import pytest + +from openocd.errors import SWDError +from openocd.types import APInfo, DAPInfo + + +async def test_dap_info(session): + """info() should return a DAPInfo with parsed DPIDR and AP count.""" + info = await session.swd.info() + assert isinstance(info, DAPInfo) + assert info.name == "stm32f1x.dap" + assert info.dpidr == 0x2BA01477 + assert info.ap_count == 1 + assert "MEM-AP" in info.raw_info + + +async def test_dap_info_frozen(session): + """DAPInfo should be immutable (frozen dataclass).""" + info = await session.swd.info() + with pytest.raises(AttributeError): + info.name = "something_else" # type: ignore[misc] + + +async def test_dpreg_read(session): + """dpreg() without a value should read and return a DP register.""" + result = await session.swd.dpreg(0x0) + assert isinstance(result, int) + assert result == 0x2BA01477 + + +async def test_dpreg_write(session, mock_ocd): + """dpreg() with a value should write and return the written value.""" + result = await session.swd.dpreg(0x0, value=0x12345678) + assert result == 0x12345678 + # Verify the mock received the write command + _, _, server = mock_ocd + write_cmds = [c for c in server.received_commands if "dpreg 0x0 0x" in c] + assert len(write_cmds) >= 1 + + +async def test_apreg_read(session): + """apreg() without a value should read an AP register.""" + result = await session.swd.apreg(0, 0xFC) + assert isinstance(result, int) + assert result == 0x04770031 + + +async def test_apreg_write(session, mock_ocd): + """apreg() with a value should write and return the written value.""" + result = await session.swd.apreg(0, 0x0, value=0xAABBCCDD) + assert result == 0xAABBCCDD + _, _, server = mock_ocd + write_cmds = [c for c in server.received_commands if "apreg 0 0x0 0x" in c] + assert len(write_cmds) >= 1 + + +async def test_enumerate_aps(session): + """list_aps() should discover APs by probing IDR until zero.""" + aps = await session.swd.list_aps() + assert isinstance(aps, list) + assert len(aps) == 1 + + ap = aps[0] + assert isinstance(ap, APInfo) + assert ap.index == 0 + assert ap.idr == 0x04770031 + assert ap.base == 0xE00FF003 + assert ap.ap_type == "MEM-AP" + + +async def test_ap_info_frozen(session): + """APInfo should be immutable (frozen dataclass).""" + aps = await session.swd.list_aps() + with pytest.raises(AttributeError): + aps[0].index = 99 # type: ignore[misc] + + +async def test_dpidr_convenience(session): + """dpidr() should read DP address 0x0.""" + result = await session.swd.dpidr() + assert result == 0x2BA01477 + + +async def test_target_id(session): + """target_id() should read DP address 0x24.""" + result = await session.swd.target_id() + assert result == 0x00000477 + + +async def test_auto_resolve_dap(session, mock_ocd): + """With no explicit dap name, the controller should auto-discover.""" + # First call triggers dap names lookup + await session.swd.dpidr() + _, _, server = mock_ocd + assert "dap names" in server.received_commands + + # Second call should use the cached name (no extra dap names) + count_before = server.received_commands.count("dap names") + await session.swd.dpidr() + count_after = server.received_commands.count("dap names") + assert count_after == count_before + + +async def test_explicit_dap_name(session, mock_ocd): + """Passing dap= explicitly should skip auto-discovery.""" + result = await session.swd.dpreg(0x0, dap="stm32f1x.dap") + assert result == 0x2BA01477 + # Should NOT have called "dap names" + _, _, server = mock_ocd + assert "dap names" not in server.received_commands + + +async def test_swd_error_on_bad_response(mock_ocd): + """SWDError should be raised when response matches OpenOCD error patterns.""" + from openocd.swd.dap import _check_error + + with pytest.raises(SWDError): + _check_error("Error: invalid DAP", "test") + + with pytest.raises(SWDError): + _check_error("invalid command name", "test") + + with pytest.raises(SWDError): + _check_error("command not found", "test") + + # Clean responses should not raise + _check_error("0x2ba01477", "test") + _check_error("", "test") + + # Legitimate output containing "error" as a substring should NOT raise. + # This is the false-positive prevention fix (C1 from code review). + _check_error("error detection enabled in CTRL register", "test") + _check_error("AP ID register 0x04770031", "test") + + +async def test_swd_error_no_hex_value(mock_ocd): + """SWDError should be raised when no hex value found in read response.""" + from openocd.swd.dap import _parse_hex + + with pytest.raises(SWDError, match="no hex value"): + _parse_hex("no numbers here", "test read") + + +def test_sync_wrapper(): + """SyncSWDController should expose the same API synchronously. + + The sync API blocks with run_until_complete, so the mock server must + run on a separate thread to accept connections concurrently. + """ + import asyncio + import threading + + from openocd.session import Session + from tests.mock_server import MockOpenOCDServer + + # Run mock server in a background thread with its own event loop. + bg_loop = asyncio.new_event_loop() + server = MockOpenOCDServer() + bg_loop.run_until_complete(server.start()) + host, port = server.address + + thread = threading.Thread(target=bg_loop.run_forever, daemon=True) + thread.start() + + try: + with Session.connect_sync(host, port, timeout=5.0) as sync_sess: + result = sync_sess.swd.dpidr() + assert result == 0x2BA01477 + + info = sync_sess.swd.info() + assert isinstance(info, DAPInfo) + assert info.name == "stm32f1x.dap" + + aps = sync_sess.swd.list_aps() + assert len(aps) == 1 + finally: + bg_loop.call_soon_threadsafe(bg_loop.stop) + thread.join(timeout=5) + bg_loop.run_until_complete(server.stop()) + bg_loop.close() + + +def test_classify_ap(): + """AP classification should identify MEM-AP, JTAG-AP, and unknown types.""" + from openocd.swd.dap import _classify_ap + + # MEM-AP ADIv5 (class field 0x8) + assert _classify_ap(0x04770031) == "MEM-AP" + # Zero IDR = unknown + assert _classify_ap(0x00000000) == "unknown" + # Class field 0x1 (COM-AP / legacy MEM-AP) + assert _classify_ap(0x00002000) == "MEM-AP" + # MEM-AP ADIv6 (class field 0x9) + assert _classify_ap(0x00012000) == "MEM-AP" + # JTAG-AP: non-zero IDR, class not MEM-AP, type field 0x0 + assert _classify_ap(0x00000010) == "JTAG-AP" # bits[3:0]=0x0, class=0 + # Unknown: non-zero IDR, class not MEM-AP, type field != 0 + assert _classify_ap(0x00000001) == "unknown" # bits[3:0]=0x1, class=0 + + +# ====================================================================== +# Error-path tests (from code review findings I6) +# ====================================================================== + + +async def test_no_dap_found(mock_ocd): + """SWDError should be raised when dap names returns empty.""" + from openocd.session import Session + + host, port, server = mock_ocd + # Override dap names to return empty + server.add_response(r"^dap names$", "") + + sess = await Session.connect(host, port, timeout=5.0) + try: + with pytest.raises(SWDError, match="No DAP instances found"): + await sess.swd.dpidr() + finally: + await sess.close() + + +async def test_invalidate_cache(session, mock_ocd): + """invalidate_cache() should force re-discovery on next call.""" + _, _, server = mock_ocd + + # First call populates the cache + await session.swd.dpidr() + count_after_first = server.received_commands.count("dap names") + assert count_after_first == 1 + + # Invalidate and call again + session.swd.invalidate_cache() + await session.swd.dpidr() + count_after_invalidate = server.received_commands.count("dap names") + assert count_after_invalidate == 2 + + +async def test_dpreg_negative_address_rejected(session): + """Negative addresses should be rejected before reaching OpenOCD.""" + with pytest.raises(SWDError, match="must be 0"): + await session.swd.dpreg(-1) + + +async def test_dpreg_overflow_address_rejected(session): + """Addresses > 0xFFFFFFFF should be rejected.""" + with pytest.raises(SWDError, match="must be 0"): + await session.swd.dpreg(0x1_0000_0000) + + +async def test_apreg_negative_ap_rejected(session): + """Negative AP numbers should be rejected.""" + with pytest.raises(SWDError, match="AP number must be"): + await session.swd.apreg(-1, 0xFC) + + +async def test_apreg_ap_over_255_rejected(session): + """AP numbers > 255 should be rejected per ARM ADI spec.""" + with pytest.raises(SWDError, match="AP number must be"): + await session.swd.apreg(256, 0xFC) + + +async def test_dpreg_write_negative_value_rejected(session): + """Negative values should be rejected for DP register writes.""" + with pytest.raises(SWDError, match="must be 0"): + await session.swd.dpreg(0x0, value=-1) + + +async def test_dap_info_unparseable_dpidr(mock_ocd): + """When dap info output has no DPIDR line, dpidr should be 0 with warning.""" + from openocd.session import Session + + host, port, server = mock_ocd + # Override dap info to return output with no DPIDR line + server.add_response(r"^stm32f1x\.dap info$", "AP # 0\n Some AP info\n No DPIDR here") + + sess = await Session.connect(host, port, timeout=5.0) + try: + info = await sess.swd.info() + assert info.dpidr == 0 # Falls back to 0 with a logged warning + assert info.ap_count == 1 # AP # 0 is still counted + finally: + await sess.close()