commit 7e1eac5e2db79739daab4837e28c6a33ec10ebdf Author: Ryan Malloy Date: Thu Feb 12 17:55:58 2026 -0700 Add openocd-python: typed async-first Python bindings for OpenOCD Standalone PyPI package providing structured access to the full OpenOCD command surface via the TCL RPC protocol (port 6666). Async-first API with sync wrappers for every method. Subsystems: target control, memory read/write, CPU registers, flash programming, JTAG chain/scan/boundary, breakpoints/watchpoints, SVD peripheral decoding, RTT channels, transport/adapter config. 79 tests passing against a mock TCL RPC server. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fc68fc7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +.venv/ +*.egg-info/ +dist/ +build/ +.pytest_cache/ +.ruff_cache/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..2bb3738 --- /dev/null +++ b/README.md @@ -0,0 +1,46 @@ +# openocd-python + +Typed, async-first Python bindings for the full OpenOCD command surface. + +## Install + +```bash +pip install openocd-python +``` + +## Quick Start + +```python +from openocd import Session + +# Connect to a running OpenOCD instance +async with Session.connect() as ocd: + state = await ocd.target.halt() + pc = await ocd.registers.pc() + mem = await ocd.memory.read_u32(0x08000000, 4) + print(f"PC: {pc:#x}") + +# Or spawn OpenOCD and connect +async with Session.start("interface/cmsis-dap.cfg -f target/stm32f1x.cfg") as ocd: + await ocd.target.halt() + regs = await ocd.registers.read_all() + +# Synchronous API available too +with Session.start_sync("interface/cmsis-dap.cfg") as ocd: + ocd.target.halt() + print(f"PC: {ocd.registers.pc():#x}") +``` + +## Features + +- **Async-first** with sync wrappers for every method +- **Typed returns** — dataclasses, not raw strings +- **Full OpenOCD surface**: target control, memory, registers, flash, JTAG, breakpoints, RTT +- **SVD decoding** — read a peripheral register and get named bitfields +- **Process management** — spawn and manage OpenOCD subprocesses +- **Dual transport** — TCL RPC (primary) and telnet (fallback) + +## Requirements + +- Python 3.11+ +- OpenOCD installed and on PATH (or pass `openocd_bin=`) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3c3c9ab --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,61 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "openocd-python" +version = "2025.02.12" +description = "Typed, async-first Python bindings for the full OpenOCD command surface" +readme = "README.md" +license = "MIT" +requires-python = ">=3.11" +authors = [ + {name = "Ryan Malloy", email = "ryan@supported.systems"}, +] +keywords = ["openocd", "jtag", "swd", "embedded", "debugging", "hardware"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Debuggers", + "Topic :: Software Development :: Embedded Systems", + "Topic :: System :: Hardware", + "Typing :: Typed", +] +dependencies = [ + "cmsis-svd>=0.4", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.24", + "ruff>=0.8", +] + +[project.scripts] +openocd-python = "openocd.cli:main" + +[project.urls] +Homepage = "https://github.com/ryanmalloy/openocd-python" +Issues = "https://github.com/ryanmalloy/openocd-python/issues" + +[tool.hatch.build.targets.wheel] +packages = ["src/openocd"] + +[tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B", "SIM"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +markers = [ + "hardware: requires physical DAP-Link hardware (deselect with '-m not hardware')", +] diff --git a/src/openocd/__init__.py b/src/openocd/__init__.py new file mode 100644 index 0000000..145dfcd --- /dev/null +++ b/src/openocd/__init__.py @@ -0,0 +1,64 @@ +"""openocd-python — typed, async-first Python bindings for OpenOCD.""" + +from openocd.errors import ( + ConnectionError, + FlashError, + JTAGError, + OpenOCDError, + ProcessError, + SVDError, + TargetError, + TargetNotHaltedError, + TimeoutError, +) +from openocd.session import Session, SyncSession +from openocd.types import ( + BitField, + Breakpoint, + DecodedRegister, + FlashBank, + FlashSector, + JTAGState, + MemoryRegion, + Register, + RTTChannel, + TAPInfo, + TargetState, + Watchpoint, +) + +__all__ = [ + # Session + "Session", + "SyncSession", + # Types + "BitField", + "Breakpoint", + "DecodedRegister", + "FlashBank", + "FlashSector", + "JTAGState", + "MemoryRegion", + "RTTChannel", + "Register", + "TAPInfo", + "TargetState", + "Watchpoint", + # Errors + "ConnectionError", + "FlashError", + "JTAGError", + "OpenOCDError", + "ProcessError", + "SVDError", + "TargetError", + "TargetNotHaltedError", + "TimeoutError", +] + +try: + from importlib.metadata import version + + __version__ = version("openocd-python") +except Exception: + __version__ = "0.0.0" diff --git a/src/openocd/breakpoints.py b/src/openocd/breakpoints.py new file mode 100644 index 0000000..7da6ead --- /dev/null +++ b/src/openocd/breakpoints.py @@ -0,0 +1,234 @@ +"""Breakpoint and watchpoint management. + +Wraps OpenOCD's ``bp``, ``rbp``, ``wp``, and ``rwp`` commands for +setting, removing, and listing hardware/software breakpoints and +data watchpoints. +""" + +from __future__ import annotations + +import asyncio +import logging +import re +from typing import Literal + +from openocd.connection.base import Connection +from openocd.errors import OpenOCDError +from openocd.types import Breakpoint, Watchpoint + +log = logging.getLogger(__name__) + + +class BreakpointError(OpenOCDError): + """A breakpoint or watchpoint operation failed.""" + + +# --------------------------------------------------------------------------- +# Parsers +# --------------------------------------------------------------------------- + +# Breakpoint(IVA): 0x08001234, 0x2, 1 (hw=1) or 0 (sw) +_BP_RE = re.compile( + r"Breakpoint\([^)]*\):\s*(?P0x[0-9a-fA-F]+),\s*" + r"(?P0x[0-9a-fA-F]+),\s*(?P\d+)" +) + +# Watchpoint output varies across OpenOCD versions. Common formats: +# address: 0x20000000, len: 0x4, r/w/a: 2 (access), value: ... +# Watchpoint(DWT): 0x20000000, 0x4, 2 +_WP_RE = re.compile( + r"(?:address:\s*(?P0x[0-9a-fA-F]+).*?len:\s*(?P0x[0-9a-fA-F]+).*?r/w/a:\s*(?P\d+))" + r"|" + r"(?:Watchpoint\([^)]*\):\s*(?P0x[0-9a-fA-F]+),\s*(?P0x[0-9a-fA-F]+),\s*(?P\d+))" +) + +# OpenOCD watchpoint access type encoding +_WP_ACCESS_MAP = {0: "r", 1: "w", 2: "rw"} +_WP_ACCESS_CMD = {"r": "r", "w": "w", "rw": "a"} + + +def _check_error(response: str, context: str) -> None: + """Raise BreakpointError if the response indicates failure.""" + if "error" in response.lower(): + raise BreakpointError(f"{context}: {response.strip()}") + + +def _parse_breakpoint_list(text: str) -> list[Breakpoint]: + """Parse the output of ``bp`` (no arguments) into Breakpoint objects.""" + breakpoints: list[Breakpoint] = [] + for idx, m in enumerate(_BP_RE.finditer(text)): + hw_flag = int(m.group("hw")) + breakpoints.append( + Breakpoint( + number=idx, + type="hw" if hw_flag else "sw", + address=int(m.group("addr"), 16), + length=int(m.group("len"), 16), + enabled=True, + ) + ) + return breakpoints + + +def _parse_watchpoint_list(text: str) -> list[Watchpoint]: + """Parse watchpoint listing output.""" + watchpoints: list[Watchpoint] = [] + for idx, m in enumerate(_WP_RE.finditer(text)): + # Match could come from either alternative in the regex + if m.group("addr1") is not None: + addr = int(m.group("addr1"), 16) + length = int(m.group("len1"), 16) + rwa = int(m.group("rwa1")) + else: + addr = int(m.group("addr2"), 16) + length = int(m.group("len2"), 16) + rwa = int(m.group("rwa2")) + + watchpoints.append( + Watchpoint( + number=idx, + address=addr, + length=length, + access=_WP_ACCESS_MAP.get(rwa, "rw"), + ) + ) + return watchpoints + + +class BreakpointManager: + """Manage breakpoints and watchpoints via OpenOCD. + + Breakpoints can be either software (patching the instruction) or + hardware (using on-chip comparators). Watchpoints trigger on data + access to a given address range. + """ + + def __init__(self, conn: Connection) -> None: + self._conn = conn + + # ------------------------------------------------------------------ + # Breakpoints + # ------------------------------------------------------------------ + + async def add(self, address: int, length: int = 2, hw: bool = False) -> None: + """Set a breakpoint at the given address. + + Args: + address: Instruction address for the breakpoint. + length: Breakpoint length in bytes (2 for Thumb, 4 for ARM). + hw: Request a hardware breakpoint. If False, OpenOCD uses a + software breakpoint when possible. + """ + cmd = f"bp 0x{address:08X} {length}" + if hw: + cmd += " hw" + resp = await self._conn.send(cmd) + _check_error(resp, f"bp 0x{address:08X}") + log.info("Breakpoint set at 0x%08X (len=%d, hw=%s)", address, length, hw) + + async def remove(self, address: int) -> None: + """Remove a breakpoint at the given address. + + Args: + address: Address of the breakpoint to remove. + """ + cmd = f"rbp 0x{address:08X}" + resp = await self._conn.send(cmd) + _check_error(resp, f"rbp 0x{address:08X}") + log.info("Breakpoint removed at 0x%08X", address) + + async def list(self) -> list[Breakpoint]: + """List all active breakpoints. + + Returns: + A list of Breakpoint objects describing each active breakpoint. + """ + resp = await self._conn.send("bp") + # An empty response or no matches means no breakpoints set + if not resp.strip(): + return [] + return _parse_breakpoint_list(resp) + + # ------------------------------------------------------------------ + # Watchpoints + # ------------------------------------------------------------------ + + async def add_watchpoint( + self, + address: int, + length: int, + access: Literal["r", "w", "rw"] = "rw", + ) -> None: + """Set a data watchpoint. + + Args: + address: Memory address to watch. + length: Number of bytes to watch (must be power of 2 on most targets). + access: Access type -- ``"r"`` for read, ``"w"`` for write, + ``"rw"`` for read/write (access). + """ + access_flag = _WP_ACCESS_CMD.get(access, "a") + cmd = f"wp 0x{address:08X} {length} {access_flag}" + resp = await self._conn.send(cmd) + _check_error(resp, f"wp 0x{address:08X}") + log.info("Watchpoint set at 0x%08X (len=%d, access=%s)", address, length, access) + + async def remove_watchpoint(self, address: int) -> None: + """Remove a watchpoint at the given address. + + Args: + address: Address of the watchpoint to remove. + """ + cmd = f"rwp 0x{address:08X}" + resp = await self._conn.send(cmd) + _check_error(resp, f"rwp 0x{address:08X}") + log.info("Watchpoint removed at 0x%08X", address) + + async def list_watchpoints(self) -> list[Watchpoint]: + """List all active watchpoints. + + Returns: + A list of Watchpoint objects describing each active watchpoint. + """ + # OpenOCD doesn't have a dedicated "list watchpoints" command + # but 'wp' with no arguments on some builds returns the list. + # The more reliable approach is using the TCL command. + resp = await self._conn.send("wp") + if not resp.strip(): + return [] + return _parse_watchpoint_list(resp) + + +# ====================================================================== +# Sync wrapper +# ====================================================================== + +class SyncBreakpointManager: + """Synchronous wrapper around BreakpointManager.""" + + def __init__(self, bp_manager: BreakpointManager, loop: asyncio.AbstractEventLoop) -> None: + self._bp = bp_manager + self._loop = loop + + def add(self, address: int, length: int = 2, hw: bool = False) -> None: + self._loop.run_until_complete(self._bp.add(address, length=length, hw=hw)) + + def remove(self, address: int) -> None: + self._loop.run_until_complete(self._bp.remove(address)) + + def list(self) -> list[Breakpoint]: + return self._loop.run_until_complete(self._bp.list()) + + def add_watchpoint( + self, + address: int, + length: int, + access: Literal["r", "w", "rw"] = "rw", + ) -> None: + self._loop.run_until_complete(self._bp.add_watchpoint(address, length, access=access)) + + def remove_watchpoint(self, address: int) -> None: + self._loop.run_until_complete(self._bp.remove_watchpoint(address)) + + def list_watchpoints(self) -> list[Watchpoint]: + return self._loop.run_until_complete(self._bp.list_watchpoints()) diff --git a/src/openocd/cli.py b/src/openocd/cli.py new file mode 100644 index 0000000..53fb0cb --- /dev/null +++ b/src/openocd/cli.py @@ -0,0 +1,161 @@ +"""CLI entry point for openocd-python. + +Provides quick diagnostics and a REPL for interactive use: + + $ openocd-python --help + $ openocd-python info # probe detection + target info + $ openocd-python repl # interactive command REPL + $ openocd-python read 0x08000000 16 # quick memory read +""" + +from __future__ import annotations + +import argparse +import asyncio +import sys + + +def main() -> None: + try: + from importlib.metadata import version + + pkg_version = version("openocd-python") + except Exception: + pkg_version = "dev" + + parser = argparse.ArgumentParser( + prog="openocd-python", + description=f"OpenOCD Python bindings v{pkg_version}", + ) + parser.add_argument( + "--version", action="version", version=f"openocd-python {pkg_version}" + ) + parser.add_argument( + "--host", default="localhost", help="OpenOCD host (default: localhost)" + ) + parser.add_argument( + "--port", type=int, default=6666, help="OpenOCD TCL RPC port (default: 6666)" + ) + + sub = parser.add_subparsers(dest="command") + + sub.add_parser("info", help="Show target and adapter information") + + repl_parser = sub.add_parser("repl", help="Interactive OpenOCD command REPL") + repl_parser.add_argument( + "--timeout", type=float, default=10.0, help="Command timeout in seconds" + ) + + read_parser = sub.add_parser("read", help="Read memory and display as hexdump") + read_parser.add_argument("address", help="Start address (hex, e.g. 0x08000000)") + read_parser.add_argument( + "size", type=int, nargs="?", default=64, help="Bytes to read (default: 64)" + ) + + sub.add_parser("scan", help="Scan the JTAG chain") + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + sys.exit(0) + + asyncio.run(_dispatch(args)) + + +async def _dispatch(args: argparse.Namespace) -> None: + from openocd.session import Session + + async with Session.connect(host=args.host, port=args.port) as ocd: + if args.command == "info": + await _cmd_info(ocd) + elif args.command == "repl": + await _cmd_repl(ocd, timeout=args.timeout) + elif args.command == "read": + await _cmd_read(ocd, args.address, args.size) + elif args.command == "scan": + await _cmd_scan(ocd) + + +async def _cmd_info(ocd) -> None: + """Display target state and adapter information.""" + from openocd.errors import OpenOCDError + + print("=== OpenOCD Target Info ===\n") + + try: + state = await ocd.target.state() + print(f" Target: {state.name}") + print(f" State: {state.state}") + if state.current_pc is not None: + print(f" PC: 0x{state.current_pc:08X}") + except OpenOCDError as e: + print(f" Target: (error: {e})") + + print() + + try: + transport_name = await ocd.transport.select() + print(f" Transport: {transport_name}") + except OpenOCDError: + pass + + try: + adapter = await ocd.transport.adapter_info() + print(f" Adapter: {adapter}") + except OpenOCDError: + pass + + try: + speed = await ocd.transport.adapter_speed() + print(f" Speed: {speed} kHz") + except OpenOCDError: + pass + + +async def _cmd_repl(ocd, timeout: float = 10.0) -> None: + """Interactive command REPL.""" + print("OpenOCD REPL (type 'quit' or Ctrl-D to exit)\n") + while True: + try: + cmd = input("ocd> ") + except (EOFError, KeyboardInterrupt): + print() + break + if cmd.strip().lower() in ("quit", "exit", "q"): + break + if not cmd.strip(): + continue + try: + result = await ocd.command(cmd) + if result.strip(): + print(result) + except Exception as e: + print(f"Error: {e}") + + +async def _cmd_read(ocd, address_str: str, size: int) -> None: + """Read memory and display as hexdump.""" + addr = int(address_str, 0) + dump = await ocd.memory.hexdump(addr, size) + print(dump) + + +async def _cmd_scan(ocd) -> None: + """Scan the JTAG chain.""" + taps = await ocd.jtag.scan_chain() + if not taps: + print("No TAPs found on the JTAG chain.") + return + + print(f"{'TAP Name':<25s} {'IDCODE':>10s} IR Enabled") + print("-" * 50) + for tap in taps: + print( + f"{tap.name:<25s} 0x{tap.idcode:08X} {tap.ir_length:>2d} " + f"{'yes' if tap.enabled else 'no'}" + ) + + +if __name__ == "__main__": + main() diff --git a/src/openocd/connection/__init__.py b/src/openocd/connection/__init__.py new file mode 100644 index 0000000..62a3170 --- /dev/null +++ b/src/openocd/connection/__init__.py @@ -0,0 +1,7 @@ +"""Connection backends for communicating with OpenOCD.""" + +from openocd.connection.base import Connection +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.connection.telnet import TelnetConnection + +__all__ = ["Connection", "TclRpcConnection", "TelnetConnection"] diff --git a/src/openocd/connection/base.py b/src/openocd/connection/base.py new file mode 100644 index 0000000..c2847b5 --- /dev/null +++ b/src/openocd/connection/base.py @@ -0,0 +1,30 @@ +"""Abstract base class for OpenOCD connection backends.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable + + +class Connection(ABC): + """Protocol-agnostic interface to an OpenOCD instance.""" + + @abstractmethod + async def connect(self, host: str, port: int) -> None: + """Open a connection to the given host and port.""" + + @abstractmethod + async def send(self, command: str) -> str: + """Send a command string and return the response.""" + + @abstractmethod + async def close(self) -> None: + """Close the connection.""" + + @abstractmethod + async def enable_notifications(self) -> None: + """Enable asynchronous event notifications from OpenOCD.""" + + @abstractmethod + def on_notification(self, callback: Callable[[str], None]) -> None: + """Register a callback for incoming notifications.""" diff --git a/src/openocd/connection/tcl_rpc.py b/src/openocd/connection/tcl_rpc.py new file mode 100644 index 0000000..313057b --- /dev/null +++ b/src/openocd/connection/tcl_rpc.py @@ -0,0 +1,163 @@ +"""TCL RPC client for OpenOCD (port 6666). + +OpenOCD's TCL RPC uses a simple framing protocol: + - Client sends: command_string + \\x1a + - Server replies: response_string + \\x1a +The \\x1a (ASCII SUB / Ctrl-Z) byte acts as an unambiguous delimiter. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import Callable + +from openocd.connection.base import Connection +from openocd.errors import ConnectionError +from openocd.errors import TimeoutError as OcdTimeoutError + +log = logging.getLogger(__name__) + +SEPARATOR = b"\x1a" +DEFAULT_TIMEOUT = 10.0 + + +class TclRpcConnection(Connection): + """Async TCP client speaking OpenOCD's TCL RPC protocol.""" + + def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None: + self._reader: asyncio.StreamReader | None = None + self._writer: asyncio.StreamWriter | None = None + self._timeout = timeout + self._notification_callbacks: list[Callable[[str], None]] = [] + self._notification_task: asyncio.Task[None] | None = None + self._lock = asyncio.Lock() + self._host: str = "" + self._port: int = 0 + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect(self, host: str = "localhost", port: int = 6666) -> None: + self._host = host + self._port = port + try: + self._reader, self._writer = await asyncio.wait_for( + asyncio.open_connection(host, port), + timeout=self._timeout, + ) + except OSError as exc: + raise ConnectionError( + f"Cannot connect to OpenOCD TCL RPC at {host}:{port}: {exc}" + ) from exc + except TimeoutError as exc: + raise OcdTimeoutError( + f"Timed out connecting to OpenOCD TCL RPC at {host}:{port}" + ) from exc + log.debug("Connected to OpenOCD TCL RPC at %s:%d", host, port) + + async def close(self) -> None: + if self._notification_task and not self._notification_task.done(): + self._notification_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._notification_task + self._notification_task = None + + if self._writer: + self._writer.close() + with contextlib.suppress(OSError): + await self._writer.wait_closed() + self._writer = None + self._reader = None + log.debug("TCL RPC connection closed") + + # ------------------------------------------------------------------ + # Command send/receive + # ------------------------------------------------------------------ + + async def send(self, command: str) -> str: + """Send a command and return the response string. + + The protocol appends \\x1a after the command and reads until + \\x1a appears in the response stream. + """ + if not self._writer or not self._reader: + raise ConnectionError("Not connected — call connect() first") + + async with self._lock: + payload = command.encode("utf-8") + SEPARATOR + self._writer.write(payload) + await self._writer.drain() + log.debug("TX: %s", command) + + try: + raw = await asyncio.wait_for( + self._read_until_separator(), + timeout=self._timeout, + ) + except TimeoutError as exc: + raise OcdTimeoutError( + f"Timed out waiting for response to: {command}" + ) from exc + + response = raw.decode("utf-8", errors="replace") + log.debug("RX: %s", response[:200]) + return response + + async def _read_until_separator(self) -> bytes: + """Read from the stream until the \\x1a separator is found.""" + assert self._reader is not None + buf = bytearray() + while True: + chunk = await self._reader.read(4096) + if not chunk: + raise ConnectionError("OpenOCD closed the connection") + buf.extend(chunk) + idx = buf.find(SEPARATOR) + if idx != -1: + return bytes(buf[:idx]) + + # ------------------------------------------------------------------ + # Notifications (async events from OpenOCD) + # ------------------------------------------------------------------ + + async def enable_notifications(self) -> None: + """Enable TCL event notifications and start the listener loop. + + Sends ``tcl_notifications on`` which causes OpenOCD to push + target-state-change events over the same socket. + """ + await self.send("tcl_notifications on") + self._notification_task = asyncio.create_task(self._notification_loop()) + + async def _notification_loop(self) -> None: + """Background task that reads unsolicited notifications.""" + assert self._reader is not None + buf = bytearray() + try: + while True: + chunk = await self._reader.read(4096) + if not chunk: + break + buf.extend(chunk) + while True: + idx = buf.find(SEPARATOR) + if idx == -1: + break + msg = buf[:idx].decode("utf-8", errors="replace") + buf = buf[idx + 1 :] + log.debug("Notification: %s", msg) + for cb in self._notification_callbacks: + try: + cb(msg) + except Exception: + log.exception("Notification callback error") + except asyncio.CancelledError: + return + except Exception: + log.exception("Notification loop crashed") + + def on_notification(self, callback: Callable[[str], None]) -> None: + self._notification_callbacks.append(callback) diff --git a/src/openocd/connection/telnet.py b/src/openocd/connection/telnet.py new file mode 100644 index 0000000..80b0155 --- /dev/null +++ b/src/openocd/connection/telnet.py @@ -0,0 +1,100 @@ +"""Telnet connection to OpenOCD (port 4444) — fallback transport. + +The telnet interface is human-oriented and its output formatting varies +between OpenOCD versions. Prefer TclRpcConnection where possible. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import Callable + +from openocd.connection.base import Connection +from openocd.errors import ConnectionError +from openocd.errors import TimeoutError as OcdTimeoutError + +log = logging.getLogger(__name__) + +PROMPT = b"> " +DEFAULT_TIMEOUT = 10.0 + + +class TelnetConnection(Connection): + """Async telnet client for OpenOCD port 4444.""" + + def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None: + self._reader: asyncio.StreamReader | None = None + self._writer: asyncio.StreamWriter | None = None + self._timeout = timeout + self._lock = asyncio.Lock() + + async def connect(self, host: str = "localhost", port: int = 4444) -> None: + try: + self._reader, self._writer = await asyncio.wait_for( + asyncio.open_connection(host, port), + timeout=self._timeout, + ) + except OSError as exc: + raise ConnectionError( + f"Cannot connect to OpenOCD telnet at {host}:{port}: {exc}" + ) from exc + except TimeoutError as exc: + raise OcdTimeoutError( + f"Timed out connecting to OpenOCD telnet at {host}:{port}" + ) from exc + + # Consume the initial banner / prompt + with contextlib.suppress(TimeoutError): + await asyncio.wait_for(self._read_until_prompt(), timeout=self._timeout) + log.debug("Connected to OpenOCD telnet at %s:%d", host, port) + + async def close(self) -> None: + if self._writer: + self._writer.close() + with contextlib.suppress(OSError): + await self._writer.wait_closed() + self._writer = None + self._reader = None + + async def send(self, command: str) -> str: + if not self._writer or not self._reader: + raise ConnectionError("Not connected") + + async with self._lock: + self._writer.write((command + "\n").encode("utf-8")) + await self._writer.drain() + + try: + raw = await asyncio.wait_for( + self._read_until_prompt(), + timeout=self._timeout, + ) + except TimeoutError as exc: + raise OcdTimeoutError(f"Timed out waiting for: {command}") from exc + + # Strip the echoed command and trailing prompt + text = raw.decode("utf-8", errors="replace") + lines = text.splitlines() + # First line is often the echoed command, last line is the prompt + if lines and lines[0].strip() == command.strip(): + lines = lines[1:] + return "\n".join(lines).strip() + + async def _read_until_prompt(self) -> bytes: + assert self._reader is not None + buf = bytearray() + while True: + chunk = await self._reader.read(4096) + if not chunk: + raise ConnectionError("OpenOCD closed the connection") + buf.extend(chunk) + if buf.endswith(PROMPT): + return bytes(buf[: -len(PROMPT)]) + + async def enable_notifications(self) -> None: + log.warning("Telnet transport does not support async notifications") + + def on_notification(self, callback: Callable[[str], None]) -> None: + log.warning("Telnet transport does not support notifications") diff --git a/src/openocd/errors.py b/src/openocd/errors.py new file mode 100644 index 0000000..8805287 --- /dev/null +++ b/src/openocd/errors.py @@ -0,0 +1,43 @@ +"""Exception hierarchy for openocd-python. + +All exceptions inherit from OpenOCDError so callers can catch broadly +or narrowly as needed. +""" + +from __future__ import annotations + + +class OpenOCDError(Exception): + """Base exception for all openocd-python errors.""" + + +class ConnectionError(OpenOCDError): + """Cannot connect to the OpenOCD TCL RPC or telnet interface.""" + + +class TimeoutError(OpenOCDError): + """A command or wait operation exceeded its deadline.""" + + +class TargetError(OpenOCDError): + """The target is not responding or returned an error.""" + + +class TargetNotHaltedError(TargetError): + """Operation requires a halted target but it is currently running.""" + + +class FlashError(OpenOCDError): + """A flash read/write/erase/verify operation failed.""" + + +class JTAGError(OpenOCDError): + """JTAG communication or chain error.""" + + +class SVDError(OpenOCDError): + """SVD file not found, failed to parse, or lookup error.""" + + +class ProcessError(OpenOCDError): + """OpenOCD subprocess failed to start or exited unexpectedly.""" diff --git a/src/openocd/events.py b/src/openocd/events.py new file mode 100644 index 0000000..d9f3de9 --- /dev/null +++ b/src/openocd/events.py @@ -0,0 +1,112 @@ +"""Async event system for OpenOCD target state changes. + +OpenOCD can push TCL notifications when target state changes occur +(halt, resume, reset, etc.). This module provides a typed callback +interface on top of the raw notification stream. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from openocd.connection.base import Connection + +log = logging.getLogger(__name__) + +# Known event types emitted by OpenOCD's TCL notification system +EVENT_HALTED = "halted" +EVENT_RESUMED = "resumed" +EVENT_RESET = "reset" +EVENT_GDB_ATTACHED = "gdb-attached" +EVENT_GDB_DETACHED = "gdb-detached" + + +class EventManager: + """Manages callbacks for target state change events. + + Usage:: + + events = EventManager(conn) + await events.enable() + + events.on("halted", lambda msg: print(f"Target halted: {msg}")) + events.on("reset", handle_reset) + + # Later... + events.off("halted", that_callback) + """ + + def __init__(self, conn: Connection) -> None: + self._conn = conn + self._callbacks: dict[str, list[Callable[[str], None]]] = {} + self._enabled = False + + @property + def enabled(self) -> bool: + """Whether TCL notifications have been turned on.""" + return self._enabled + + async def enable(self) -> None: + """Enable TCL notifications and start event dispatch. + + Sends ``tcl_notifications on`` to OpenOCD and registers an + internal handler that routes incoming messages to typed callbacks. + + Raises: + ConnectionError: If the connection is not open. + """ + if self._enabled: + return + + await self._conn.enable_notifications() + self._conn.on_notification(self._dispatch) + self._enabled = True + log.info("Event notifications enabled") + + def on(self, event_type: str, callback: Callable[[str], None]) -> None: + """Register a callback for a specific event type. + + Args: + event_type: Event name to match (e.g. "halted", "reset", "resumed"). + Matching is case-insensitive substring — a notification + containing "halted" anywhere triggers "halted" callbacks. + callback: Function to call with the full notification message. + """ + key = event_type.lower() + if key not in self._callbacks: + self._callbacks[key] = [] + if callback not in self._callbacks[key]: + self._callbacks[key].append(callback) + log.debug("Registered callback for event '%s'", key) + + def off(self, event_type: str, callback: Callable[[str], None]) -> None: + """Unregister a callback. + + Args: + event_type: The event type the callback was registered under. + callback: The callback to remove. + """ + key = event_type.lower() + handlers = self._callbacks.get(key, []) + try: + handlers.remove(callback) + log.debug("Unregistered callback for event '%s'", key) + except ValueError: + pass + + def _dispatch(self, message: str) -> None: + """Route an incoming notification to matching callbacks.""" + msg_lower = message.lower() + for event_type, handlers in self._callbacks.items(): + if event_type in msg_lower: + for handler in handlers: + try: + handler(message) + except Exception: + log.exception( + "Error in event callback for '%s'", + event_type, + ) diff --git a/src/openocd/flash.py b/src/openocd/flash.py new file mode 100644 index 0000000..08e7f73 --- /dev/null +++ b/src/openocd/flash.py @@ -0,0 +1,381 @@ +"""Flash memory programming operations. + +Wraps OpenOCD's ``flash`` command family for reading, writing, erasing, +and verifying on-chip flash memory banks. +""" + +from __future__ import annotations + +import asyncio +import logging +import re +import tempfile +from pathlib import Path + +from openocd.connection.base import Connection +from openocd.errors import FlashError +from openocd.types import FlashBank, FlashSector + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Regex patterns for parsing OpenOCD flash output +# --------------------------------------------------------------------------- + +# flash banks: #0 : stm32f1x.flash (stm32f1x) at 0x08000000, size 0x00020000, ... +_BANK_LIST_RE = re.compile( + r"#(?P\d+)\s*:\s*(?P\S+)\s+\((?P\S+)\)\s+" + r"at\s+(?P0x[0-9a-fA-F]+),\s*size\s+(?P0x[0-9a-fA-F]+),\s*" + r"buswidth\s+(?P\d+),\s*chipwidth\s+(?P\d+)" +) + +# flash info header: #0 : stm32f1x at 0x08000000, size 0x00020000, ... +_INFO_HEADER_RE = re.compile( + r"#(?P\d+)\s*:\s*(?P\S+)\s+at\s+(?P0x[0-9a-fA-F]+),\s*" + r"size\s+(?P0x[0-9a-fA-F]+),\s*" + r"buswidth\s+(?P\d+),\s*chipwidth\s+(?P\d+)" +) + +# flash info sector: # 0: 0x00000000 (0x400 1kB) not protected +_SECTOR_RE = re.compile( + r"#\s*(?P\d+):\s*(?P0x[0-9a-fA-F]+)\s+" + r"\((?P0x[0-9a-fA-F]+)\s+[^)]*\)\s+" + r"(?Pprotected|not protected)" +) + + +def _check_error(response: str, context: str) -> None: + """Raise FlashError if the response indicates failure.""" + if "error" in response.lower(): + raise FlashError(f"{context}: {response.strip()}") + + +def _parse_bank_list(text: str) -> list[FlashBank]: + """Parse the output of ``flash banks``.""" + banks: list[FlashBank] = [] + for m in _BANK_LIST_RE.finditer(text): + banks.append( + FlashBank( + index=int(m.group("idx")), + name=m.group("name"), + base=int(m.group("base"), 16), + size=int(m.group("size"), 16), + bus_width=int(m.group("bw")), + chip_width=int(m.group("cw")), + target=m.group("driver"), + ) + ) + return banks + + +def _parse_bank_info(text: str) -> FlashBank: + """Parse the output of ``flash info `` into a FlashBank with sectors.""" + header = _INFO_HEADER_RE.search(text) + if not header: + raise FlashError(f"Cannot parse flash info output: {text[:200]}") + + sectors: list[FlashSector] = [] + for m in _SECTOR_RE.finditer(text): + sectors.append( + FlashSector( + index=int(m.group("idx")), + offset=int(m.group("offset"), 16), + size=int(m.group("size"), 16), + protected=m.group("prot") == "protected", + ) + ) + + return FlashBank( + index=int(header.group("idx")), + name=header.group("name"), + base=int(header.group("base"), 16), + size=int(header.group("size"), 16), + bus_width=int(header.group("bw")), + chip_width=int(header.group("cw")), + target=( + header.group("name").split(".")[0] + if "." in header.group("name") + else header.group("name") + ), + sectors=sectors, + ) + + +class Flash: + """Flash memory programming via OpenOCD. + + All methods are async and use the underlying TCL RPC connection to + issue ``flash`` commands. + """ + + def __init__(self, conn: Connection) -> None: + self._conn = conn + + # ------------------------------------------------------------------ + # Bank enumeration + # ------------------------------------------------------------------ + + async def banks(self) -> list[FlashBank]: + """List all configured flash banks. + + Returns: + A list of FlashBank descriptors (without detailed sector info). + """ + resp = await self._conn.send("flash banks") + _check_error(resp, "flash banks") + return _parse_bank_list(resp) + + async def info(self, bank: int = 0) -> FlashBank: + """Get detailed information about a flash bank, including sectors. + + Args: + bank: Flash bank number (default 0). + + Returns: + A FlashBank with populated ``sectors`` list. + """ + resp = await self._conn.send(f"flash info {bank}") + _check_error(resp, f"flash info {bank}") + return _parse_bank_info(resp) + + # ------------------------------------------------------------------ + # Read operations + # ------------------------------------------------------------------ + + async def read(self, bank: int, offset: int, size: int) -> bytes: + """Read raw bytes from a flash bank. + + Uses a temporary file as an intermediary since OpenOCD's + ``flash read_bank`` writes to a file, then reads the file content + back via TCL. + + Args: + bank: Flash bank number. + offset: Byte offset within the bank. + size: Number of bytes to read. + + Returns: + The raw flash contents as bytes. + """ + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as tmp: + tmp_path = tmp.name + + try: + cmd = f"flash read_bank {bank} {tmp_path} {offset} {size}" + resp = await self._conn.send(cmd) + _check_error(resp, f"flash read_bank {bank}") + + # Read the file back through TCL to handle remote OpenOCD instances. + # Use ocd_find + binary read if available, otherwise fall back to + # reading the local file. + tcl_read = ( + f"set fp [open {tmp_path} rb]; " + f"set data [read $fp]; " + f"close $fp; " + f"set data" + ) + try: + raw = await self._conn.send(tcl_read) + # TCL returns binary as string; try base64 approach if garbled + return raw.encode("latin-1") + except Exception: + # Fallback: read the local file directly + return Path(tmp_path).read_bytes() + finally: + Path(tmp_path).unlink(missing_ok=True) + + async def read_to_file(self, bank: int, path: Path) -> None: + """Read an entire flash bank to a local file. + + Args: + bank: Flash bank number. + path: Destination file path. + """ + cmd = f"flash read_bank {bank} {path}" + resp = await self._conn.send(cmd) + _check_error(resp, f"flash read_bank {bank} to {path}") + log.info("Flash bank %d read to %s", bank, path) + + # ------------------------------------------------------------------ + # Write operations + # ------------------------------------------------------------------ + + async def write(self, bank: int, offset: int, data: bytes) -> None: + """Write raw bytes to a flash bank at the given offset. + + Writes data through a temporary file since OpenOCD's + ``flash write_bank`` reads from a file. + + Args: + bank: Flash bank number. + offset: Byte offset within the bank. + data: Bytes to write. + """ + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as tmp: + tmp.write(data) + tmp_path = tmp.name + + try: + cmd = f"flash write_bank {bank} {tmp_path} {offset}" + resp = await self._conn.send(cmd) + _check_error(resp, f"flash write_bank {bank}") + log.info("Wrote %d bytes to flash bank %d at offset 0x%X", len(data), bank, offset) + finally: + Path(tmp_path).unlink(missing_ok=True) + + async def write_image( + self, + path: Path, + erase: bool = True, + verify: bool = True, + ) -> None: + """Program a firmware image into flash. + + This is the high-level "flash and go" command that handles erase, + write, and optional verification in a single operation. + + Args: + path: Path to the firmware image (.bin, .hex, .elf, etc.). + erase: Erase affected sectors before writing (default True). + verify: Verify flash contents after writing (default True). + """ + parts = ["flash", "write_image"] + if erase: + parts.append("erase") + parts.append(str(path)) + + cmd = " ".join(parts) + resp = await self._conn.send(cmd) + _check_error(resp, f"flash write_image {path}") + log.info("Flash image written: %s (erase=%s)", path, erase) + + if verify: + verify_cmd = f"verify_image {path}" + vresp = await self._conn.send(verify_cmd) + if "error" in vresp.lower() or "mismatch" in vresp.lower(): + raise FlashError(f"Post-write verification failed: {vresp.strip()}") + log.info("Flash verification passed for %s", path) + + # ------------------------------------------------------------------ + # Erase operations + # ------------------------------------------------------------------ + + async def erase_sector(self, bank: int, first: int, last: int) -> None: + """Erase a range of sectors within a flash bank. + + Args: + bank: Flash bank number. + first: First sector number to erase (inclusive). + last: Last sector number to erase (inclusive). + """ + if first > last: + raise FlashError(f"Invalid sector range: first ({first}) > last ({last})") + + cmd = f"flash erase_sector {bank} {first} {last}" + resp = await self._conn.send(cmd) + _check_error(resp, f"flash erase_sector {bank} {first}-{last}") + log.info("Erased sectors %d-%d in flash bank %d", first, last, bank) + + async def erase_all(self, bank: int = 0) -> None: + """Erase all sectors in a flash bank. + + Queries the bank info to determine the last sector, then erases + the full range. + + Args: + bank: Flash bank number (default 0). + """ + bank_info = await self.info(bank) + if not bank_info.sectors: + # Fall back to erasing sector 0 through "last" using the count + # OpenOCD also accepts "last" as a keyword + cmd = f"flash erase_sector {bank} 0 last" + resp = await self._conn.send(cmd) + _check_error(resp, f"flash erase_all bank {bank}") + else: + last_sector = bank_info.sectors[-1].index + await self.erase_sector(bank, 0, last_sector) + log.info("Erased all sectors in flash bank %d", bank) + + # ------------------------------------------------------------------ + # Protection + # ------------------------------------------------------------------ + + async def protect(self, bank: int, first: int, last: int, on: bool) -> None: + """Set or clear write protection on a range of sectors. + + Args: + bank: Flash bank number. + first: First sector number (inclusive). + last: Last sector number (inclusive). + on: True to enable protection, False to disable. + """ + flag = "on" if on else "off" + cmd = f"flash protect {bank} {first} {last} {flag}" + resp = await self._conn.send(cmd) + _check_error(resp, f"flash protect {bank} {first}-{last} {flag}") + log.info("Flash bank %d sectors %d-%d protection: %s", bank, first, last, flag) + + # ------------------------------------------------------------------ + # Verify + # ------------------------------------------------------------------ + + async def verify(self, bank: int, path: Path) -> bool: + """Verify flash bank contents against a file. + + Args: + bank: Flash bank number. + path: Path to the reference binary file. + + Returns: + True if the flash contents match the file, False otherwise. + """ + cmd = f"flash verify_bank {bank} {path}" + resp = await self._conn.send(cmd) + if "error" in resp.lower() or "mismatch" in resp.lower(): + log.warning("Flash verify failed for bank %d against %s: %s", bank, path, resp.strip()) + return False + log.info("Flash bank %d verified against %s", bank, path) + return True + + +# ====================================================================== +# Sync wrapper +# ====================================================================== + +class SyncFlash: + """Synchronous wrapper around Flash for use outside async contexts.""" + + def __init__(self, flash: Flash, loop: asyncio.AbstractEventLoop) -> None: + self._flash = flash + self._loop = loop + + def banks(self) -> list[FlashBank]: + return self._loop.run_until_complete(self._flash.banks()) + + def info(self, bank: int = 0) -> FlashBank: + return self._loop.run_until_complete(self._flash.info(bank)) + + def read(self, bank: int, offset: int, size: int) -> bytes: + return self._loop.run_until_complete(self._flash.read(bank, offset, size)) + + def read_to_file(self, bank: int, path: Path) -> None: + self._loop.run_until_complete(self._flash.read_to_file(bank, path)) + + def write(self, bank: int, offset: int, data: bytes) -> None: + self._loop.run_until_complete(self._flash.write(bank, offset, data)) + + def write_image(self, path: Path, erase: bool = True, verify: bool = True) -> None: + self._loop.run_until_complete(self._flash.write_image(path, erase=erase, verify=verify)) + + def erase_sector(self, bank: int, first: int, last: int) -> None: + self._loop.run_until_complete(self._flash.erase_sector(bank, first, last)) + + def erase_all(self, bank: int = 0) -> None: + self._loop.run_until_complete(self._flash.erase_all(bank)) + + def protect(self, bank: int, first: int, last: int, on: bool) -> None: + self._loop.run_until_complete(self._flash.protect(bank, first, last, on=on)) + + def verify(self, bank: int, path: Path) -> bool: + return self._loop.run_until_complete(self._flash.verify(bank, path)) diff --git a/src/openocd/jtag/__init__.py b/src/openocd/jtag/__init__.py new file mode 100644 index 0000000..08a5c54 --- /dev/null +++ b/src/openocd/jtag/__init__.py @@ -0,0 +1,5 @@ +"""JTAG chain control, scan operations, and boundary scan.""" + +from openocd.jtag.chain import JTAGController, SyncJTAGController + +__all__ = ["JTAGController", "SyncJTAGController"] diff --git a/src/openocd/jtag/boundary.py b/src/openocd/jtag/boundary.py new file mode 100644 index 0000000..2a17d28 --- /dev/null +++ b/src/openocd/jtag/boundary.py @@ -0,0 +1,52 @@ +"""SVF and XSVF boundary-scan file execution.""" + +from __future__ import annotations + +from pathlib import Path + +from openocd.connection.base import Connection +from openocd.errors import JTAGError + + +async def svf( + conn: Connection, + path: Path, + *, + tap: str | None = None, + quiet: bool = False, + progress: bool = True, +) -> None: + """Execute an SVF (Serial Vector Format) file. + + Args: + path: Path to the ``.svf`` file. + tap: Restrict operations to this TAP. When ``None``, OpenOCD + applies vectors to whatever TAP is appropriate. + quiet: Suppress per-statement logging inside OpenOCD. + progress: Show a progress indicator (default on). + """ + parts = ["svf", str(path)] + if tap is not None: + parts.extend(["-tap", tap]) + if quiet: + parts.append("quiet") + if progress: + parts.append("progress") + resp = await conn.send(" ".join(parts)) + _check_error(resp, "svf") + + +async def xsvf(conn: Connection, tap: str, path: Path) -> None: + """Execute an XSVF file against the given TAP. + + Args: + tap: TAP to target (e.g. ``stm32f1x.cpu``). + path: Path to the ``.xsvf`` file. + """ + resp = await conn.send(f"xsvf {tap} {path}") + _check_error(resp, "xsvf") + + +def _check_error(response: str, command: str) -> None: + if "Error" in response or "error" in response.split("\n")[0]: + raise JTAGError(f"{command} failed: {response.strip()}") diff --git a/src/openocd/jtag/chain.py b/src/openocd/jtag/chain.py new file mode 100644 index 0000000..9b52cb6 --- /dev/null +++ b/src/openocd/jtag/chain.py @@ -0,0 +1,218 @@ +"""JTAG chain enumeration and the main JTAGController facade.""" + +from __future__ import annotations + +import asyncio +import logging +import re +from pathlib import Path + +from openocd.connection.base import Connection +from openocd.errors import JTAGError +from openocd.jtag import boundary as _boundary +from openocd.jtag import scan as _scan +from openocd.jtag import state as _state +from openocd.types import JTAGState, TAPInfo + +log = logging.getLogger(__name__) + +# Regex for one row of ``scan_chain`` output. +# Example line: +# 0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f +_CHAIN_ROW_RE = re.compile( + r"^\s*\d+\s+" # index + r"(\S+)\s+" # tap name (chip.tap) + r"([YN])\s+" # enabled + r"(0x[0-9a-fA-F]+)\s+" # idcode + r"(0x[0-9a-fA-F]+)\s+" # expected + r"(\d+)", # ir_length +) + + +async def scan_chain(conn: Connection) -> list[TAPInfo]: + """Query the JTAG scan chain and return a list of discovered TAPs.""" + resp = await conn.send("scan_chain") + if "Error" in resp: + raise JTAGError(f"scan_chain failed: {resp.strip()}") + return _parse_scan_chain(resp) + + +async def new_tap( + conn: Connection, + chip: str, + tap: str, + ir_len: int, + expected_id: int | None = None, +) -> None: + """Declare a new TAP on the JTAG chain. + + Args: + chip: Chip name (first part of the dotted TAP name). + tap: TAP name (second part, e.g. ``cpu``, ``bs``). + ir_len: Instruction register length in bits. + expected_id: Expected IDCODE. When ``None``, OpenOCD skips + the IDCODE verification. + """ + parts = ["jtag", "newtap", chip, tap, "-irlen", str(ir_len)] + if expected_id is not None: + parts.extend(["-expected-id", f"0x{expected_id:08x}"]) + resp = await conn.send(" ".join(parts)) + if "Error" in resp: + raise JTAGError(f"newtap failed: {resp.strip()}") + + +def _parse_scan_chain(raw: str) -> list[TAPInfo]: + """Parse the tabular output of ``scan_chain``.""" + taps: list[TAPInfo] = [] + for line in raw.splitlines(): + m = _CHAIN_ROW_RE.match(line) + if not m: + continue + full_name = m.group(1) + # Split "chip.tap" into components + dot = full_name.find(".") + if dot == -1: + chip_part, tap_part = full_name, "" + else: + chip_part, tap_part = full_name[:dot], full_name[dot + 1 :] + + taps.append( + TAPInfo( + name=full_name, + chip=chip_part, + tap_name=tap_part, + idcode=int(m.group(3), 16), + ir_length=int(m.group(5)), + enabled=m.group(2) == "Y", + ) + ) + return taps + + +# ====================================================================== +# JTAGController — unified facade +# ====================================================================== + +class JTAGController: + """High-level async interface to all JTAG operations. + + Delegates to helper functions in the ``scan``, ``state``, and + ``boundary`` sub-modules so each method stays concise. + """ + + def __init__(self, conn: Connection) -> None: + self._conn = conn + + # -- Chain enumeration ------------------------------------------------- + + async def scan_chain(self) -> list[TAPInfo]: + """Return every TAP discovered on the JTAG chain.""" + return await scan_chain(self._conn) + + async def new_tap( + self, + chip: str, + tap: str, + ir_len: int, + expected_id: int | None = None, + ) -> None: + """Declare a new TAP on the chain.""" + await new_tap(self._conn, chip, tap, ir_len, expected_id) + + # -- Scan operations --------------------------------------------------- + + async def irscan(self, tap: str, instruction: int) -> int: + """Shift *instruction* into the TAP instruction register.""" + return await _scan.irscan(self._conn, tap, instruction) + + async def drscan(self, tap: str, bits: int, value: int) -> int: + """Shift *value* (of width *bits*) through the data register.""" + return await _scan.drscan(self._conn, tap, bits, value) + + async def runtest(self, cycles: int) -> None: + """Clock *cycles* TCK pulses in the Run-Test/Idle state.""" + await _scan.runtest(self._conn, cycles) + + # -- TAP state machine ------------------------------------------------- + + async def pathmove(self, states: list[JTAGState]) -> None: + """Walk the TAP controller through an explicit state sequence.""" + await _state.pathmove(self._conn, states) + + # -- Boundary scan (SVF / XSVF) ---------------------------------------- + + async def svf( + self, + path: Path, + tap: str | None = None, + *, + quiet: bool = False, + progress: bool = True, + ) -> None: + """Execute an SVF boundary-scan file.""" + await _boundary.svf(self._conn, path, tap=tap, quiet=quiet, progress=progress) + + async def xsvf(self, tap: str, path: Path) -> None: + """Execute an XSVF boundary-scan file against *tap*.""" + await _boundary.xsvf(self._conn, tap, path) + + +# ====================================================================== +# SyncJTAGController — blocking wrappers +# ====================================================================== + +class SyncJTAGController: + """Synchronous wrapper around :class:`JTAGController`. + + Every async method is exposed with the same signature but runs + through ``loop.run_until_complete``. + """ + + def __init__(self, ctrl: JTAGController, loop: asyncio.AbstractEventLoop) -> None: + self._ctrl = ctrl + self._loop = loop + + # -- Chain ------------------------------------------------------------- + + def scan_chain(self) -> list[TAPInfo]: + return self._loop.run_until_complete(self._ctrl.scan_chain()) + + def new_tap( + self, + chip: str, + tap: str, + ir_len: int, + expected_id: int | None = None, + ) -> None: + self._loop.run_until_complete(self._ctrl.new_tap(chip, tap, ir_len, expected_id)) + + # -- Scan -------------------------------------------------------------- + + def irscan(self, tap: str, instruction: int) -> int: + return self._loop.run_until_complete(self._ctrl.irscan(tap, instruction)) + + def drscan(self, tap: str, bits: int, value: int) -> int: + return self._loop.run_until_complete(self._ctrl.drscan(tap, bits, value)) + + def runtest(self, cycles: int) -> None: + self._loop.run_until_complete(self._ctrl.runtest(cycles)) + + # -- State ------------------------------------------------------------- + + def pathmove(self, states: list[JTAGState]) -> None: + self._loop.run_until_complete(self._ctrl.pathmove(states)) + + # -- Boundary ---------------------------------------------------------- + + def svf( + self, + path: Path, + tap: str | None = None, + *, + quiet: bool = False, + progress: bool = True, + ) -> None: + self._loop.run_until_complete(self._ctrl.svf(path, tap, quiet=quiet, progress=progress)) + + def xsvf(self, tap: str, path: Path) -> None: + self._loop.run_until_complete(self._ctrl.xsvf(tap, path)) diff --git a/src/openocd/jtag/scan.py b/src/openocd/jtag/scan.py new file mode 100644 index 0000000..777f524 --- /dev/null +++ b/src/openocd/jtag/scan.py @@ -0,0 +1,58 @@ +"""IR/DR scan and TCK run-test operations.""" + +from __future__ import annotations + +from openocd.connection.base import Connection +from openocd.errors import JTAGError + + +async def irscan(conn: Connection, tap: str, instruction: int) -> int: + """Shift an instruction into the TAP instruction register. + + Returns the value shifted out of the IR during the operation. + """ + resp = await conn.send(f"irscan {tap} 0x{instruction:x}") + _check_error(resp, "irscan") + # OpenOCD returns the shifted-out value as a hex string + cleaned = resp.strip() + if not cleaned: + return 0 + try: + return int(cleaned, 16) + except ValueError: + return 0 + + +async def drscan(conn: Connection, tap: str, bits: int, value: int) -> int: + """Shift data through the TAP data register. + + Args: + tap: TAP name (e.g. ``stm32f1x.cpu``). + bits: Number of bits to shift. + value: Value to shift in. + + Returns the value shifted out of the DR. + """ + resp = await conn.send(f"drscan {tap} {bits} 0x{value:x}") + _check_error(resp, "drscan") + cleaned = resp.strip() + if not cleaned: + return 0 + try: + return int(cleaned, 16) + except ValueError: + return 0 + + +async def runtest(conn: Connection, cycles: int) -> None: + """Execute *cycles* TCK clocks in the Run-Test/Idle state.""" + if cycles < 0: + raise JTAGError(f"runtest cycles must be non-negative, got {cycles}") + resp = await conn.send(f"runtest {cycles}") + _check_error(resp, "runtest") + + +def _check_error(response: str, command: str) -> None: + """Raise JTAGError if OpenOCD reported an error.""" + if "Error" in response or "error" in response.split("\n")[0]: + raise JTAGError(f"{command} failed: {response.strip()}") diff --git a/src/openocd/jtag/state.py b/src/openocd/jtag/state.py new file mode 100644 index 0000000..6fe9994 --- /dev/null +++ b/src/openocd/jtag/state.py @@ -0,0 +1,26 @@ +"""TAP state-machine path movement.""" + +from __future__ import annotations + +from openocd.connection.base import Connection +from openocd.errors import JTAGError +from openocd.types import JTAGState + + +async def pathmove(conn: Connection, states: list[JTAGState]) -> None: + """Move the TAP controller through a sequence of states. + + Each state must be a legal single-step transition from the previous one + in the IEEE 1149.1 state machine. OpenOCD validates the path and will + report an error for illegal transitions. + """ + if not states: + raise JTAGError("pathmove requires at least one target state") + state_names = " ".join(s.value for s in states) + resp = await conn.send(f"pathmove {state_names}") + _check_error(resp, "pathmove") + + +def _check_error(response: str, command: str) -> None: + if "Error" in response or "error" in response.split("\n")[0]: + raise JTAGError(f"{command} failed: {response.strip()}") diff --git a/src/openocd/memory.py b/src/openocd/memory.py new file mode 100644 index 0000000..dd5ebd5 --- /dev/null +++ b/src/openocd/memory.py @@ -0,0 +1,243 @@ +"""Memory read/write operations via OpenOCD TCL API. + +Uses the ``read_memory`` and ``write_memory`` TCL commands for reliable +structured I/O, falling back to ``mdb``/``mdw`` style commands only +where the TCL API is unavailable. +""" + +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path + +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.errors import TargetError + +log = logging.getLogger(__name__) + +# Width constants for read_memory / write_memory +_WIDTH_8 = 8 +_WIDTH_16 = 16 +_WIDTH_32 = 32 +_WIDTH_64 = 64 + +# Hexdump formatting +_HEXDUMP_BYTES_PER_LINE = 16 + + +class Memory: + """Read and write target memory.""" + + def __init__(self, conn: TclRpcConnection) -> None: + self._conn = conn + + # ------------------------------------------------------------------ + # Typed reads + # ------------------------------------------------------------------ + + async def read_u8(self, addr: int, count: int = 1) -> list[int]: + """Read 8-bit values starting at *addr*.""" + return await self._read(addr, _WIDTH_8, count) + + async def read_u16(self, addr: int, count: int = 1) -> list[int]: + """Read 16-bit values starting at *addr*.""" + return await self._read(addr, _WIDTH_16, count) + + async def read_u32(self, addr: int, count: int = 1) -> list[int]: + """Read 32-bit values starting at *addr*.""" + return await self._read(addr, _WIDTH_32, count) + + async def read_u64(self, addr: int, count: int = 1) -> list[int]: + """Read 64-bit values starting at *addr*.""" + return await self._read(addr, _WIDTH_64, count) + + async def read_bytes(self, addr: int, size: int) -> bytes: + """Read *size* bytes starting at *addr* and return as a bytes object.""" + values = await self._read(addr, _WIDTH_8, size) + return bytes(values) + + # ------------------------------------------------------------------ + # Typed writes + # ------------------------------------------------------------------ + + async def write_u8(self, addr: int, values: int | list[int]) -> None: + """Write one or more 8-bit values starting at *addr*.""" + await self._write(addr, _WIDTH_8, values) + + async def write_u16(self, addr: int, values: int | list[int]) -> None: + """Write one or more 16-bit values starting at *addr*.""" + await self._write(addr, _WIDTH_16, values) + + async def write_u32(self, addr: int, values: int | list[int]) -> None: + """Write one or more 32-bit values starting at *addr*.""" + await self._write(addr, _WIDTH_32, values) + + async def write_bytes(self, addr: int, data: bytes) -> None: + """Write raw bytes to memory starting at *addr*.""" + await self._write(addr, _WIDTH_8, list(data)) + + # ------------------------------------------------------------------ + # Utilities + # ------------------------------------------------------------------ + + async def search(self, pattern: bytes, start: int, end: int) -> list[int]: + """Search for *pattern* in memory between *start* and *end*. + + Reads the region in chunks and returns a list of addresses where + the pattern was found. This is done client-side since OpenOCD + has no native memory-search command. + """ + if not pattern: + return [] + + region_size = end - start + if region_size <= 0: + return [] + + chunk_size = 4096 + overlap = len(pattern) - 1 + results: list[int] = [] + offset = 0 + + while offset < region_size: + read_len = min(chunk_size + overlap, region_size - offset) + data = await self.read_bytes(start + offset, read_len) + + # Scan for the pattern within this chunk + search_start = 0 + while True: + idx = data.find(pattern, search_start) + if idx == -1: + break + results.append(start + offset + idx) + search_start = idx + 1 + + # Advance past the non-overlapping portion + offset += chunk_size + + return results + + async def dump(self, addr: int, size: int, path: Path) -> None: + """Read *size* bytes from *addr* and write them to a file.""" + data = await self.read_bytes(addr, size) + path.write_bytes(data) + log.info("Dumped %d bytes from 0x%08X to %s", size, addr, path) + + async def hexdump(self, addr: int, size: int) -> str: + """Read *size* bytes and return a formatted hex+ASCII dump. + + Output format (16 bytes per line):: + + 08000000: 00 50 00 20 A1 01 00 08 AB 01 00 08 AD 01 00 08 |.P. ............| + """ + data = await self.read_bytes(addr, size) + lines: list[str] = [] + + for offset in range(0, len(data), _HEXDUMP_BYTES_PER_LINE): + chunk = data[offset : offset + _HEXDUMP_BYTES_PER_LINE] + line_addr = addr + offset + + # Hex portion — two groups of 8 bytes separated by an extra space + hex_parts: list[str] = [] + for i, b in enumerate(chunk): + hex_parts.append(f"{b:02X}") + if i == 7: + hex_parts.append("") # extra gap between byte 7 and 8 + hex_str = " ".join(hex_parts) + + # Pad to consistent width (3 chars * 16 bytes + 1 extra gap = 49 chars) + # 16 hex pairs = 16*2=32 hex chars, 15 spaces + 1 gap space = 16 = 49 + hex_str = hex_str.ljust(49) + + # ASCII portion + ascii_str = "".join( + chr(b) if 0x20 <= b < 0x7F else "." for b in chunk + ) + + lines.append(f"{line_addr:08X}: {hex_str} |{ascii_str}|") + + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _read(self, addr: int, width: int, count: int) -> list[int]: + """Read *count* values of *width* bits using the TCL ``read_memory`` API. + + Command: ``read_memory `` + Response: space-separated hex values. + """ + cmd = f"read_memory 0x{addr:x} {width} {count}" + resp = await self._conn.send(cmd) + + if "error" in resp.lower(): + raise TargetError(f"read_memory failed: {resp}") + + tokens = resp.strip().split() + try: + return [int(t, 16) for t in tokens] + except ValueError as exc: + raise TargetError( + f"Cannot parse read_memory response: {resp!r}" + ) from exc + + async def _write(self, addr: int, width: int, values: int | list[int]) -> None: + """Write values of *width* bits using the TCL ``write_memory`` API. + + Command: ``write_memory {val1 val2 ...}`` + """ + if isinstance(values, int): + values = [values] + + val_str = " ".join(f"0x{v:x}" for v in values) + cmd = f"write_memory 0x{addr:x} {width} {{{val_str}}}" + resp = await self._conn.send(cmd) + + if "error" in resp.lower(): + raise TargetError(f"write_memory failed: {resp}") + + +class SyncMemory: + """Synchronous wrapper around Memory.""" + + def __init__(self, memory: Memory, loop: asyncio.AbstractEventLoop) -> None: + self._memory = memory + self._loop = loop + + def read_u8(self, addr: int, count: int = 1) -> list[int]: + return self._loop.run_until_complete(self._memory.read_u8(addr, count)) + + def read_u16(self, addr: int, count: int = 1) -> list[int]: + return self._loop.run_until_complete(self._memory.read_u16(addr, count)) + + def read_u32(self, addr: int, count: int = 1) -> list[int]: + return self._loop.run_until_complete(self._memory.read_u32(addr, count)) + + def read_u64(self, addr: int, count: int = 1) -> list[int]: + return self._loop.run_until_complete(self._memory.read_u64(addr, count)) + + def read_bytes(self, addr: int, size: int) -> bytes: + return self._loop.run_until_complete(self._memory.read_bytes(addr, size)) + + def write_u8(self, addr: int, values: int | list[int]) -> None: + self._loop.run_until_complete(self._memory.write_u8(addr, values)) + + def write_u16(self, addr: int, values: int | list[int]) -> None: + self._loop.run_until_complete(self._memory.write_u16(addr, values)) + + def write_u32(self, addr: int, values: int | list[int]) -> None: + self._loop.run_until_complete(self._memory.write_u32(addr, values)) + + def write_bytes(self, addr: int, data: bytes) -> None: + self._loop.run_until_complete(self._memory.write_bytes(addr, data)) + + def search(self, pattern: bytes, start: int, end: int) -> list[int]: + return self._loop.run_until_complete(self._memory.search(pattern, start, end)) + + def dump(self, addr: int, size: int, path: Path) -> None: + self._loop.run_until_complete(self._memory.dump(addr, size, path)) + + def hexdump(self, addr: int, size: int) -> str: + return self._loop.run_until_complete(self._memory.hexdump(addr, size)) diff --git a/src/openocd/process.py b/src/openocd/process.py new file mode 100644 index 0000000..94ae8cf --- /dev/null +++ b/src/openocd/process.py @@ -0,0 +1,137 @@ +"""OpenOCD subprocess management. + +Spawns an OpenOCD process, waits for the TCL RPC port to become +available, and provides clean shutdown. +""" + +from __future__ import annotations + +import asyncio +import logging +import shutil + +from openocd.errors import ProcessError +from openocd.errors import TimeoutError as OpenOCDTimeoutError + +log = logging.getLogger(__name__) + +DEFAULT_TCL_PORT = 6666 +READY_POLL_INTERVAL = 0.25 + + +class OpenOCDProcess: + """Spawn and manage an OpenOCD subprocess.""" + + def __init__(self) -> None: + self._proc: asyncio.subprocess.Process | None = None + self._tcl_port: int = DEFAULT_TCL_PORT + + @property + def pid(self) -> int | None: + return self._proc.pid if self._proc else None + + @property + def running(self) -> bool: + return self._proc is not None and self._proc.returncode is None + + @property + def tcl_port(self) -> int: + return self._tcl_port + + async def start( + self, + config: str, + extra_args: list[str] | None = None, + tcl_port: int = DEFAULT_TCL_PORT, + openocd_bin: str | None = None, + ) -> None: + """Start OpenOCD with the given configuration. + + Args: + config: Config file path or inline ``-f`` / ``-c`` arguments. + Multiple files can be separated by spaces with ``-f`` prefixes, + e.g. ``"interface/cmsis-dap.cfg -f target/stm32f1x.cfg"``. + extra_args: Additional CLI arguments. + tcl_port: TCL RPC port (default 6666). + openocd_bin: Path to OpenOCD binary (auto-detected if None). + """ + self._tcl_port = tcl_port + binary = openocd_bin or shutil.which("openocd") + if not binary: + raise ProcessError( + "OpenOCD binary not found. Install it or pass openocd_bin=" + ) + + args = [binary] + # Parse the config string — support both bare paths and -f/-c flags + config_parts = config.split() + i = 0 + while i < len(config_parts): + part = config_parts[i] + if part in ("-f", "-c"): + args.extend([part, config_parts[i + 1]]) + i += 2 + else: + args.extend(["-f", part]) + i += 1 + + args.extend(["-c", f"tcl_port {tcl_port}"]) + + if extra_args: + args.extend(extra_args) + + log.info("Starting OpenOCD: %s", " ".join(args)) + try: + self._proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + except FileNotFoundError as exc: + raise ProcessError(f"OpenOCD binary not found at {binary}") from exc + except OSError as exc: + raise ProcessError(f"Failed to start OpenOCD: {exc}") from exc + + async def wait_ready(self, timeout: float = 10.0) -> None: + """Poll until the TCL RPC port is accepting connections.""" + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + # Check if the process died + if self._proc and self._proc.returncode is not None: + stderr = "" + if self._proc.stderr: + raw = await self._proc.stderr.read() + stderr = raw.decode("utf-8", errors="replace") + raise ProcessError( + f"OpenOCD exited with code {self._proc.returncode}: {stderr[-500:]}" + ) + + try: + _, writer = await asyncio.wait_for( + asyncio.open_connection("localhost", self._tcl_port), + timeout=1.0, + ) + writer.close() + await writer.wait_closed() + log.info("OpenOCD ready on TCL port %d", self._tcl_port) + return + except (OSError, TimeoutError): + await asyncio.sleep(READY_POLL_INTERVAL) + + raise OpenOCDTimeoutError( + f"OpenOCD did not become ready within {timeout}s" + ) + + async def stop(self) -> None: + """Terminate the OpenOCD process.""" + if not self._proc: + return + if self._proc.returncode is None: + self._proc.terminate() + try: + await asyncio.wait_for(self._proc.wait(), timeout=5.0) + except TimeoutError: + self._proc.kill() + await self._proc.wait() + log.info("OpenOCD process stopped (pid=%d)", self._proc.pid) + self._proc = None diff --git a/src/openocd/registers.py b/src/openocd/registers.py new file mode 100644 index 0000000..d8ae9cd --- /dev/null +++ b/src/openocd/registers.py @@ -0,0 +1,186 @@ +"""CPU register access via OpenOCD. + +Wraps the ``reg`` command family to read and write individual registers, +list all registers, and provide ARM Cortex-M convenience accessors. +""" + +from __future__ import annotations + +import asyncio +import logging +import re + +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.errors import TargetError, TargetNotHaltedError +from openocd.types import Register + +log = logging.getLogger(__name__) + +# Matches "reg " output: "pc (/32): 0x08001234" +_REG_VALUE_RE = re.compile( + r"(\S+)\s+\(/(\d+)\):\s*(0x[0-9a-fA-F]+)" +) + +# Matches a row in "reg" (list all) output. +# Typical formats: +# "(0) r0 (/32): 0x00000000" +# "(123) xPSR (/32): 0x61000000 (dirty)" +_REG_LIST_RE = re.compile( + r"\((\d+)\)\s+" # register number + r"(\S+)\s+" # register name + r"\(/(\d+)\):\s*" # bit width + r"(0x[0-9a-fA-F]+)" # value + r"(?:\s+\(dirty\))?" # optional dirty flag +) + + +class Registers: + """Read and write CPU registers.""" + + def __init__(self, conn: TclRpcConnection) -> None: + self._conn = conn + + async def read(self, name: str) -> int: + """Read a single register by name and return its value. + + Args: + name: Register name, e.g. ``"pc"``, ``"r0"``, ``"xPSR"``. + + Returns: + The register value as an integer. + + Raises: + TargetNotHaltedError: Target must be halted for register access. + TargetError: Register not found or command failed. + """ + resp = await self._conn.send(f"reg {name}") + self._check_halted(resp) + + m = _REG_VALUE_RE.search(resp) + if not m: + raise TargetError(f"Cannot parse register '{name}' from: {resp}") + + return int(m.group(3), 16) + + async def write(self, name: str, value: int) -> None: + """Write a value to a register. + + Args: + name: Register name. + value: Value to write. + """ + resp = await self._conn.send(f"reg {name} 0x{value:x}") + self._check_halted(resp) + + if "error" in resp.lower() and "not halted" not in resp.lower(): + raise TargetError(f"reg write failed: {resp}") + + async def read_all(self) -> dict[str, Register]: + """Read all registers and return them as a dict keyed by name. + + Returns: + Mapping of register name to Register dataclass. + """ + resp = await self._conn.send("reg") + self._check_halted(resp) + + registers: dict[str, Register] = {} + for line in resp.splitlines(): + m = _REG_LIST_RE.search(line) + if m: + number = int(m.group(1)) + name = m.group(2) + size = int(m.group(3)) + value = int(m.group(4), 16) + dirty = "(dirty)" in line + + registers[name] = Register( + name=name, + number=number, + value=value, + size=size, + dirty=dirty, + ) + + return registers + + async def read_many(self, names: list[str]) -> dict[str, int]: + """Read several registers by name. + + Args: + names: List of register names. + + Returns: + Mapping of register name to value. + """ + results: dict[str, int] = {} + for name in names: + results[name] = await self.read(name) + return results + + # ------------------------------------------------------------------ + # ARM Cortex-M convenience accessors + # ------------------------------------------------------------------ + + async def pc(self) -> int: + """Read the program counter.""" + return await self.read("pc") + + async def sp(self) -> int: + """Read the stack pointer.""" + return await self.read("sp") + + async def lr(self) -> int: + """Read the link register.""" + return await self.read("lr") + + async def xpsr(self) -> int: + """Read the xPSR (combined program status register).""" + return await self.read("xPSR") + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _check_halted(resp: str) -> None: + """Raise TargetNotHaltedError if the response indicates the target + is not halted (register access requires a halted target). + """ + lower = resp.lower() + if "not halted" in lower or "target not halted" in lower: + raise TargetNotHaltedError( + "Target must be halted to access registers" + ) + + +class SyncRegisters: + """Synchronous wrapper around Registers.""" + + def __init__(self, registers: Registers, loop: asyncio.AbstractEventLoop) -> None: + self._registers = registers + self._loop = loop + + def read(self, name: str) -> int: + return self._loop.run_until_complete(self._registers.read(name)) + + def write(self, name: str, value: int) -> None: + self._loop.run_until_complete(self._registers.write(name, value)) + + def read_all(self) -> dict[str, Register]: + return self._loop.run_until_complete(self._registers.read_all()) + + def read_many(self, names: list[str]) -> dict[str, int]: + return self._loop.run_until_complete(self._registers.read_many(names)) + + def pc(self) -> int: + return self._loop.run_until_complete(self._registers.pc()) + + def sp(self) -> int: + return self._loop.run_until_complete(self._registers.sp()) + + def lr(self) -> int: + return self._loop.run_until_complete(self._registers.lr()) + + def xpsr(self) -> int: + return self._loop.run_until_complete(self._registers.xpsr()) diff --git a/src/openocd/rtt.py b/src/openocd/rtt.py new file mode 100644 index 0000000..bd5cc70 --- /dev/null +++ b/src/openocd/rtt.py @@ -0,0 +1,226 @@ +"""Real-Time Transfer (RTT) support via OpenOCD. + +SEGGER RTT provides high-speed bidirectional communication between +a debug host and an embedded target using shared memory in RAM. +OpenOCD exposes RTT through its ``rtt`` command family. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from typing import TYPE_CHECKING + +from openocd.errors import OpenOCDError +from openocd.types import RTTChannel + +if TYPE_CHECKING: + from openocd.connection.base import Connection + +log = logging.getLogger(__name__) + + +class RTTManager: + """Control and use SEGGER RTT channels via OpenOCD. + + Typical flow:: + + rtt = RTTManager(conn) + await rtt.setup(address=0x20000000, size=0x1000) + await rtt.start() + channels = await rtt.channels() + data = await rtt.read(0) + await rtt.write(0, "hello\\n") + await rtt.stop() + """ + + def __init__(self, conn: Connection) -> None: + self._conn = conn + + async def setup( + self, + address: int, + size: int, + id_string: str = "SEGGER RTT", + ) -> None: + """Configure RTT control block search parameters. + + Args: + address: Start address of the RAM region to search. + size: Size of the search region in bytes. + id_string: RTT control block identifier (default "SEGGER RTT"). + + Raises: + OpenOCDError: If the setup command fails. + """ + cmd = f'rtt setup 0x{address:X} 0x{size:X} "{id_string}"' + response = await self._conn.send(cmd) + _check_rtt_response(response, cmd) + log.info( + "RTT setup: search 0x%08X +0x%X id=%r", + address, + size, + id_string, + ) + + async def start(self) -> None: + """Start RTT — searches for the control block and activates channels. + + Raises: + OpenOCDError: If the control block is not found or start fails. + """ + response = await self._conn.send("rtt start") + _check_rtt_response(response, "rtt start") + log.info("RTT started") + + async def stop(self) -> None: + """Stop RTT communication. + + Raises: + OpenOCDError: If the stop command fails. + """ + response = await self._conn.send("rtt stop") + _check_rtt_response(response, "rtt stop") + log.info("RTT stopped") + + async def channels(self) -> list[RTTChannel]: + """List available RTT channels. + + Returns: + List of RTTChannel descriptors (index, name, size, direction). + + Raises: + OpenOCDError: If the channels command fails. + """ + response = await self._conn.send("rtt channels") + _check_rtt_response(response, "rtt channels") + return _parse_channels(response) + + async def read(self, channel: int) -> str: + """Read pending data from an RTT up-channel. + + Args: + channel: Channel index (typically 0 for Terminal). + + Returns: + The data read as a string (may be empty if nothing pending). + + Raises: + OpenOCDError: If the read command fails. + """ + cmd = f"rtt channelread {channel}" + response = await self._conn.send(cmd) + _check_rtt_response(response, cmd) + return response + + async def write(self, channel: int, data: str) -> None: + """Write data to an RTT down-channel. + + Args: + channel: Channel index (typically 0 for Terminal). + data: String data to send to the target. + + Raises: + OpenOCDError: If the write command fails. + """ + cmd = f'rtt channelwrite {channel} "{data}"' + response = await self._conn.send(cmd) + _check_rtt_response(response, cmd) + + +class SyncRTTManager: + """Synchronous wrapper around RTTManager.""" + + def __init__(self, manager: RTTManager, loop: asyncio.AbstractEventLoop) -> None: + self._manager = manager + self._loop = loop + + def setup( + self, + address: int, + size: int, + id_string: str = "SEGGER RTT", + ) -> None: + self._loop.run_until_complete( + self._manager.setup(address, size, id_string) + ) + + def start(self) -> None: + self._loop.run_until_complete(self._manager.start()) + + def stop(self) -> None: + self._loop.run_until_complete(self._manager.stop()) + + def channels(self) -> list[RTTChannel]: + return self._loop.run_until_complete(self._manager.channels()) + + def read(self, channel: int) -> str: + return self._loop.run_until_complete(self._manager.read(channel)) + + def write(self, channel: int, data: str) -> None: + self._loop.run_until_complete(self._manager.write(channel, data)) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _check_rtt_response(response: str, command: str) -> None: + """Raise on error responses from RTT commands.""" + if response and "error" in response.lower(): + raise OpenOCDError(f"RTT command failed ({command}): {response}") + + +def _parse_channels(response: str) -> list[RTTChannel]: + """Parse the output of ``rtt channels`` into RTTChannel objects. + + OpenOCD typically outputs lines like:: + + Up-channels: + 0: Terminal 1024 + Down-channels: + 0: Terminal 16 + + The exact format may vary by OpenOCD version; this parser is + intentionally lenient. + """ + channels: list[RTTChannel] = [] + direction = "up" + + for line in response.splitlines(): + stripped = line.strip() + lower = stripped.lower() + + if "up-channel" in lower or lower.startswith("up"): + direction = "up" + continue + if "down-channel" in lower or lower.startswith("down"): + direction = "down" + continue + + # Try to parse lines like "0: Terminal 1024" + if ":" in stripped and stripped[0].isdigit(): + parts = stripped.split(":", 1) + try: + index = int(parts[0].strip()) + except ValueError: + continue + + rest = parts[1].strip().split() + name = rest[0] if rest else f"channel_{index}" + size = 0 + if len(rest) >= 2: + with contextlib.suppress(ValueError): + size = int(rest[-1]) + + channels.append( + RTTChannel( + index=index, + name=name, + size=size, + direction=direction, + ) + ) + + return channels diff --git a/src/openocd/session.py b/src/openocd/session.py new file mode 100644 index 0000000..72ba36e --- /dev/null +++ b/src/openocd/session.py @@ -0,0 +1,301 @@ +"""Session — the main entry point for openocd-python. + +Manages the connection lifecycle and provides access to all subsystems +(target, memory, registers, flash, JTAG, SVD, etc.). +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING + +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.process import OpenOCDProcess + +if TYPE_CHECKING: + from openocd.breakpoints import BreakpointManager, SyncBreakpointManager + from openocd.flash import Flash, SyncFlash + from openocd.jtag import JTAGController, SyncJTAGController + from openocd.memory import Memory, SyncMemory + from openocd.registers import Registers, SyncRegisters + from openocd.rtt import RTTManager + from openocd.svd import SVDManager, SyncSVDManager + from openocd.target import SyncTarget, Target + from openocd.transport import Transport + +log = logging.getLogger(__name__) + + +class Session: + """Main entry point. Manages connection and provides access to subsystems.""" + + def __init__(self, connection: TclRpcConnection, process: OpenOCDProcess | None = None) -> None: + self._conn = connection + self._process = process + self._target: Target | None = None + self._memory: Memory | None = None + self._registers: Registers | None = None + self._flash: Flash | None = None + self._jtag: JTAGController | None = None + self._breakpoints: BreakpointManager | None = None + self._rtt: RTTManager | None = None + self._svd: SVDManager | None = None + self._transport: Transport | None = None + + # ------------------------------------------------------------------ + # Factory methods + # ------------------------------------------------------------------ + + @classmethod + async def start( + cls, + config: str | Path, + *, + tcl_port: int = 6666, + openocd_bin: str | None = None, + timeout: float = 10.0, + extra_args: list[str] | None = None, + ) -> Session: + """Spawn an OpenOCD process and connect to it. + + Args: + config: Config file path or ``-f``/``-c`` flags string. + tcl_port: TCL RPC port. + openocd_bin: Custom OpenOCD binary path. + timeout: Seconds to wait for OpenOCD readiness. + extra_args: Additional CLI arguments for OpenOCD. + """ + proc = OpenOCDProcess() + await proc.start( + str(config), extra_args=extra_args, tcl_port=tcl_port, openocd_bin=openocd_bin + ) + await proc.wait_ready(timeout=timeout) + + conn = TclRpcConnection(timeout=timeout) + await conn.connect("localhost", tcl_port) + + return cls(connection=conn, process=proc) + + @classmethod + async def connect( + cls, + host: str = "localhost", + port: int = 6666, + timeout: float = 10.0, + ) -> Session: + """Connect to an already-running OpenOCD instance.""" + conn = TclRpcConnection(timeout=timeout) + await conn.connect(host, port) + return cls(connection=conn) + + # ------------------------------------------------------------------ + # Sync factory wrappers + # ------------------------------------------------------------------ + + @classmethod + def start_sync(cls, config: str | Path, **kwargs) -> SyncSession: + """Synchronous version of start(). Returns a SyncSession.""" + loop = _get_or_create_loop() + session = loop.run_until_complete(cls.start(config, **kwargs)) + return SyncSession(session, loop) + + @classmethod + def connect_sync(cls, host: str = "localhost", port: int = 6666, **kwargs) -> SyncSession: + """Synchronous version of connect(). Returns a SyncSession.""" + loop = _get_or_create_loop() + session = loop.run_until_complete(cls.connect(host, port, **kwargs)) + return SyncSession(session, loop) + + # ------------------------------------------------------------------ + # Context manager + # ------------------------------------------------------------------ + + async def __aenter__(self) -> Session: + return self + + async def __aexit__(self, *exc) -> None: + await self.close() + + async def close(self) -> None: + """Close the connection and stop the subprocess if we spawned it.""" + await self._conn.close() + if self._process: + await self._process.stop() + + # ------------------------------------------------------------------ + # Raw command escape hatch + # ------------------------------------------------------------------ + + async def command(self, cmd: str) -> str: + """Send a raw OpenOCD command and return the response string.""" + return await self._conn.send(cmd) + + # ------------------------------------------------------------------ + # Subsystem accessors (lazy-initialized) + # ------------------------------------------------------------------ + + @property + def target(self) -> Target: + if self._target is None: + from openocd.target import Target + self._target = Target(self._conn) + return self._target + + @property + def memory(self) -> Memory: + if self._memory is None: + from openocd.memory import Memory + self._memory = Memory(self._conn) + return self._memory + + @property + def registers(self) -> Registers: + if self._registers is None: + from openocd.registers import Registers + self._registers = Registers(self._conn) + return self._registers + + @property + def flash(self) -> Flash: + if self._flash is None: + from openocd.flash import Flash + self._flash = Flash(self._conn) + return self._flash + + @property + def jtag(self) -> JTAGController: + if self._jtag is None: + from openocd.jtag import JTAGController + self._jtag = JTAGController(self._conn) + return self._jtag + + @property + def breakpoints(self) -> BreakpointManager: + if self._breakpoints is None: + from openocd.breakpoints import BreakpointManager + self._breakpoints = BreakpointManager(self._conn) + return self._breakpoints + + @property + def rtt(self) -> RTTManager: + if self._rtt is None: + from openocd.rtt import RTTManager + self._rtt = RTTManager(self._conn) + return self._rtt + + @property + def svd(self) -> SVDManager: + if self._svd is None: + from openocd.svd import SVDManager + self._svd = SVDManager(self._conn, self.memory) + return self._svd + + @property + def transport(self) -> Transport: + if self._transport is None: + from openocd.transport import Transport + self._transport = Transport(self._conn) + return self._transport + + # ------------------------------------------------------------------ + # Event shortcuts + # ------------------------------------------------------------------ + + def on_halt(self, callback: Callable[[str], None]) -> None: + """Register a callback for target halt events.""" + def _filter(msg: str) -> None: + if "halted" in msg.lower(): + callback(msg) + self._conn.on_notification(_filter) + + def on_reset(self, callback: Callable[[str], None]) -> None: + """Register a callback for target reset events.""" + def _filter(msg: str) -> None: + if "reset" in msg.lower(): + callback(msg) + self._conn.on_notification(_filter) + + +# ====================================================================== +# Sync wrapper +# ====================================================================== + +class SyncSession: + """Wraps an async Session for synchronous use.""" + + def __init__(self, session: Session, loop: asyncio.AbstractEventLoop) -> None: + self._session = session + self._loop = loop + + def __enter__(self) -> SyncSession: + return self + + def __exit__(self, *exc) -> None: + self._loop.run_until_complete(self._session.close()) + + def command(self, cmd: str) -> str: + return self._loop.run_until_complete(self._session.command(cmd)) + + @property + def target(self) -> SyncTarget: + from openocd.target import SyncTarget + return SyncTarget(self._session.target, self._loop) + + @property + def memory(self) -> SyncMemory: + from openocd.memory import SyncMemory + return SyncMemory(self._session.memory, self._loop) + + @property + def registers(self) -> SyncRegisters: + from openocd.registers import SyncRegisters + return SyncRegisters(self._session.registers, self._loop) + + @property + def flash(self) -> SyncFlash: + from openocd.flash import SyncFlash + return SyncFlash(self._session.flash, self._loop) + + @property + def jtag(self) -> SyncJTAGController: + from openocd.jtag import SyncJTAGController + return SyncJTAGController(self._session.jtag, self._loop) + + @property + def breakpoints(self) -> SyncBreakpointManager: + from openocd.breakpoints import SyncBreakpointManager + return SyncBreakpointManager(self._session.breakpoints, self._loop) + + @property + def svd(self) -> SyncSVDManager: + from openocd.svd import SyncSVDManager + return SyncSVDManager(self._session.svd, self._loop) + + +# ====================================================================== +# Helpers +# ====================================================================== + +def _get_or_create_loop() -> asyncio.AbstractEventLoop: + """Get the running event loop, or create a new one if there isn't one.""" + try: + loop = asyncio.get_running_loop() + # If we're already in an async context we can't use run_until_complete + raise RuntimeError( + "Cannot use sync API from an async context. " + "Use the async Session.start()/connect() instead." + ) + except RuntimeError: + pass + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop diff --git a/src/openocd/svd/__init__.py b/src/openocd/svd/__init__.py new file mode 100644 index 0000000..23facff --- /dev/null +++ b/src/openocd/svd/__init__.py @@ -0,0 +1,5 @@ +"""SVD (System View Description) integration for peripheral/register decoding.""" + +from openocd.svd.peripheral import SVDManager, SyncSVDManager + +__all__ = ["SVDManager", "SyncSVDManager"] diff --git a/src/openocd/svd/decoder.py b/src/openocd/svd/decoder.py new file mode 100644 index 0000000..638a690 --- /dev/null +++ b/src/openocd/svd/decoder.py @@ -0,0 +1,54 @@ +"""Register value decoding using SVD metadata. + +Takes a raw integer read from hardware and splits it into named bitfields +using the field definitions from a parsed SVD file. +""" + +from __future__ import annotations + +from typing import Any + +from openocd.types import BitField, DecodedRegister + + +def decode_register( + peripheral_obj: Any, + register_obj: Any, + raw_value: int, +) -> DecodedRegister: + """Decode a raw register value into named bitfields using SVD metadata. + + Args: + peripheral_obj: cmsis_svd peripheral (used for base_address and name). + register_obj: cmsis_svd register (used for fields, address_offset, name). + raw_value: The 32-bit value read from hardware. + + Returns: + A DecodedRegister with all fields extracted and annotated. + """ + address = peripheral_obj.base_address + register_obj.address_offset + fields: list[BitField] = [] + + for svd_field in register_obj.fields or []: + mask = ((1 << svd_field.bit_width) - 1) << svd_field.bit_offset + value = (raw_value & mask) >> svd_field.bit_offset + fields.append( + BitField( + name=svd_field.name, + offset=svd_field.bit_offset, + width=svd_field.bit_width, + value=value, + description=svd_field.description or "", + ) + ) + + # Sort fields by bit offset (low to high) for consistent display + fields.sort(key=lambda f: f.offset) + + return DecodedRegister( + peripheral=peripheral_obj.name, + register=register_obj.name, + address=address, + raw_value=raw_value, + fields=fields, + ) diff --git a/src/openocd/svd/parser.py b/src/openocd/svd/parser.py new file mode 100644 index 0000000..f3bfc93 --- /dev/null +++ b/src/openocd/svd/parser.py @@ -0,0 +1,128 @@ +"""SVD file loading and peripheral/register lookup. + +Wraps cmsis_svd to parse CMSIS-SVD XML files and provide indexed access +to peripherals and their registers. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from cmsis_svd import SVDParser + +from openocd.errors import SVDError + +log = logging.getLogger(__name__) + + +class SVDParserWrapper: + """Load and cache parsed SVD device data.""" + + def __init__(self) -> None: + self._device: Any = None + self._peripherals: dict[str, Any] = {} + + @property + def loaded(self) -> bool: + """Whether an SVD file has been parsed.""" + return self._device is not None + + def load(self, svd_path: Path) -> None: + """Parse an SVD file and index peripherals/registers. + + Args: + svd_path: Path to the .svd file on disk. + + Raises: + SVDError: If the file cannot be found or parsed. + """ + path = Path(svd_path) + if not path.exists(): + raise SVDError(f"SVD file not found: {path}") + + try: + parser = SVDParser.for_xml_file(str(path)) + self._device = parser.get_device() + except Exception as exc: + raise SVDError(f"Failed to parse SVD file {path}: {exc}") from exc + + self._peripherals = {p.name: p for p in self._device.get_peripherals()} + log.info( + "Loaded SVD for %s — %d peripherals", + getattr(self._device, "name", "unknown"), + len(self._peripherals), + ) + + def _require_loaded(self) -> None: + if not self.loaded: + raise SVDError("No SVD file loaded — call load() first") + + def get_peripheral(self, name: str) -> Any: + """Look up a peripheral by name. + + Args: + name: Peripheral name (case-sensitive, e.g. "GPIOA", "USART1"). + + Returns: + The cmsis_svd peripheral object. + + Raises: + SVDError: If no SVD is loaded or the peripheral is not found. + """ + self._require_loaded() + periph = self._peripherals.get(name) + if periph is None: + raise SVDError( + f"Peripheral '{name}' not found. " + f"Available: {', '.join(sorted(self._peripherals))}" + ) + return periph + + def get_register(self, peripheral: str, register: str) -> Any: + """Look up a register within a peripheral. + + Args: + peripheral: Peripheral name. + register: Register name (e.g. "CR1", "SR"). + + Returns: + The cmsis_svd register object. + + Raises: + SVDError: If the peripheral or register is not found. + """ + periph = self.get_peripheral(peripheral) + registers = periph.registers or [] + for reg in registers: + if reg.name == register: + return reg + + available = [r.name for r in registers] + raise SVDError( + f"Register '{register}' not found in {peripheral}. " + f"Available: {', '.join(sorted(available))}" + ) + + def list_peripherals(self) -> list[str]: + """Return sorted names of all peripherals in the SVD. + + Raises: + SVDError: If no SVD is loaded. + """ + self._require_loaded() + return sorted(self._peripherals.keys()) + + def list_registers(self, peripheral: str) -> list[str]: + """Return sorted register names for a peripheral. + + Args: + peripheral: Peripheral name. + + Raises: + SVDError: If the peripheral is not found. + """ + periph = self.get_peripheral(peripheral) + registers = periph.registers or [] + return sorted(r.name for r in registers) diff --git a/src/openocd/svd/peripheral.py b/src/openocd/svd/peripheral.py new file mode 100644 index 0000000..e68763b --- /dev/null +++ b/src/openocd/svd/peripheral.py @@ -0,0 +1,186 @@ +"""SVDManager — combines SVD parsing, register decoding, and hardware reads. + +This is the primary interface for SVD-based register inspection. It ties +the SVD parser, bitfield decoder, and the Memory subsystem together so +callers can do things like: + + decoded = await svd.read_register("GPIOA", "ODR") + print(decoded) +""" + +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from openocd.svd.decoder import decode_register +from openocd.svd.parser import SVDParserWrapper +from openocd.types import DecodedRegister + +if TYPE_CHECKING: + from openocd.connection.base import Connection + from openocd.memory import Memory + +log = logging.getLogger(__name__) + + +class SVDManager: + """High-level SVD register access: parse, read, decode.""" + + def __init__(self, conn: Connection, memory: Memory) -> None: + self._conn = conn + self._memory = memory + self._parser = SVDParserWrapper() + + @property + def loaded(self) -> bool: + """Whether an SVD file has been loaded.""" + return self._parser.loaded + + async def load(self, svd_path: Path) -> None: + """Parse an SVD file and make its peripherals available. + + This is a synchronous file parse wrapped in the async interface + for consistency with the rest of the API. + + Args: + svd_path: Path to the .svd XML file. + + Raises: + SVDError: If the file is missing or unparseable. + """ + self._parser.load(svd_path) + + def list_peripherals(self) -> list[str]: + """Return sorted peripheral names from the loaded SVD. + + Raises: + SVDError: If no SVD is loaded. + """ + return self._parser.list_peripherals() + + def list_registers(self, peripheral: str) -> list[str]: + """Return sorted register names for a peripheral. + + Args: + peripheral: Peripheral name (e.g. "GPIOA"). + + Raises: + SVDError: If no SVD is loaded or peripheral not found. + """ + return self._parser.list_registers(peripheral) + + async def read_register(self, peripheral: str, register: str) -> DecodedRegister: + """Read a register from hardware and decode it using SVD metadata. + + This is the primary method: it computes the register's memory-mapped + address from the SVD, reads 32 bits from the target, and returns + a fully decoded result with named bitfields. + + Args: + peripheral: Peripheral name (e.g. "GPIOA"). + register: Register name (e.g. "ODR"). + + Returns: + DecodedRegister with address, raw value, and decoded fields. + + Raises: + SVDError: If peripheral/register not found. + TargetError: If the memory read fails. + """ + periph_obj = self._parser.get_peripheral(peripheral) + reg_obj = self._parser.get_register(peripheral, register) + address = periph_obj.base_address + reg_obj.address_offset + + values = await self._memory.read_u32(address) + raw = values[0] + return decode_register(periph_obj, reg_obj, raw) + + async def read_peripheral(self, peripheral: str) -> dict[str, DecodedRegister]: + """Read and decode every register in a peripheral. + + Args: + peripheral: Peripheral name. + + Returns: + Dict mapping register name to its DecodedRegister. + + Raises: + SVDError: If peripheral not found. + TargetError: If any memory read fails. + """ + periph_obj = self._parser.get_peripheral(peripheral) + registers = periph_obj.registers or [] + result: dict[str, DecodedRegister] = {} + + for reg_obj in registers: + address = periph_obj.base_address + reg_obj.address_offset + try: + values = await self._memory.read_u32(address) + raw = values[0] + result[reg_obj.name] = decode_register(periph_obj, reg_obj, raw) + except Exception as exc: + log.warning( + "Failed to read %s.%s @ 0x%08X: %s", + peripheral, + reg_obj.name, + address, + exc, + ) + # Skip unreadable registers (write-only, reserved, etc.) + + return result + + def decode(self, peripheral: str, register: str, value: int) -> DecodedRegister: + """Decode a raw value without reading hardware. + + Useful when you already have the register value (from a log, + a previous read, or a known reset value). + + Args: + peripheral: Peripheral name. + register: Register name. + value: Raw 32-bit register value. + + Returns: + DecodedRegister with the decoded bitfields. + """ + periph_obj = self._parser.get_peripheral(peripheral) + reg_obj = self._parser.get_register(peripheral, register) + return decode_register(periph_obj, reg_obj, value) + + +class SyncSVDManager: + """Synchronous wrapper around SVDManager.""" + + def __init__(self, manager: SVDManager, loop: asyncio.AbstractEventLoop) -> None: + self._manager = manager + self._loop = loop + + @property + def loaded(self) -> bool: + return self._manager.loaded + + def load(self, svd_path: Path) -> None: + self._loop.run_until_complete(self._manager.load(svd_path)) + + def list_peripherals(self) -> list[str]: + return self._manager.list_peripherals() + + def list_registers(self, peripheral: str) -> list[str]: + return self._manager.list_registers(peripheral) + + def read_register(self, peripheral: str, register: str) -> DecodedRegister: + return self._loop.run_until_complete( + self._manager.read_register(peripheral, register) + ) + + def read_peripheral(self, peripheral: str) -> dict[str, DecodedRegister]: + return self._loop.run_until_complete( + self._manager.read_peripheral(peripheral) + ) + + def decode(self, peripheral: str, register: str, value: int) -> DecodedRegister: + return self._manager.decode(peripheral, register, value) diff --git a/src/openocd/target.py b/src/openocd/target.py new file mode 100644 index 0000000..fdab1b0 --- /dev/null +++ b/src/openocd/target.py @@ -0,0 +1,164 @@ +"""Target state control — halt, resume, step, reset, and state queries. + +Wraps the OpenOCD target commands: halt, resume, step, reset, +wait_halt, and targets (for state inspection). +""" + +from __future__ import annotations + +import asyncio +import logging +import re +from typing import Literal + +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.errors import TargetError, TimeoutError +from openocd.types import TargetState + +log = logging.getLogger(__name__) + +# Matches a target row from "targets" output, e.g.: +# " 0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted" +_TARGET_ROW_RE = re.compile( + r"^\s*\d+\*?\s+" # index, optional current marker + r"(\S+)\s+" # target name + r"\S+\s+" # type + r"\S+\s+" # endian + r"\S+\s+" # tap name + r"(\S+)" # state +) + + +class Target: + """Target execution control — halt, resume, step, reset.""" + + def __init__(self, conn: TclRpcConnection) -> None: + self._conn = conn + + async def halt(self) -> TargetState: + """Halt the target and return the resulting state.""" + resp = await self._conn.send("halt") + if "error" in resp.lower() and "already" not in resp.lower(): + raise TargetError(f"halt failed: {resp}") + return await self._parse_state() + + async def resume(self, address: int | None = None) -> None: + """Resume execution, optionally from a specific address.""" + cmd = "resume" + if address is not None: + cmd = f"resume 0x{address:x}" + resp = await self._conn.send(cmd) + if "error" in resp.lower(): + raise TargetError(f"resume failed: {resp}") + + async def step(self, address: int | None = None) -> TargetState: + """Single-step and return the resulting state.""" + cmd = "step" + if address is not None: + cmd = f"step 0x{address:x}" + resp = await self._conn.send(cmd) + if "error" in resp.lower(): + raise TargetError(f"step failed: {resp}") + return await self._parse_state() + + async def reset(self, mode: Literal["run", "halt", "init"] = "halt") -> None: + """Reset the target. + + Args: + mode: Reset mode — "run" resumes after reset, "halt" stops at + the reset vector, "init" runs init scripts after reset. + """ + resp = await self._conn.send(f"reset {mode}") + if "error" in resp.lower(): + raise TargetError(f"reset failed: {resp}") + + async def wait_halt(self, timeout_ms: int = 5000) -> TargetState: + """Block until the target halts or the timeout expires. + + Args: + timeout_ms: Maximum wait time in milliseconds. + + Raises: + TimeoutError: Target did not halt within the deadline. + """ + resp = await self._conn.send(f"wait_halt {timeout_ms}") + if "timed out" in resp.lower() or "time out" in resp.lower(): + raise TimeoutError(f"Target did not halt within {timeout_ms}ms") + if "error" in resp.lower(): + raise TargetError(f"wait_halt failed: {resp}") + return await self._parse_state() + + async def state(self) -> TargetState: + """Query and return the current target state.""" + return await self._parse_state() + + async def _parse_state(self) -> TargetState: + """Parse the ``targets`` command output into a TargetState. + + The output looks like:: + + TargetName Type Endian TapName State + -- ------------------ ---------- ------ ------------------ ------------ + 0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted + + If the target is halted, also reads the program counter via ``reg pc``. + """ + resp = await self._conn.send("targets") + + name = "unknown" + raw_state = "unknown" + + for line in resp.splitlines(): + m = _TARGET_ROW_RE.match(line) + if m: + name = m.group(1) + raw_state = m.group(2).lower() + break + + # Normalize to our known state literals + if raw_state not in ("running", "halted", "reset", "debug-running"): + raw_state = "unknown" + + pc: int | None = None + if raw_state == "halted": + try: + pc = await self._read_pc() + except Exception: + log.debug("Could not read PC while halted", exc_info=True) + + return TargetState(name=name, state=raw_state, current_pc=pc) + + async def _read_pc(self) -> int: + """Read the program counter from the halted target.""" + resp = await self._conn.send("reg pc") + # Output: "pc (/32): 0x08001234" + m = re.search(r":\s*(0x[0-9a-fA-F]+)", resp) + if not m: + raise TargetError(f"Cannot parse PC from: {resp}") + return int(m.group(1), 16) + + +class SyncTarget: + """Synchronous wrapper around Target.""" + + def __init__(self, target: Target, loop: asyncio.AbstractEventLoop) -> None: + self._target = target + self._loop = loop + + def halt(self) -> TargetState: + return self._loop.run_until_complete(self._target.halt()) + + def resume(self, address: int | None = None) -> None: + self._loop.run_until_complete(self._target.resume(address)) + + def step(self, address: int | None = None) -> TargetState: + return self._loop.run_until_complete(self._target.step(address)) + + def reset(self, mode: Literal["run", "halt", "init"] = "halt") -> None: + self._loop.run_until_complete(self._target.reset(mode)) + + def wait_halt(self, timeout_ms: int = 5000) -> TargetState: + return self._loop.run_until_complete(self._target.wait_halt(timeout_ms)) + + def state(self) -> TargetState: + return self._loop.run_until_complete(self._target.state()) diff --git a/src/openocd/transport.py b/src/openocd/transport.py new file mode 100644 index 0000000..5c9a015 --- /dev/null +++ b/src/openocd/transport.py @@ -0,0 +1,149 @@ +"""Transport selection and debug adapter configuration. + +OpenOCD supports multiple debug transports (JTAG, SWD, SWIM, etc.) +and various adapter interfaces (CMSIS-DAP, ST-Link, J-Link, etc.). +This module provides access to transport and adapter state. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from openocd.errors import OpenOCDError + +if TYPE_CHECKING: + from openocd.connection.base import Connection + +log = logging.getLogger(__name__) + + +class Transport: + """Query and configure the debug transport and adapter. + + Usage:: + + transport = Transport(conn) + current = await transport.select() # e.g. "swd" + available = await transport.list() # e.g. ["jtag", "swd"] + speed = await transport.adapter_speed() # current kHz + await transport.adapter_speed(4000) # set to 4 MHz + """ + + def __init__(self, conn: Connection) -> None: + self._conn = conn + + async def select(self) -> str: + """Get the currently selected transport. + + Returns: + Transport name string (e.g. "jtag", "swd", "swim"). + + Raises: + OpenOCDError: If the command fails. + """ + response = await self._conn.send("transport select") + response = response.strip() + if not response: + raise OpenOCDError("Empty response from 'transport select'") + return response + + async def list(self) -> list[str]: + """List transports available for the current adapter. + + Returns: + List of transport name strings. + + Raises: + OpenOCDError: If the command fails. + """ + response = await self._conn.send("transport list") + response = response.strip() + if not response: + raise OpenOCDError("Empty response from 'transport list'") + + # OpenOCD may return a Tcl list like "jtag swd" or one per line + transports: list[str] = [] + for line in response.splitlines(): + for token in line.split(): + cleaned = token.strip("{}") + if cleaned: + transports.append(cleaned) + return transports + + async def adapter_info(self) -> str: + """Get adapter/interface information. + + Tries ``adapter name`` first (newer OpenOCD), falls back to + ``adapter info`` for older versions. + + Returns: + Adapter description string. + """ + # "adapter name" is the preferred command in OpenOCD >= 0.12 + response = await self._conn.send("adapter name") + response = response.strip() + + if not response or "invalid" in response.lower() or "error" in response.lower(): + response = await self._conn.send("adapter info") + response = response.strip() + + if not response: + raise OpenOCDError("Could not determine adapter info") + return response + + async def adapter_speed(self, khz: int | None = None) -> int: + """Get or set the adapter clock speed. + + Args: + khz: If provided, set the adapter speed to this value in kHz. + If None, just query the current speed. + + Returns: + The current (or newly set) adapter speed in kHz. + + Raises: + OpenOCDError: If the command fails or response is not parseable. + """ + cmd = f"adapter speed {khz}" if khz is not None else "adapter speed" + + response = await self._conn.send(cmd) + response = response.strip() + + # OpenOCD response is typically just a number, or + # "adapter speed: 4000 kHz" depending on the interface + speed = _parse_speed(response) + if speed is None: + raise OpenOCDError(f"Cannot parse adapter speed from: {response!r}") + + if khz is not None: + log.info("Adapter speed set to %d kHz", speed) + return speed + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _parse_speed(response: str) -> int | None: + """Extract a numeric kHz value from an adapter speed response. + + Handles formats like: + "4000" + "adapter speed: 4000 kHz" + "4000 kHz" + """ + # Try the whole thing as a plain integer + try: + return int(response) + except ValueError: + pass + + # Pull out the first integer-looking token + for token in response.replace(":", " ").split(): + try: + return int(token) + except ValueError: + continue + + return None diff --git a/src/openocd/types.py b/src/openocd/types.py new file mode 100644 index 0000000..0b29be1 --- /dev/null +++ b/src/openocd/types.py @@ -0,0 +1,185 @@ +"""Shared dataclasses and enums used across the openocd-python package.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Literal + +# --------------------------------------------------------------------------- +# Target +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class TargetState: + """Snapshot of target execution state.""" + + name: str + state: Literal["running", "halted", "reset", "debug-running", "unknown"] + current_pc: int | None = None + + +# --------------------------------------------------------------------------- +# Registers +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class Register: + """A single CPU register.""" + + name: str + number: int + value: int + size: int # bits + dirty: bool = False + + +# --------------------------------------------------------------------------- +# Flash +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class FlashSector: + """One sector inside a flash bank.""" + + index: int + offset: int + size: int + protected: bool + + +@dataclass(frozen=True) +class FlashBank: + """A flash bank reported by OpenOCD.""" + + index: int + name: str + base: int + size: int + bus_width: int + chip_width: int + target: str + sectors: list[FlashSector] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# JTAG +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class TAPInfo: + """One TAP discovered on the JTAG chain.""" + + name: str + chip: str + tap_name: str + idcode: int + ir_length: int + enabled: bool + + +class JTAGState(str, Enum): + """IEEE 1149.1 TAP controller states.""" + + RESET = "RESET" + IDLE = "IDLE" + DRSELECT = "DRSELECT" + DRCAPTURE = "DRCAPTURE" + DRSHIFT = "DRSHIFT" + DREXIT1 = "DREXIT1" + DRPAUSE = "DRPAUSE" + DREXIT2 = "DREXIT2" + DRUPDATE = "DRUPDATE" + IRSELECT = "IRSELECT" + IRCAPTURE = "IRCAPTURE" + IRSHIFT = "IRSHIFT" + IREXIT1 = "IREXIT1" + IRPAUSE = "IRPAUSE" + IREXIT2 = "IREXIT2" + IRUPDATE = "IRUPDATE" + + +# --------------------------------------------------------------------------- +# Memory +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class MemoryRegion: + """A chunk of memory read from the target.""" + + address: int + size: int + data: bytes + + +# --------------------------------------------------------------------------- +# SVD +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class BitField: + """One decoded bitfield inside a register.""" + + name: str + offset: int + width: int + value: int + description: str + + +@dataclass +class DecodedRegister: + """A register value decoded into named bitfields via SVD.""" + + peripheral: str + register: str + address: int + raw_value: int + fields: list[BitField] = field(default_factory=list) + + def __str__(self) -> str: + header = f"{self.peripheral}.{self.register}" + lines = [f"{header} @ 0x{self.address:08X} = 0x{self.raw_value:08X}"] + for f in self.fields: + bits = f"{f.offset + f.width - 1}:{f.offset}" if f.width > 1 else str(f.offset) + lines.append(f" [{bits:>5s}] {f.name:<20s} = 0x{f.value:X} {f.description}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Breakpoints +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class Breakpoint: + """An active breakpoint.""" + + number: int + type: Literal["hw", "sw"] + address: int + length: int + enabled: bool + + +@dataclass(frozen=True) +class Watchpoint: + """An active watchpoint.""" + + number: int + address: int + length: int + access: Literal["r", "w", "rw"] + + +# --------------------------------------------------------------------------- +# RTT +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class RTTChannel: + """An RTT channel descriptor.""" + + index: int + name: str + size: int + direction: Literal["up", "down"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..279ebf6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +"""Shared pytest fixtures for openocd-python tests.""" +from __future__ import annotations + +import pytest + +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.session import Session +from tests.mock_server import MockOpenOCDServer + + +@pytest.fixture +async def mock_ocd(): + """Start a MockOpenOCDServer, yield (host, port), stop on teardown.""" + server = MockOpenOCDServer() + await server.start() + host, port = server.address + yield host, port, server + await server.stop() + + +@pytest.fixture +async def connection(mock_ocd): + """A TclRpcConnection connected to the mock server.""" + host, port, _server = mock_ocd + conn = TclRpcConnection(timeout=5.0) + await conn.connect(host, port) + yield conn + await conn.close() + + +@pytest.fixture +async def session(mock_ocd): + """A Session connected to the mock server via Session.connect().""" + host, port, _server = mock_ocd + sess = await Session.connect(host, port, timeout=5.0) + yield sess + await sess.close() diff --git a/tests/mock_server.py b/tests/mock_server.py new file mode 100644 index 0000000..08aca98 --- /dev/null +++ b/tests/mock_server.py @@ -0,0 +1,255 @@ +"""Fake OpenOCD TCL RPC server for testing. + +An asyncio TCP server that speaks the OpenOCD TCL RPC framing protocol: + client sends: command_string + \\x1a + server replies: response_string + \\x1a + +Supports exact-match and regex-based command routing with pre-loaded +responses that mirror real OpenOCD output. +""" +from __future__ import annotations + +import asyncio +import contextlib +import re +from collections.abc import Callable + +SEPARATOR = b"\x1a" + + +# -- Canned OpenOCD responses ------------------------------------------------ + +TARGETS_RESPONSE = """\ + TargetName Type Endian TapName State +-- ------------------ ---------- ------ ------------------ ------------ + 0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted""" + +REG_PC_RESPONSE = "pc (/32): 0x08001234" +REG_SP_RESPONSE = "sp (/32): 0x20005000" +REG_LR_RESPONSE = "lr (/32): 0x08000100" +REG_XPSR_RESPONSE = "xPSR (/32): 0x61000000" + +REG_ALL_RESPONSE = """\ +===== ARM registers +(0) r0 (/32): 0x00000000 +(1) r1 (/32): 0x00000001 +(2) r2 (/32): 0x20001000 +(3) r3 (/32): 0x00000003 +(4) r4 (/32): 0x00000000 +(5) r5 (/32): 0x00000000 +(6) r6 (/32): 0x00000000 +(7) r7 (/32): 0x20004FF0 +(8) r8 (/32): 0x00000000 +(9) r9 (/32): 0x00000000 +(10) r10 (/32): 0x00000000 +(11) r11 (/32): 0x00000000 +(12) r12 (/32): 0x00000000 +(13) sp (/32): 0x20005000 +(14) lr (/32): 0x08000100 +(15) pc (/32): 0x08001234 +(16) xPSR (/32): 0x61000000 +(17) msp (/32): 0x20005000 +(18) psp (/32): 0x00000000 +(19) primask (/1): 0x00 +(20) basepri (/8): 0x00 +(21) faultmask (/1): 0x00 +(22) control (/3): 0x00 (dirty)""" + +READ_MEMORY_RESPONSE = "20005000 080001a1 080001ab 080001ad" + +FLASH_BANKS_RESPONSE = ( + "#0 : stm32f1x.flash (stm32f1x) at 0x08000000," + " size 0x00020000, buswidth 0, chipwidth 0" +) + +SCAN_CHAIN_RESPONSE = """\ + TapName Enabled IdCode Expected IrLen IrCap IrMask +-- ------------------- -------- ---------- ---------- ----- ----- ------ + 0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f""" + +BP_LIST_RESPONSE = """\ +Breakpoint(IVA): 0x08001234, 0x2, 1 +Breakpoint(IVA): 0x08001300, 0x2, 0""" + +RTT_CHANNELS_RESPONSE = """\ +Up-channels: + 0: Terminal 1024 + 1: Log 512 +Down-channels: + 0: Terminal 16""" + +TRANSPORT_SELECT_RESPONSE = "swd" +TRANSPORT_LIST_RESPONSE = "jtag swd" +ADAPTER_SPEED_RESPONSE = "4000" + + +def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[str], str]]]: + """Build the default command-to-response routing table. + + Returns a list of (compiled_regex, response) pairs. The first match wins. + Response can be a string or a callable that receives the full command. + """ + routes: list[tuple[re.Pattern[str], str | Callable[[str], str]]] = [ + # target state control + (re.compile(r"^targets$"), TARGETS_RESPONSE), + (re.compile(r"^halt$"), ""), + (re.compile(r"^resume"), ""), + (re.compile(r"^step"), ""), + (re.compile(r"^reset\s+"), ""), + (re.compile(r"^wait_halt"), ""), + + # individual register reads (must come before bare "reg") + (re.compile(r"^reg\s+pc$"), REG_PC_RESPONSE), + (re.compile(r"^reg\s+sp$"), REG_SP_RESPONSE), + (re.compile(r"^reg\s+lr$"), REG_LR_RESPONSE), + (re.compile(r"^reg\s+xPSR$"), REG_XPSR_RESPONSE), + # register write (reg ) + (re.compile(r"^reg\s+\S+\s+0x"), ""), + # bare "reg" -> full listing + (re.compile(r"^reg$"), REG_ALL_RESPONSE), + + # memory + (re.compile(r"^read_memory\s+0x8000000\s+32\s+4$"), READ_MEMORY_RESPONSE), + # generic read_memory -- return zeros for widths/counts we haven't mapped + (re.compile(r"^read_memory\s+"), _generic_read_memory), + (re.compile(r"^write_memory\s+"), ""), + + # flash + (re.compile(r"^flash banks$"), FLASH_BANKS_RESPONSE), + (re.compile(r"^flash\s+"), ""), + + # JTAG + (re.compile(r"^scan_chain$"), SCAN_CHAIN_RESPONSE), + (re.compile(r"^irscan\s+"), "0x01"), + (re.compile(r"^drscan\s+"), "0xDEADBEEF"), + (re.compile(r"^runtest\s+"), ""), + (re.compile(r"^pathmove\s+"), ""), + + # breakpoints + (re.compile(r"^bp\s+0x"), ""), + (re.compile(r"^bp$"), BP_LIST_RESPONSE), + (re.compile(r"^rbp\s+"), ""), + (re.compile(r"^wp\s+0x"), ""), + (re.compile(r"^wp$"), ""), + (re.compile(r"^rwp\s+"), ""), + + # transport / adapter + (re.compile(r"^transport\s+select$"), TRANSPORT_SELECT_RESPONSE), + (re.compile(r"^transport\s+list$"), TRANSPORT_LIST_RESPONSE), + (re.compile(r"^adapter\s+speed$"), ADAPTER_SPEED_RESPONSE), + (re.compile(r"^adapter\s+speed\s+\d+"), ADAPTER_SPEED_RESPONSE), + (re.compile(r"^adapter\s+name$"), "cmsis-dap"), + + # RTT + (re.compile(r"^rtt\s+channels$"), RTT_CHANNELS_RESPONSE), + (re.compile(r"^rtt\s+setup\s+"), ""), + (re.compile(r"^rtt\s+start$"), ""), + (re.compile(r"^rtt\s+stop$"), ""), + (re.compile(r"^rtt\s+channelread\s+"), "hello from target"), + (re.compile(r"^rtt\s+channelwrite\s+"), ""), + + # notifications + (re.compile(r"^tcl_notifications\s+"), ""), + ] + return routes + + +def _generic_read_memory(cmd: str) -> str: + """Generate a plausible response for an arbitrary read_memory command. + + Parses the count from the command and returns that many hex zeros. + """ + parts = cmd.split() + # read_memory + count = 1 + if len(parts) >= 4: + with contextlib.suppress(ValueError): + count = int(parts[3]) + return " ".join(["00"] * count) + + +class MockOpenOCDServer: + """Asyncio TCP server that fakes OpenOCD TCL RPC responses. + + Usage:: + + server = MockOpenOCDServer() + await server.start() + host, port = server.address + # ... connect and test ... + await server.stop() + """ + + def __init__(self) -> None: + self._server: asyncio.Server | None = None + self._routes = _build_default_responses() + self._host = "127.0.0.1" + self._port = 0 # OS picks a free port + # Track raw commands received, useful for assertions + self.received_commands: list[str] = [] + + @property + def address(self) -> tuple[str, int]: + """Return (host, port) the server is listening on.""" + if self._server is None: + raise RuntimeError("Server not started") + sock = self._server.sockets[0] + return sock.getsockname()[:2] + + def add_response(self, pattern: str, response: str | Callable[[str], str]) -> None: + """Prepend a custom response rule (takes priority over defaults).""" + self._routes.insert(0, (re.compile(pattern), response)) + + async def start(self) -> None: + self._server = await asyncio.start_server( + self._handle_client, self._host, self._port + ) + await self._server.start_serving() + + async def stop(self) -> None: + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + + async def _handle_client( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + """Handle one client connection, reading commands and sending responses.""" + buf = bytearray() + try: + while True: + chunk = await reader.read(4096) + if not chunk: + break + buf.extend(chunk) + + # Process all complete commands in the buffer + while True: + idx = buf.find(SEPARATOR) + if idx == -1: + break + command = bytes(buf[:idx]).decode("utf-8", errors="replace") + buf = buf[idx + 1 :] + + self.received_commands.append(command) + response = self._resolve(command) + + writer.write(response.encode("utf-8") + SEPARATOR) + await writer.drain() + except (asyncio.CancelledError, ConnectionResetError, BrokenPipeError): + pass + finally: + writer.close() + with contextlib.suppress(OSError): + await writer.wait_closed() + + def _resolve(self, command: str) -> str: + """Find the first matching route and return its response.""" + for pattern, response in self._routes: + if pattern.search(command): + if callable(response): + return response(command) + return response + # Unrecognized command returns empty (success) + return "" diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..d149c36 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,113 @@ +"""Tests for the TclRpcConnection class.""" +from __future__ import annotations + +import asyncio + +import pytest + +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.errors import ConnectionError, TimeoutError + + +async def test_connect_to_mock_server(mock_ocd): + """Verify we can open a connection to the mock server.""" + host, port, _server = mock_ocd + conn = TclRpcConnection(timeout=5.0) + await conn.connect(host, port) + assert conn._writer is not None + assert conn._reader is not None + await conn.close() + + +async def test_send_and_receive(connection): + """Send a command and verify we get the expected response.""" + resp = await connection.send("targets") + assert "stm32f1x.cpu" in resp + assert "halted" in resp + + +async def test_separator_framing(mock_ocd): + """Verify the \\x1a framing works for multiple sequential commands.""" + host, port, _server = mock_ocd + conn = TclRpcConnection(timeout=5.0) + await conn.connect(host, port) + + # Send several commands in sequence; each should get its own response + resp1 = await conn.send("halt") + resp2 = await conn.send("reg pc") + resp3 = await conn.send("targets") + + # halt returns empty + assert resp1 == "" + # reg pc returns a value + assert "0x08001234" in resp2 + # targets returns the state table + assert "stm32f1x.cpu" in resp3 + + await conn.close() + + +async def test_connection_error_no_server(): + """Connecting to a port with no listener should raise ConnectionError.""" + conn = TclRpcConnection(timeout=1.0) + with pytest.raises(ConnectionError): + await conn.connect("127.0.0.1", 1) # port 1 is unlikely to be open + + +async def test_send_before_connect_raises(): + """Sending a command before connect() should raise ConnectionError.""" + conn = TclRpcConnection() + with pytest.raises(ConnectionError, match="Not connected"): + await conn.send("targets") + + +async def test_timeout_on_hung_server(): + """A server that never sends \\x1a should trigger a TimeoutError.""" + # Start a server that accepts connections but never responds + async def _hang(reader, writer): + # Read the command but never reply + await reader.read(4096) + await asyncio.sleep(60) + + server = await asyncio.start_server(_hang, "127.0.0.1", 0) + await server.start_serving() + sock = server.sockets[0] + host, port = sock.getsockname()[:2] + + conn = TclRpcConnection(timeout=0.3) + await conn.connect(host, port) + + with pytest.raises(TimeoutError): + await conn.send("targets") + + await conn.close() + server.close() + await server.wait_closed() + + +async def test_close_idempotent(connection): + """Calling close() multiple times should not raise.""" + await connection.close() + await connection.close() # second call is a no-op + + +async def test_concurrent_commands(mock_ocd): + """Multiple coroutines sharing one connection should serialize properly.""" + host, port, _server = mock_ocd + conn = TclRpcConnection(timeout=5.0) + await conn.connect(host, port) + + async def _do_command(cmd: str) -> str: + return await conn.send(cmd) + + results = await asyncio.gather( + _do_command("reg pc"), + _do_command("reg sp"), + _do_command("reg lr"), + ) + + assert "0x08001234" in results[0] + assert "0x20005000" in results[1] + assert "0x08000100" in results[2] + + await conn.close() diff --git a/tests/test_jtag.py b/tests/test_jtag.py new file mode 100644 index 0000000..b639257 --- /dev/null +++ b/tests/test_jtag.py @@ -0,0 +1,93 @@ +"""Tests for the JTAG subsystem.""" +from __future__ import annotations + +import pytest + +from openocd.types import TAPInfo + + +async def test_scan_chain(session): + """scan_chain() should return a list of TAPInfo objects.""" + taps = await session.jtag.scan_chain() + assert isinstance(taps, list) + assert len(taps) == 1 + + +async def test_scan_chain_tap_fields(session): + """The returned TAPInfo should have all fields populated correctly.""" + taps = await session.jtag.scan_chain() + tap = taps[0] + assert isinstance(tap, TAPInfo) + assert tap.name == "stm32f1x.cpu" + assert tap.chip == "stm32f1x" + assert tap.tap_name == "cpu" + assert tap.idcode == 0x3BA00477 + assert tap.ir_length == 4 + assert tap.enabled is True + + +async def test_scan_chain_frozen(session): + """TAPInfo should be immutable (frozen dataclass).""" + taps = await session.jtag.scan_chain() + tap = taps[0] + with pytest.raises(AttributeError): + tap.name = "something_else" # type: ignore[misc] + + +async def test_irscan(session): + """irscan should return the shifted-out value as an int.""" + result = await session.jtag.irscan("stm32f1x.cpu", 0x0E) + assert isinstance(result, int) + assert result == 0x01 + + +async def test_drscan(session): + """drscan should return the shifted-out value as an int.""" + result = await session.jtag.drscan("stm32f1x.cpu", 32, 0x00000000) + assert isinstance(result, int) + assert result == 0xDEADBEEF + + +async def test_runtest(session): + """runtest should complete without error.""" + await session.jtag.runtest(100) + + +async def test_scan_chain_parsing_multiple_taps(mock_ocd): + """Verify the parser handles multiple TAPs in scan_chain output.""" + from openocd.jtag.chain import _parse_scan_chain + + raw = """\ + TapName Enabled IdCode Expected IrLen IrCap IrMask +-- ------------------- -------- ---------- ---------- ----- ----- ------ + 0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f + 1 stm32f1x.bs N 0x06433041 0x06433041 5 0x01 0x1f""" + + taps = _parse_scan_chain(raw) + assert len(taps) == 2 + + assert taps[0].name == "stm32f1x.cpu" + assert taps[0].enabled is True + assert taps[0].idcode == 0x3BA00477 + assert taps[0].ir_length == 4 + + assert taps[1].name == "stm32f1x.bs" + assert taps[1].chip == "stm32f1x" + assert taps[1].tap_name == "bs" + assert taps[1].enabled is False + assert taps[1].idcode == 0x06433041 + assert taps[1].ir_length == 5 + + +def test_parse_scan_chain_empty(): + """An empty scan_chain output should return an empty list.""" + from openocd.jtag.chain import _parse_scan_chain + + result = _parse_scan_chain("") + assert result == [] + + result = _parse_scan_chain( + " TapName Enabled IdCode Expected IrLen IrCap IrMask\n" + "-- ------------------- -------- ---------- ---------- ----- ----- ------\n" + ) + assert result == [] diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000..b02ecba --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,93 @@ +"""Tests for the Memory subsystem.""" +from __future__ import annotations + + +async def test_read_u32(session): + """read_u32 should return a list[int] with correctly parsed hex values.""" + values = await session.memory.read_u32(0x8000000, 4) + assert isinstance(values, list) + assert len(values) == 4 + assert values[0] == 0x20005000 + assert values[1] == 0x080001A1 + assert values[2] == 0x080001AB + assert values[3] == 0x080001AD + + +async def test_read_u32_single(session): + """read_u32 with count=1 should return a single-element list.""" + values = await session.memory.read_u32(0x20000000, 1) + assert isinstance(values, list) + assert len(values) == 1 + + +async def test_read_u8(session): + """read_u8 should return a list of 8-bit values.""" + values = await session.memory.read_u8(0x20000000, 4) + assert isinstance(values, list) + assert len(values) == 4 + # Generic mock returns zeros for unregistered addresses + assert all(isinstance(v, int) for v in values) + + +async def test_read_u16(session): + """read_u16 should return a list of 16-bit values.""" + values = await session.memory.read_u16(0x20000000, 2) + assert isinstance(values, list) + assert len(values) == 2 + + +async def test_read_bytes(session): + """read_bytes should return a bytes object of the requested size.""" + data = await session.memory.read_bytes(0x20000000, 8) + assert isinstance(data, bytes) + assert len(data) == 8 + + +async def test_write_u32(session): + """write_u32 should complete without error.""" + await session.memory.write_u32(0x20000000, [0xDEADBEEF, 0xCAFEBABE]) + + +async def test_write_u32_single_value(session): + """write_u32 with a single int should complete without error.""" + await session.memory.write_u32(0x20000000, 0x12345678) + + +async def test_write_bytes(session): + """write_bytes should complete without error.""" + await session.memory.write_bytes(0x20000000, b"\x00\x01\x02\x03") + + +async def test_hexdump_format(session): + """hexdump should return a properly formatted hex+ASCII dump.""" + dump = await session.memory.hexdump(0x20000000, 32) + assert isinstance(dump, str) + lines = dump.strip().splitlines() + assert len(lines) == 2 # 32 bytes / 16 bytes per line = 2 lines + + # Each line should start with an address + assert lines[0].startswith("20000000:") + assert lines[1].startswith("20000010:") + + # Each line should contain the ASCII column delimited by pipes + for line in lines: + assert "|" in line + + +async def test_hexdump_ascii_column(session): + """Hexdump ASCII column should use dots for non-printable bytes.""" + dump = await session.memory.hexdump(0x20000000, 16) + # The mock returns all zeros, which are non-printable + assert "|" in dump + # Extract the ASCII portion between the pipes + ascii_part = dump.split("|")[1] + # All-zero bytes map to dots + assert all(c == "." for c in ascii_part) + + +async def test_read_u32_returns_ints(session): + """All values from read_u32 should be Python ints.""" + values = await session.memory.read_u32(0x8000000, 4) + for v in values: + assert isinstance(v, int) + assert v >= 0 diff --git a/tests/test_registers.py b/tests/test_registers.py new file mode 100644 index 0000000..1616e51 --- /dev/null +++ b/tests/test_registers.py @@ -0,0 +1,107 @@ +"""Tests for the Registers subsystem.""" +from __future__ import annotations + +from openocd.types import Register + + +async def test_read_pc(session): + """read('pc') should return the correct value from the mock.""" + val = await session.registers.read("pc") + assert val == 0x08001234 + + +async def test_read_sp(session): + """read('sp') should return the correct value.""" + val = await session.registers.read("sp") + assert val == 0x20005000 + + +async def test_read_lr(session): + """read('lr') should return the correct value.""" + val = await session.registers.read("lr") + assert val == 0x08000100 + + +async def test_read_xpsr(session): + """read('xPSR') should return the correct value.""" + val = await session.registers.read("xPSR") + assert val == 0x61000000 + + +async def test_read_all(session): + """read_all() should return a dict of Register objects keyed by name.""" + regs = await session.registers.read_all() + assert isinstance(regs, dict) + assert len(regs) > 0 + + # Spot-check a few registers + assert "pc" in regs + assert "sp" in regs + assert "r0" in regs + assert "xPSR" in regs + + +async def test_read_all_register_type(session): + """Each value in read_all() should be a Register dataclass.""" + regs = await session.registers.read_all() + for name, reg in regs.items(): + assert isinstance(reg, Register) + assert reg.name == name + assert isinstance(reg.number, int) + assert isinstance(reg.value, int) + assert isinstance(reg.size, int) + assert isinstance(reg.dirty, bool) + + +async def test_read_all_pc_value(session): + """The pc register from read_all() should have the correct value.""" + regs = await session.registers.read_all() + pc = regs["pc"] + assert pc.value == 0x08001234 + assert pc.size == 32 + assert pc.number == 15 + + +async def test_read_all_dirty_flag(session): + """The control register should have dirty=True in our mock data.""" + regs = await session.registers.read_all() + control = regs["control"] + assert control.dirty is True + + +async def test_convenience_pc(session): + """The pc() convenience method should match read('pc').""" + val = await session.registers.pc() + assert val == 0x08001234 + + +async def test_convenience_sp(session): + """The sp() convenience method should match read('sp').""" + val = await session.registers.sp() + assert val == 0x20005000 + + +async def test_convenience_lr(session): + """The lr() convenience method should match read('lr').""" + val = await session.registers.lr() + assert val == 0x08000100 + + +async def test_convenience_xpsr(session): + """The xpsr() convenience method should match read('xPSR').""" + val = await session.registers.xpsr() + assert val == 0x61000000 + + +async def test_write(session): + """write() should complete without error.""" + await session.registers.write("r0", 0xDEADBEEF) + + +async def test_read_many(session): + """read_many() should return values for all requested registers.""" + results = await session.registers.read_many(["pc", "sp", "lr"]) + assert len(results) == 3 + assert results["pc"] == 0x08001234 + assert results["sp"] == 0x20005000 + assert results["lr"] == 0x08000100 diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..2943bec --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,98 @@ +"""Tests for the Session class.""" +from __future__ import annotations + +import pytest + +from openocd.breakpoints import BreakpointManager +from openocd.flash import Flash +from openocd.jtag import JTAGController +from openocd.memory import Memory +from openocd.registers import Registers +from openocd.rtt import RTTManager +from openocd.session import Session +from openocd.svd import SVDManager +from openocd.target import Target +from openocd.transport import Transport + + +async def test_connect_to_mock(mock_ocd): + """Session.connect() should successfully connect to the mock server.""" + host, port, _server = mock_ocd + sess = await Session.connect(host, port, timeout=5.0) + assert sess is not None + await sess.close() + + +async def test_raw_command(session): + """session.command() should pass through to the underlying connection.""" + resp = await session.command("targets") + assert "stm32f1x.cpu" in resp + + +async def test_context_manager(mock_ocd): + """Session should work as an async context manager.""" + host, port, _server = mock_ocd + async with await Session.connect(host, port, timeout=5.0) as sess: + resp = await sess.command("halt") + assert resp == "" + # After exiting the context, the connection is closed. + # Attempting to send should raise. + from openocd.errors import ConnectionError + with pytest.raises(ConnectionError): + await sess.command("targets") + + +async def test_subsystem_target_type(session): + """session.target should return a Target instance.""" + assert isinstance(session.target, Target) + + +async def test_subsystem_memory_type(session): + """session.memory should return a Memory instance.""" + assert isinstance(session.memory, Memory) + + +async def test_subsystem_registers_type(session): + """session.registers should return a Registers instance.""" + assert isinstance(session.registers, Registers) + + +async def test_subsystem_flash_type(session): + """session.flash should return a Flash instance.""" + assert isinstance(session.flash, Flash) + + +async def test_subsystem_jtag_type(session): + """session.jtag should return a JTAGController instance.""" + assert isinstance(session.jtag, JTAGController) + + +async def test_subsystem_breakpoints_type(session): + """session.breakpoints should return a BreakpointManager instance.""" + assert isinstance(session.breakpoints, BreakpointManager) + + +async def test_subsystem_rtt_type(session): + """session.rtt should return an RTTManager instance.""" + assert isinstance(session.rtt, RTTManager) + + +async def test_subsystem_svd_type(session): + """session.svd should return an SVDManager instance.""" + assert isinstance(session.svd, SVDManager) + + +async def test_subsystem_transport_type(session): + """session.transport should return a Transport instance.""" + assert isinstance(session.transport, Transport) + + +async def test_subsystem_lazy_initialization(session): + """Accessing the same property twice should return the same object.""" + t1 = session.target + t2 = session.target + assert t1 is t2 + + m1 = session.memory + m2 = session.memory + assert m1 is m2 diff --git a/tests/test_svd.py b/tests/test_svd.py new file mode 100644 index 0000000..8c759cb --- /dev/null +++ b/tests/test_svd.py @@ -0,0 +1,249 @@ +"""Tests for SVD decoding (no hardware required). + +These tests exercise the bitfield decoder and DecodedRegister formatting +using synthetic data, without needing an SVD file or a mock server. +""" +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from openocd.svd.decoder import decode_register +from openocd.types import BitField, DecodedRegister + +# -- Fake SVD objects to avoid needing a real .svd file ----------------------- + +@dataclass +class FakeSVDField: + name: str + bit_offset: int + bit_width: int + description: str + + +@dataclass +class FakeSVDRegister: + name: str + address_offset: int + fields: list[FakeSVDField] + + +@dataclass +class FakeSVDPeripheral: + name: str + base_address: int + registers: list[FakeSVDRegister] + + +@pytest.fixture +def gpioa_odr(): + """A fake GPIOA.ODR register with two bitfields.""" + fields = [ + FakeSVDField( + name="ODR0", bit_offset=0, bit_width=1, + description="Port output data bit 0", + ), + FakeSVDField( + name="ODR1", bit_offset=1, bit_width=1, + description="Port output data bit 1", + ), + FakeSVDField( + name="ODR15_2", bit_offset=2, bit_width=14, + description="Port output data bits 15:2", + ), + ] + register = FakeSVDRegister(name="ODR", address_offset=0x14, fields=fields) + peripheral = FakeSVDPeripheral( + name="GPIOA", base_address=0x40010800, registers=[register] + ) + return peripheral, register + + +@pytest.fixture +def usart_cr1(): + """A fake USART1.CR1 register with multiple bitfields.""" + fields = [ + FakeSVDField(name="UE", bit_offset=0, bit_width=1, description="USART enable"), + FakeSVDField(name="RE", bit_offset=2, bit_width=1, description="Receiver enable"), + FakeSVDField(name="TE", bit_offset=3, bit_width=1, description="Transmitter enable"), + FakeSVDField( + name="RXNEIE", bit_offset=5, bit_width=1, + description="RXNE interrupt enable", + ), + FakeSVDField( + name="TCIE", bit_offset=6, bit_width=1, + description="Transmission complete IE", + ), + FakeSVDField( + name="TXEIE", bit_offset=7, bit_width=1, + description="TXE interrupt enable", + ), + FakeSVDField(name="M", bit_offset=12, bit_width=1, description="Word length"), + FakeSVDField( + name="OVER8", bit_offset=15, bit_width=1, + description="Oversampling mode", + ), + ] + register = FakeSVDRegister(name="CR1", address_offset=0x0C, fields=fields) + peripheral = FakeSVDPeripheral( + name="USART1", base_address=0x40013800, registers=[register] + ) + return peripheral, register + + +def test_decode_register_basic(gpioa_odr): + """decode_register should extract bitfield values from a raw integer.""" + peripheral, register = gpioa_odr + # Set ODR0=1, ODR1=0, ODR15_2 = 0x0005 (bits 2..15 = 0b00000000010100 = 0x14 shifted) + raw = 0b0000000000010101 # ODR0=1, ODR1=0, ODR15_2=0x0005 + decoded = decode_register(peripheral, register, raw) + + assert isinstance(decoded, DecodedRegister) + assert decoded.peripheral == "GPIOA" + assert decoded.register == "ODR" + assert decoded.address == 0x40010800 + 0x14 + assert decoded.raw_value == raw + + +def test_decode_bitfield_extraction(gpioa_odr): + """Individual bitfield values should be correctly masked and shifted.""" + peripheral, register = gpioa_odr + raw = 0b0000000000010101 + decoded = decode_register(peripheral, register, raw) + + fields_by_name = {f.name: f for f in decoded.fields} + + assert fields_by_name["ODR0"].value == 1 + assert fields_by_name["ODR1"].value == 0 + assert fields_by_name["ODR15_2"].value == 0x0005 + + +def test_decode_all_ones(gpioa_odr): + """All-ones value should set all fields to their max values.""" + peripheral, register = gpioa_odr + raw = 0xFFFF + decoded = decode_register(peripheral, register, raw) + + fields_by_name = {f.name: f for f in decoded.fields} + + assert fields_by_name["ODR0"].value == 1 + assert fields_by_name["ODR1"].value == 1 + assert fields_by_name["ODR15_2"].value == (1 << 14) - 1 # 0x3FFF + + +def test_decode_all_zeros(gpioa_odr): + """All-zeros value should yield all-zero fields.""" + peripheral, register = gpioa_odr + decoded = decode_register(peripheral, register, 0x0000) + + for field in decoded.fields: + assert field.value == 0 + + +def test_bitfield_type(gpioa_odr): + """Each field in a DecodedRegister should be a BitField dataclass.""" + peripheral, register = gpioa_odr + decoded = decode_register(peripheral, register, 0xAAAA) + + for field in decoded.fields: + assert isinstance(field, BitField) + assert isinstance(field.name, str) + assert isinstance(field.offset, int) + assert isinstance(field.width, int) + assert isinstance(field.value, int) + assert isinstance(field.description, str) + + +def test_fields_sorted_by_offset(gpioa_odr): + """Decoded fields should be sorted by bit offset (low to high).""" + peripheral, register = gpioa_odr + decoded = decode_register(peripheral, register, 0x1234) + + offsets = [f.offset for f in decoded.fields] + assert offsets == sorted(offsets) + + +def test_decoded_register_str(gpioa_odr): + """__str__ should produce a multi-line representation with field details.""" + peripheral, register = gpioa_odr + raw = 0b0000000000010101 + decoded = decode_register(peripheral, register, raw) + + text = str(decoded) + assert "GPIOA.ODR" in text + assert "0X40010814" in text.upper() + assert "ODR0" in text + assert "ODR1" in text + assert "ODR15_2" in text + + +def test_decoded_register_str_shows_values(gpioa_odr): + """The string representation should show each field's hex value.""" + peripheral, register = gpioa_odr + decoded = decode_register(peripheral, register, 0x0001) + + text = str(decoded) + # ODR0 = 1 should appear as "0x1" + assert "0x1" in text + + +def test_decode_complex_register(usart_cr1): + """Decode a multi-field register and verify specific field values.""" + peripheral, register = usart_cr1 + # UE=1, RE=1, TE=1 -> bits 0,2,3 set -> raw = 0x000D + raw = 0x000D + decoded = decode_register(peripheral, register, raw) + + fields_by_name = {f.name: f for f in decoded.fields} + + assert fields_by_name["UE"].value == 1 + assert fields_by_name["RE"].value == 1 + assert fields_by_name["TE"].value == 1 + assert fields_by_name["RXNEIE"].value == 0 + assert fields_by_name["M"].value == 0 + assert fields_by_name["OVER8"].value == 0 + + +def test_decode_address_calculation(usart_cr1): + """The decoded address should be base + offset.""" + peripheral, register = usart_cr1 + decoded = decode_register(peripheral, register, 0) + assert decoded.address == 0x40013800 + 0x0C + + +def test_decoded_register_fields_list(gpioa_odr): + """fields should be a plain list, not some other iterable.""" + peripheral, register = gpioa_odr + decoded = decode_register(peripheral, register, 0) + assert isinstance(decoded.fields, list) + assert len(decoded.fields) == 3 + + +def test_bitfield_frozen(): + """BitField should be immutable (frozen dataclass).""" + bf = BitField(name="TEST", offset=0, width=1, value=1, description="test") + with pytest.raises(AttributeError): + bf.value = 2 # type: ignore[misc] + + +def test_decoded_register_str_single_bit_range(gpioa_odr): + """Single-bit fields should show just the bit number, not a range.""" + peripheral, register = gpioa_odr + decoded = decode_register(peripheral, register, 0x0001) + text = str(decoded) + # ODR0 is at offset 0, width 1 -> should show "[ 0]" not "[0:0]" + lines = text.strip().splitlines() + # Find the ODR0 line + odr0_line = [ln for ln in lines if "ODR0 " in ln or "ODR0" in ln.split()][0] + assert "0]" in odr0_line + + +def test_decoded_register_str_multi_bit_range(gpioa_odr): + """Multi-bit fields should show a bit range like [15:2].""" + peripheral, register = gpioa_odr + decoded = decode_register(peripheral, register, 0xFFFF) + text = str(decoded) + lines = text.strip().splitlines() + odr15_2_line = [ln for ln in lines if "ODR15_2" in ln][0] + assert "15:2" in odr15_2_line diff --git a/tests/test_target.py b/tests/test_target.py new file mode 100644 index 0000000..3788f29 --- /dev/null +++ b/tests/test_target.py @@ -0,0 +1,74 @@ +"""Tests for the Target subsystem.""" +from __future__ import annotations + +import pytest + +from openocd.types import TargetState + + +async def test_state_returns_target_state(session): + """target.state() should return a TargetState with correct fields.""" + state = await session.target.state() + assert isinstance(state, TargetState) + assert state.name == "stm32f1x.cpu" + assert state.state == "halted" + # When halted, the mock returns pc = 0x08001234 + assert state.current_pc == 0x08001234 + + +async def test_halt(session): + """target.halt() should return a TargetState.""" + state = await session.target.halt() + assert isinstance(state, TargetState) + assert state.state == "halted" + + +async def test_resume(session): + """target.resume() should complete without error.""" + await session.target.resume() + + +async def test_resume_with_address(session): + """target.resume(address=...) should complete without error.""" + await session.target.resume(address=0x08000000) + + +async def test_step(session): + """target.step() should return a TargetState.""" + state = await session.target.step() + assert isinstance(state, TargetState) + + +async def test_step_with_address(session): + """target.step(address=...) should complete without error.""" + state = await session.target.step(address=0x08001234) + assert isinstance(state, TargetState) + + +async def test_reset_halt(session): + """target.reset('halt') should complete without error.""" + await session.target.reset("halt") + + +async def test_reset_run(session): + """target.reset('run') should complete without error.""" + await session.target.reset("run") + + +async def test_reset_init(session): + """target.reset('init') should complete without error.""" + await session.target.reset("init") + + +async def test_state_pc_field(session): + """When halted, current_pc should be populated from reg pc.""" + state = await session.target.state() + assert state.current_pc is not None + assert state.current_pc == 0x08001234 + + +async def test_state_frozen_dataclass(session): + """TargetState should be immutable (frozen dataclass).""" + state = await session.target.state() + with pytest.raises(AttributeError): + state.name = "something_else" # type: ignore[misc]