diff --git a/src/openocd/connection/tcl_rpc.py b/src/openocd/connection/tcl_rpc.py index 313057b..6b0a9a6 100644 --- a/src/openocd/connection/tcl_rpc.py +++ b/src/openocd/connection/tcl_rpc.py @@ -4,6 +4,13 @@ OpenOCD's TCL RPC uses a simple framing protocol: - Client sends: command_string + \\x1a - Server replies: response_string + \\x1a The \\x1a (ASCII SUB / Ctrl-Z) byte acts as an unambiguous delimiter. + +Notifications use a **separate connection** to avoid dual-reader race +conditions on the command stream. When ``enable_notifications()`` is +called, a second TCP connection is opened to the same host:port. That +connection sends ``tcl_notifications on`` and then exclusively reads +unsolicited events, leaving the primary connection free for +request/response commands. """ from __future__ import annotations @@ -21,21 +28,34 @@ log = logging.getLogger(__name__) SEPARATOR = b"\x1a" DEFAULT_TIMEOUT = 10.0 +MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10 MB — guard against runaway reads class TclRpcConnection(Connection): - """Async TCP client speaking OpenOCD's TCL RPC protocol.""" + """Async TCP client speaking OpenOCD's TCL RPC protocol. + + The command connection and the notification connection are kept on + **separate sockets** so that unsolicited events never corrupt + the request/response stream. + """ def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None: + # Primary command connection self._reader: asyncio.StreamReader | None = None self._writer: asyncio.StreamWriter | None = None self._timeout = timeout - self._notification_callbacks: list[Callable[[str], None]] = [] - self._notification_task: asyncio.Task[None] | None = None self._lock = asyncio.Lock() + self._remainder = bytearray() # leftover bytes after separator self._host: str = "" self._port: int = 0 + # Notification connection (separate socket) + self._notif_reader: asyncio.StreamReader | None = None + self._notif_writer: asyncio.StreamWriter | None = None + self._notification_callbacks: list[Callable[[str], None]] = [] + self._notification_task: asyncio.Task[None] | None = None + self._notification_failed: bool = False + # ------------------------------------------------------------------ # Connection lifecycle # ------------------------------------------------------------------ @@ -59,18 +79,28 @@ class TclRpcConnection(Connection): log.debug("Connected to OpenOCD TCL RPC at %s:%d", host, port) async def close(self) -> None: + # Tear down notification connection first if self._notification_task and not self._notification_task.done(): self._notification_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._notification_task self._notification_task = None + if self._notif_writer: + self._notif_writer.close() + with contextlib.suppress(OSError): + await self._notif_writer.wait_closed() + self._notif_writer = None + self._notif_reader = None + + # Tear down primary command connection if self._writer: self._writer.close() with contextlib.suppress(OSError): await self._writer.wait_closed() self._writer = None self._reader = None + self._remainder.clear() log.debug("TCL RPC connection closed") # ------------------------------------------------------------------ @@ -86,6 +116,9 @@ class TclRpcConnection(Connection): if not self._writer or not self._reader: raise ConnectionError("Not connected — call connect() first") + if self._notification_failed: + log.warning("Notification listener has stopped — events may be missed") + async with self._lock: payload = command.encode("utf-8") + SEPARATOR self._writer.write(payload) @@ -107,39 +140,100 @@ class TclRpcConnection(Connection): return response async def _read_until_separator(self) -> bytes: - """Read from the stream until the \\x1a separator is found.""" + """Read from the command stream until the \\x1a separator is found. + + Preserves any bytes received after the separator for the next call. + Raises ``ConnectionError`` if the response exceeds ``MAX_RESPONSE_SIZE``. + """ assert self._reader is not None - buf = bytearray() + buf = self._remainder + self._remainder = bytearray() + + # Check if remainder already contains a complete response + idx = buf.find(SEPARATOR) + if idx != -1: + result = bytes(buf[:idx]) + self._remainder = bytearray(buf[idx + 1 :]) + return result + while True: chunk = await self._reader.read(4096) if not chunk: raise ConnectionError("OpenOCD closed the connection") buf.extend(chunk) + if len(buf) > MAX_RESPONSE_SIZE: + raise ConnectionError( + f"Response exceeded {MAX_RESPONSE_SIZE} bytes without separator — " + "is this an OpenOCD TCL RPC port?" + ) idx = buf.find(SEPARATOR) if idx != -1: - return bytes(buf[:idx]) + result = bytes(buf[:idx]) + self._remainder = bytearray(buf[idx + 1 :]) + return result # ------------------------------------------------------------------ - # Notifications (async events from OpenOCD) + # Notifications (separate connection) # ------------------------------------------------------------------ async def enable_notifications(self) -> None: - """Enable TCL event notifications and start the listener loop. + """Open a dedicated notification connection and start the listener. - Sends ``tcl_notifications on`` which causes OpenOCD to push - target-state-change events over the same socket. + A **separate TCP connection** to the same OpenOCD instance is + used for notifications. This avoids the dual-reader race + condition that would occur if notifications and command + responses shared the same stream. """ - await self.send("tcl_notifications on") + if not self._host: + raise ConnectionError("Not connected — call connect() first") + + try: + self._notif_reader, self._notif_writer = await asyncio.wait_for( + asyncio.open_connection(self._host, self._port), + timeout=self._timeout, + ) + except OSError as exc: + raise ConnectionError( + f"Cannot open notification connection to {self._host}:{self._port}: {exc}" + ) from exc + except TimeoutError as exc: + raise OcdTimeoutError( + f"Timed out opening notification connection to {self._host}:{self._port}" + ) from exc + + # Enable notifications on the dedicated connection + enable_cmd = b"tcl_notifications on" + SEPARATOR + self._notif_writer.write(enable_cmd) + await self._notif_writer.drain() + + # Read and discard the acknowledgement + ack_buf = bytearray() + while True: + chunk = await asyncio.wait_for( + self._notif_reader.read(4096), timeout=self._timeout + ) + if not chunk: + raise ConnectionError( + "Notification connection closed during setup" + ) + ack_buf.extend(chunk) + if ack_buf.find(SEPARATOR) != -1: + break + + log.debug("Notification connection established to %s:%d", self._host, self._port) + self._notification_failed = False self._notification_task = asyncio.create_task(self._notification_loop()) async def _notification_loop(self) -> None: - """Background task that reads unsolicited notifications.""" - assert self._reader is not None + """Background task that reads unsolicited notifications from + the dedicated notification connection.""" + assert self._notif_reader is not None buf = bytearray() try: while True: - chunk = await self._reader.read(4096) + chunk = await self._notif_reader.read(4096) if not chunk: + log.warning("Notification connection closed by OpenOCD") break buf.extend(chunk) while True: @@ -158,6 +252,8 @@ class TclRpcConnection(Connection): return except Exception: log.exception("Notification loop crashed") + finally: + self._notification_failed = True def on_notification(self, callback: Callable[[str], None]) -> None: self._notification_callbacks.append(callback) diff --git a/src/openocd/connection/telnet.py b/src/openocd/connection/telnet.py index 80b0155..a5c6806 100644 --- a/src/openocd/connection/telnet.py +++ b/src/openocd/connection/telnet.py @@ -19,6 +19,7 @@ log = logging.getLogger(__name__) PROMPT = b"> " DEFAULT_TIMEOUT = 10.0 +MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10 MB class TelnetConnection(Connection): @@ -90,6 +91,10 @@ class TelnetConnection(Connection): if not chunk: raise ConnectionError("OpenOCD closed the connection") buf.extend(chunk) + if len(buf) > MAX_RESPONSE_SIZE: + raise ConnectionError( + f"Response exceeded {MAX_RESPONSE_SIZE} bytes without prompt" + ) if buf.endswith(PROMPT): return bytes(buf[: -len(PROMPT)]) diff --git a/src/openocd/jtag/boundary.py b/src/openocd/jtag/boundary.py index 2a17d28..8ded55c 100644 --- a/src/openocd/jtag/boundary.py +++ b/src/openocd/jtag/boundary.py @@ -48,5 +48,5 @@ async def xsvf(conn: Connection, tap: str, path: Path) -> None: def _check_error(response: str, command: str) -> None: - if "Error" in response or "error" in response.split("\n")[0]: + if "error" in response.lower(): raise JTAGError(f"{command} failed: {response.strip()}") diff --git a/src/openocd/jtag/chain.py b/src/openocd/jtag/chain.py index 9b52cb6..babbd94 100644 --- a/src/openocd/jtag/chain.py +++ b/src/openocd/jtag/chain.py @@ -32,7 +32,7 @@ _CHAIN_ROW_RE = re.compile( async def scan_chain(conn: Connection) -> list[TAPInfo]: """Query the JTAG scan chain and return a list of discovered TAPs.""" resp = await conn.send("scan_chain") - if "Error" in resp: + if "error" in resp.lower(): raise JTAGError(f"scan_chain failed: {resp.strip()}") return _parse_scan_chain(resp) @@ -57,7 +57,7 @@ async def new_tap( if expected_id is not None: parts.extend(["-expected-id", f"0x{expected_id:08x}"]) resp = await conn.send(" ".join(parts)) - if "Error" in resp: + if "error" in resp.lower(): raise JTAGError(f"newtap failed: {resp.strip()}") diff --git a/src/openocd/jtag/scan.py b/src/openocd/jtag/scan.py index 777f524..19d5002 100644 --- a/src/openocd/jtag/scan.py +++ b/src/openocd/jtag/scan.py @@ -54,5 +54,5 @@ async def runtest(conn: Connection, cycles: int) -> None: def _check_error(response: str, command: str) -> None: """Raise JTAGError if OpenOCD reported an error.""" - if "Error" in response or "error" in response.split("\n")[0]: + if "error" in response.lower(): raise JTAGError(f"{command} failed: {response.strip()}") diff --git a/src/openocd/jtag/state.py b/src/openocd/jtag/state.py index 6fe9994..50dce93 100644 --- a/src/openocd/jtag/state.py +++ b/src/openocd/jtag/state.py @@ -22,5 +22,5 @@ async def pathmove(conn: Connection, states: list[JTAGState]) -> None: def _check_error(response: str, command: str) -> None: - if "Error" in response or "error" in response.split("\n")[0]: + if "error" in response.lower(): raise JTAGError(f"{command} failed: {response.strip()}") diff --git a/src/openocd/process.py b/src/openocd/process.py index 94ae8cf..0351579 100644 --- a/src/openocd/process.py +++ b/src/openocd/process.py @@ -40,7 +40,7 @@ class OpenOCDProcess: async def start( self, - config: str, + config: str | list[str], extra_args: list[str] | None = None, tcl_port: int = DEFAULT_TCL_PORT, openocd_bin: str | None = None, @@ -48,9 +48,10 @@ class OpenOCDProcess: """Start OpenOCD with the given configuration. Args: - config: Config file path or inline ``-f`` / ``-c`` arguments. - Multiple files can be separated by spaces with ``-f`` prefixes, - e.g. ``"interface/cmsis-dap.cfg -f target/stm32f1x.cfg"``. + config: Config file path, inline ``-f`` / ``-c`` arguments string, + or a list of arguments (preferred for paths with spaces). + String form: ``"interface/cmsis-dap.cfg -f target/stm32f1x.cfg"`` + List form: ``["-f", "my config/board.cfg", "-f", "target/stm32f1x.cfg"]`` extra_args: Additional CLI arguments. tcl_port: TCL RPC port (default 6666). openocd_bin: Path to OpenOCD binary (auto-detected if None). @@ -63,17 +64,27 @@ class OpenOCDProcess: ) args = [binary] - # Parse the config string — support both bare paths and -f/-c flags - config_parts = config.split() - i = 0 - while i < len(config_parts): - part = config_parts[i] - if part in ("-f", "-c"): - args.extend([part, config_parts[i + 1]]) - i += 2 - else: - args.extend(["-f", part]) - i += 1 + + # Accept either a pre-split list or a string to parse + if isinstance(config, list): + args.extend(config) + else: + config_parts = config.split() + if not config_parts: + raise ProcessError("Empty config string") + i = 0 + while i < len(config_parts): + part = config_parts[i] + if part in ("-f", "-c"): + if i + 1 >= len(config_parts): + raise ProcessError( + f"Config flag '{part}' requires an argument" + ) + args.extend([part, config_parts[i + 1]]) + i += 2 + else: + args.extend(["-f", part]) + i += 1 args.extend(["-c", f"tcl_port {tcl_port}"]) @@ -84,7 +95,7 @@ class OpenOCDProcess: try: self._proc = await asyncio.create_subprocess_exec( *args, - stdout=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE, ) except FileNotFoundError as exc: diff --git a/src/openocd/py.typed b/src/openocd/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/openocd/rtt.py b/src/openocd/rtt.py index bd5cc70..727f753 100644 --- a/src/openocd/rtt.py +++ b/src/openocd/rtt.py @@ -124,7 +124,14 @@ class RTTManager: Raises: OpenOCDError: If the write command fails. """ - cmd = f'rtt channelwrite {channel} "{data}"' + # Escape TCL special characters to prevent injection + escaped = ( + data.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("[", "\\[") + .replace("$", "\\$") + ) + cmd = f'rtt channelwrite {channel} "{escaped}"' response = await self._conn.send(cmd) _check_rtt_response(response, cmd) diff --git a/src/openocd/session.py b/src/openocd/session.py index 72ba36e..33b86e4 100644 --- a/src/openocd/session.py +++ b/src/openocd/session.py @@ -74,8 +74,12 @@ class Session: ) await proc.wait_ready(timeout=timeout) - conn = TclRpcConnection(timeout=timeout) - await conn.connect("localhost", tcl_port) + try: + conn = TclRpcConnection(timeout=timeout) + await conn.connect("localhost", tcl_port) + except Exception: + await proc.stop() + raise return cls(connection=conn, process=proc) @@ -229,6 +233,13 @@ class SyncSession: def __init__(self, session: Session, loop: asyncio.AbstractEventLoop) -> None: self._session = session self._loop = loop + self._target: SyncTarget | None = None + self._memory: SyncMemory | None = None + self._registers: SyncRegisters | None = None + self._flash: SyncFlash | None = None + self._jtag: SyncJTAGController | None = None + self._breakpoints: SyncBreakpointManager | None = None + self._svd: SyncSVDManager | None = None def __enter__(self) -> SyncSession: return self @@ -241,38 +252,52 @@ class SyncSession: @property def target(self) -> SyncTarget: - from openocd.target import SyncTarget - return SyncTarget(self._session.target, self._loop) + if self._target is None: + from openocd.target import SyncTarget + self._target = SyncTarget(self._session.target, self._loop) + return self._target @property def memory(self) -> SyncMemory: - from openocd.memory import SyncMemory - return SyncMemory(self._session.memory, self._loop) + if self._memory is None: + from openocd.memory import SyncMemory + self._memory = SyncMemory(self._session.memory, self._loop) + return self._memory @property def registers(self) -> SyncRegisters: - from openocd.registers import SyncRegisters - return SyncRegisters(self._session.registers, self._loop) + if self._registers is None: + from openocd.registers import SyncRegisters + self._registers = SyncRegisters(self._session.registers, self._loop) + return self._registers @property def flash(self) -> SyncFlash: - from openocd.flash import SyncFlash - return SyncFlash(self._session.flash, self._loop) + if self._flash is None: + from openocd.flash import SyncFlash + self._flash = SyncFlash(self._session.flash, self._loop) + return self._flash @property def jtag(self) -> SyncJTAGController: - from openocd.jtag import SyncJTAGController - return SyncJTAGController(self._session.jtag, self._loop) + if self._jtag is None: + from openocd.jtag import SyncJTAGController + self._jtag = SyncJTAGController(self._session.jtag, self._loop) + return self._jtag @property def breakpoints(self) -> SyncBreakpointManager: - from openocd.breakpoints import SyncBreakpointManager - return SyncBreakpointManager(self._session.breakpoints, self._loop) + if self._breakpoints is None: + from openocd.breakpoints import SyncBreakpointManager + self._breakpoints = SyncBreakpointManager(self._session.breakpoints, self._loop) + return self._breakpoints @property def svd(self) -> SyncSVDManager: - from openocd.svd import SyncSVDManager - return SyncSVDManager(self._session.svd, self._loop) + if self._svd is None: + from openocd.svd import SyncSVDManager + self._svd = SyncSVDManager(self._session.svd, self._loop) + return self._svd # ====================================================================== @@ -280,16 +305,20 @@ class SyncSession: # ====================================================================== def _get_or_create_loop() -> asyncio.AbstractEventLoop: - """Get the running event loop, or create a new one if there isn't one.""" + """Get or create an event loop for synchronous usage. + + Raises RuntimeError if called from within an already-running async + context (where ``run_until_complete`` would deadlock). + """ try: - loop = asyncio.get_running_loop() - # If we're already in an async context we can't use run_until_complete + asyncio.get_running_loop() + except RuntimeError: + pass # No running loop — this is the expected path for sync usage + else: raise RuntimeError( "Cannot use sync API from an async context. " "Use the async Session.start()/connect() instead." ) - except RuntimeError: - pass try: loop = asyncio.get_event_loop() if loop.is_closed(): diff --git a/src/openocd/svd/peripheral.py b/src/openocd/svd/peripheral.py index e68763b..2e806f2 100644 --- a/src/openocd/svd/peripheral.py +++ b/src/openocd/svd/peripheral.py @@ -51,7 +51,7 @@ class SVDManager: Raises: SVDError: If the file is missing or unparseable. """ - self._parser.load(svd_path) + await asyncio.to_thread(self._parser.load, svd_path) def list_peripherals(self) -> list[str]: """Return sorted peripheral names from the loaded SVD. diff --git a/src/openocd/types.py b/src/openocd/types.py index 0b29be1..6ab8bad 100644 --- a/src/openocd/types.py +++ b/src/openocd/types.py @@ -127,7 +127,7 @@ class BitField: description: str -@dataclass +@dataclass(frozen=True) class DecodedRegister: """A register value decoded into named bitfields via SVD.""" diff --git a/tests/test_error_paths.py b/tests/test_error_paths.py new file mode 100644 index 0000000..ba6b3c4 --- /dev/null +++ b/tests/test_error_paths.py @@ -0,0 +1,626 @@ +"""Error-path tests for openocd-python. + +Exercises every error condition and exception branch across the +connection, target, memory, register, flash, breakpoint, session, +and process subsystems. + +Each test configures a mock server to return error responses (or +misbehave at the protocol level) and asserts that the correct +exception type is raised with a meaningful message. +""" +from __future__ import annotations + +import asyncio + +import pytest + +from openocd.breakpoints import BreakpointError, BreakpointManager +from openocd.connection.tcl_rpc import TclRpcConnection +from openocd.errors import ( + ConnectionError, + FlashError, + ProcessError, + TargetError, + TargetNotHaltedError, + TimeoutError, +) +from openocd.flash import Flash +from openocd.memory import Memory +from openocd.process import OpenOCDProcess +from openocd.registers import Registers +from openocd.session import Session, _get_or_create_loop +from openocd.target import Target +from tests.mock_server import MockOpenOCDServer + +# ====================================================================== +# Helpers: error-returning mock servers +# ====================================================================== + + +@pytest.fixture +async def error_server(): + """A MockOpenOCDServer pre-wired to return error strings. + + The default response table is left intact so that "targets" and + similar plumbing commands still work. Individual tests prepend + error-producing routes via ``server.add_response()``. + """ + server = MockOpenOCDServer() + await server.start() + yield server + await server.stop() + + +@pytest.fixture +async def error_conn(error_server): + """A TclRpcConnection wired to the error mock server.""" + host, port = error_server.address + conn = TclRpcConnection(timeout=5.0) + await conn.connect(host, port) + yield conn + await conn.close() + + +# ====================================================================== +# 1. Connection error paths +# ====================================================================== + + +class TestConnectionErrors: + """Errors at the transport / framing layer.""" + + async def test_send_on_closed_connection(self): + """send() after close() raises ConnectionError.""" + server = MockOpenOCDServer() + await server.start() + host, port = server.address + + conn = TclRpcConnection(timeout=5.0) + await conn.connect(host, port) + await conn.close() + + with pytest.raises(ConnectionError, match="Not connected"): + await conn.send("targets") + + await server.stop() + + async def test_send_before_connect(self): + """send() without a prior connect() raises ConnectionError.""" + conn = TclRpcConnection() + with pytest.raises(ConnectionError, match="Not connected"): + await conn.send("targets") + + async def test_timeout_when_server_never_responds(self): + """A server that reads but never sends \\x1a triggers TimeoutError.""" + hang_event = asyncio.Event() + + async def _black_hole(reader, writer): + await reader.read(4096) + # never send a response -- wait until test signals us to stop + try: + await hang_event.wait() + except asyncio.CancelledError: + pass + finally: + writer.close() + + srv = await asyncio.start_server(_black_hole, "127.0.0.1", 0) + await srv.start_serving() + host, port = srv.sockets[0].getsockname()[:2] + + conn = TclRpcConnection(timeout=0.3) + await conn.connect(host, port) + + with pytest.raises(TimeoutError): + await conn.send("targets") + + await conn.close() + hang_event.set() + srv.close() + await srv.wait_closed() + + async def test_server_closes_connection_mid_stream(self): + """Server closing the socket without a separator raises ConnectionError.""" + + async def _close_immediately(reader, writer): + await reader.read(4096) + # send partial data with no separator then close + writer.write(b"partial response") + await writer.drain() + writer.close() + await writer.wait_closed() + + srv = await asyncio.start_server(_close_immediately, "127.0.0.1", 0) + await srv.start_serving() + host, port = srv.sockets[0].getsockname()[:2] + + conn = TclRpcConnection(timeout=2.0) + await conn.connect(host, port) + + with pytest.raises(ConnectionError, match="closed the connection"): + await conn.send("targets") + + await conn.close() + srv.close() + await srv.wait_closed() + + async def test_bounded_read_rejects_oversized_response(self): + """Response exceeding MAX_RESPONSE_SIZE without separator raises ConnectionError.""" + + async def _flood(reader, writer): + await reader.read(4096) + # send a lot of data with no separator + chunk = b"A" * 65536 + try: + while True: + writer.write(chunk) + await writer.drain() + except (BrokenPipeError, ConnectionResetError): + pass + + srv = await asyncio.start_server(_flood, "127.0.0.1", 0) + await srv.start_serving() + host, port = srv.sockets[0].getsockname()[:2] + + conn = TclRpcConnection(timeout=10.0) + await conn.connect(host, port) + + with pytest.raises(ConnectionError, match="exceeded"): + await conn.send("targets") + + await conn.close() + srv.close() + await srv.wait_closed() + + async def test_connect_refused(self): + """Connecting to a port with nothing listening raises ConnectionError.""" + conn = TclRpcConnection(timeout=1.0) + with pytest.raises(ConnectionError): + await conn.connect("127.0.0.1", 1) + + +# ====================================================================== +# 2. Target error paths +# ====================================================================== + + +class TestTargetErrors: + """Error conditions from the Target subsystem.""" + + async def test_halt_error_response(self, error_server, error_conn): + """halt() with an error response raises TargetError.""" + error_server.add_response(r"^halt$", "error: target not responding") + target = Target(error_conn) + + with pytest.raises(TargetError, match="halt failed"): + await target.halt() + + async def test_halt_already_halted_is_not_error(self, error_server, error_conn): + """halt() when response says 'already halted' should NOT raise.""" + error_server.add_response(r"^halt$", "error: target already halted") + target = Target(error_conn) + # "already halted" is a benign condition -- halt() checks for it + state = await target.halt() + assert state.state in ("halted", "unknown") + + async def test_resume_error_response(self, error_server, error_conn): + """resume() with an error response raises TargetError.""" + error_server.add_response(r"^resume", "error: cannot resume target") + target = Target(error_conn) + + with pytest.raises(TargetError, match="resume failed"): + await target.resume() + + async def test_step_error_response(self, error_server, error_conn): + """step() with an error response raises TargetError.""" + error_server.add_response(r"^step", "error: step failed on target") + target = Target(error_conn) + + with pytest.raises(TargetError, match="step failed"): + await target.step() + + async def test_wait_halt_timeout(self, error_server, error_conn): + """wait_halt receiving 'timed out' raises TimeoutError.""" + error_server.add_response(r"^wait_halt", "timed out while waiting for target") + target = Target(error_conn) + + with pytest.raises(TimeoutError, match="did not halt"): + await target.wait_halt(timeout_ms=100) + + async def test_wait_halt_time_out_variant(self, error_server, error_conn): + """wait_halt receiving 'time out' (two words) also raises TimeoutError.""" + error_server.add_response(r"^wait_halt", "time out waiting for halt") + target = Target(error_conn) + + with pytest.raises(TimeoutError, match="did not halt"): + await target.wait_halt(timeout_ms=100) + + async def test_wait_halt_generic_error(self, error_server, error_conn): + """wait_halt with a generic error (not timeout) raises TargetError.""" + error_server.add_response(r"^wait_halt", "error: target communication failure") + target = Target(error_conn) + + with pytest.raises(TargetError, match="wait_halt failed"): + await target.wait_halt(timeout_ms=100) + + async def test_state_unexpected_format(self, error_server, error_conn): + """targets returning garbage still produces a TargetState with 'unknown'.""" + error_server.add_response(r"^targets$", "this is not valid target output") + # Also need to suppress the reg pc call that _parse_state makes + error_server.add_response(r"^reg\s+pc$", "no such register") + target = Target(error_conn) + + state = await target.state() + assert state.name == "unknown" + assert state.state == "unknown" + assert state.current_pc is None + + async def test_state_unrecognized_state_string(self, error_server, error_conn): + """A target row with a bizarre state string normalizes to 'unknown'.""" + weird_table = ( + " TargetName Type Endian TapName State\n" + "-- ------------------ ---------- ------ ------------------ ------------\n" + " 0* stm32f1x.cpu cortex_m little stm32f1x.cpu exploding" + ) + error_server.add_response(r"^targets$", weird_table) + target = Target(error_conn) + + state = await target.state() + assert state.name == "stm32f1x.cpu" + assert state.state == "unknown" + assert state.current_pc is None + + async def test_reset_error_response(self, error_server, error_conn): + """reset() with an error response raises TargetError.""" + error_server.add_response(r"^reset\s+", "error: reset failed, adapter not found") + target = Target(error_conn) + + with pytest.raises(TargetError, match="reset failed"): + await target.reset("halt") + + +# ====================================================================== +# 3. Memory error paths +# ====================================================================== + + +class TestMemoryErrors: + """Error conditions from the Memory subsystem.""" + + async def test_read_u32_target_not_halted(self, error_server, error_conn): + """read_u32 with 'target not halted' in response raises TargetError.""" + error_server.add_response( + r"^read_memory\s+", "error: target not halted" + ) + mem = Memory(error_conn) + + with pytest.raises(TargetError, match="read_memory failed"): + await mem.read_u32(0x20000000, 1) + + async def test_write_u32_error_response(self, error_server, error_conn): + """write_u32 with an error response raises TargetError.""" + error_server.add_response( + r"^write_memory\s+", "error: target not halted" + ) + mem = Memory(error_conn) + + with pytest.raises(TargetError, match="write_memory failed"): + await mem.write_u32(0x20000000, 0xDEADBEEF) + + async def test_read_u32_non_hex_tokens(self, error_server, error_conn): + """read_memory returning non-hex garbage raises TargetError.""" + error_server.add_response( + r"^read_memory\s+", "not_a_hex_value xyz !!!" + ) + mem = Memory(error_conn) + + with pytest.raises(TargetError, match="Cannot parse read_memory"): + await mem.read_u32(0x20000000, 1) + + async def test_read_u8_error_response(self, error_server, error_conn): + """read_u8 with an error response raises TargetError.""" + error_server.add_response( + r"^read_memory\s+", "error: bus fault during memory read" + ) + mem = Memory(error_conn) + + with pytest.raises(TargetError, match="read_memory failed"): + await mem.read_u8(0xFFFFFFFF, 4) + + async def test_write_bytes_error_response(self, error_server, error_conn): + """write_bytes with an error response raises TargetError.""" + error_server.add_response( + r"^write_memory\s+", "error: write access violation" + ) + mem = Memory(error_conn) + + with pytest.raises(TargetError, match="write_memory failed"): + await mem.write_bytes(0x00000000, b"\x01\x02\x03") + + async def test_read_u16_error_response(self, error_server, error_conn): + """read_u16 with an error response raises TargetError.""" + error_server.add_response( + r"^read_memory\s+", "error: alignment fault" + ) + mem = Memory(error_conn) + + with pytest.raises(TargetError, match="read_memory failed"): + await mem.read_u16(0x20000001, 1) + + +# ====================================================================== +# 4. Register error paths +# ====================================================================== + + +class TestRegisterErrors: + """Error conditions from the Registers subsystem.""" + + async def test_read_not_halted(self, error_server, error_conn): + """read('pc') when target is not halted raises TargetNotHaltedError.""" + error_server.add_response( + r"^reg\s+pc$", "target not halted" + ) + regs = Registers(error_conn) + + with pytest.raises(TargetNotHaltedError, match="halted"): + await regs.read("pc") + + async def test_read_nonexistent_register(self, error_server, error_conn): + """read('nonexistent') with unparseable response raises TargetError.""" + error_server.add_response( + r"^reg\s+nonexistent$", "invalid command name \"nonexistent\"" + ) + regs = Registers(error_conn) + + with pytest.raises(TargetError, match="Cannot parse register"): + await regs.read("nonexistent") + + async def test_write_not_halted(self, error_server, error_conn): + """write('pc', val) when target is not halted raises TargetNotHaltedError.""" + error_server.add_response( + r"^reg\s+pc\s+0x", "target not halted" + ) + regs = Registers(error_conn) + + with pytest.raises(TargetNotHaltedError, match="halted"): + await regs.write("pc", 0x1234) + + async def test_write_generic_error(self, error_server, error_conn): + """write() with a non-halted-related error raises TargetError.""" + error_server.add_response( + r"^reg\s+r0\s+0x", "error: register write failed" + ) + regs = Registers(error_conn) + + with pytest.raises(TargetError, match="reg write failed"): + await regs.write("r0", 0xDEAD) + + async def test_read_all_not_halted(self, error_server, error_conn): + """read_all() when target is not halted raises TargetNotHaltedError.""" + error_server.add_response(r"^reg$", "target not halted") + regs = Registers(error_conn) + + with pytest.raises(TargetNotHaltedError, match="halted"): + await regs.read_all() + + async def test_read_many_partial_failure(self, error_server, error_conn): + """read_many() should propagate the first register read failure.""" + # pc succeeds, but sp returns not-halted + error_server.add_response(r"^reg\s+sp$", "target not halted") + regs = Registers(error_conn) + + with pytest.raises(TargetNotHaltedError): + await regs.read_many(["pc", "sp"]) + + +# ====================================================================== +# 5. Flash error paths +# ====================================================================== + + +class TestFlashErrors: + """Error conditions from the Flash subsystem.""" + + async def test_banks_error(self, error_server, error_conn): + """flash.banks() with an error response raises FlashError.""" + error_server.add_response(r"^flash banks$", "error: no flash banks configured") + flash = Flash(error_conn) + + with pytest.raises(FlashError, match="flash banks"): + await flash.banks() + + async def test_info_error(self, error_server, error_conn): + """flash.info() with an error response raises FlashError.""" + error_server.add_response(r"^flash info\s+", "error: invalid bank number") + flash = Flash(error_conn) + + with pytest.raises(FlashError, match="flash info"): + await flash.info(bank=99) + + async def test_erase_sector_invalid_range(self, error_server, error_conn): + """erase_sector with first > last raises FlashError locally.""" + flash = Flash(error_conn) + + with pytest.raises(FlashError, match="Invalid sector range"): + await flash.erase_sector(bank=0, first=10, last=5) + + async def test_erase_sector_error_response(self, error_server, error_conn): + """erase_sector with error from server raises FlashError.""" + error_server.add_response(r"^flash erase_sector\s+", "error: erase failed") + flash = Flash(error_conn) + + with pytest.raises(FlashError, match="flash erase_sector"): + await flash.erase_sector(bank=0, first=0, last=3) + + async def test_write_image_error(self, error_server, error_conn): + """write_image with error from server raises FlashError.""" + error_server.add_response( + r"^flash write_image\s+", "error: flash write failed" + ) + flash = Flash(error_conn) + + with pytest.raises(FlashError, match="flash write_image"): + from pathlib import Path + await flash.write_image(Path("/tmp/fake_firmware.bin"), verify=False) + + async def test_protect_error(self, error_server, error_conn): + """flash.protect() with error response raises FlashError.""" + error_server.add_response( + r"^flash protect\s+", "error: protection change not supported" + ) + flash = Flash(error_conn) + + with pytest.raises(FlashError, match="flash protect"): + await flash.protect(bank=0, first=0, last=3, on=True) + + +# ====================================================================== +# 6. Breakpoint error paths +# ====================================================================== + + +class TestBreakpointErrors: + """Error conditions from the BreakpointManager subsystem.""" + + async def test_add_breakpoint_error(self, error_server, error_conn): + """add() with error response raises BreakpointError.""" + error_server.add_response( + r"^bp\s+0x", "error: can not add breakpoint, resource not available" + ) + bp = BreakpointManager(error_conn) + + with pytest.raises(BreakpointError, match="bp 0x"): + await bp.add(0x08001234) + + async def test_remove_breakpoint_error(self, error_server, error_conn): + """remove() with error response raises BreakpointError.""" + error_server.add_response( + r"^rbp\s+", "error: no breakpoint at address" + ) + bp = BreakpointManager(error_conn) + + with pytest.raises(BreakpointError, match="rbp 0x"): + await bp.remove(0x08001234) + + async def test_add_watchpoint_error(self, error_server, error_conn): + """add_watchpoint() with error response raises BreakpointError.""" + error_server.add_response( + r"^wp\s+0x", "error: no free watchpoint comparator" + ) + bp = BreakpointManager(error_conn) + + with pytest.raises(BreakpointError, match="wp 0x"): + await bp.add_watchpoint(0x20000000, 4) + + async def test_remove_watchpoint_error(self, error_server, error_conn): + """remove_watchpoint() with error response raises BreakpointError.""" + error_server.add_response( + r"^rwp\s+", "error: no watchpoint at address" + ) + bp = BreakpointManager(error_conn) + + with pytest.raises(BreakpointError, match="rwp 0x"): + await bp.remove_watchpoint(0x20000000) + + +# ====================================================================== +# 7. Session error paths +# ====================================================================== + + +class TestSessionErrors: + """Error conditions from the Session layer.""" + + async def test_get_or_create_loop_from_async_context(self): + """_get_or_create_loop() inside a running loop raises RuntimeError.""" + with pytest.raises(RuntimeError, match="Cannot use sync API"): + _get_or_create_loop() + + async def test_command_on_closed_session(self, mock_ocd): + """session.command() after close() raises ConnectionError.""" + host, port, _server = mock_ocd + sess = await Session.connect(host, port, timeout=5.0) + await sess.close() + + with pytest.raises(ConnectionError, match="Not connected"): + await sess.command("targets") + + async def test_connect_to_nonexistent_host(self): + """Session.connect() to a bogus address raises ConnectionError.""" + with pytest.raises(ConnectionError): + await Session.connect("127.0.0.1", 1, timeout=1.0) + + async def test_double_close_is_safe(self, mock_ocd): + """Calling close() twice on a session should not raise.""" + host, port, _server = mock_ocd + sess = await Session.connect(host, port, timeout=5.0) + await sess.close() + await sess.close() # should be a no-op + + +# ====================================================================== +# 8. Process error paths +# ====================================================================== + + +class TestProcessErrors: + """Error conditions from the OpenOCDProcess manager.""" + + async def test_start_empty_config(self): + """start() with an empty config string raises ProcessError.""" + proc = OpenOCDProcess() + with pytest.raises(ProcessError, match="[Ee]mpty config"): + # Use /bin/true as a stand-in binary so we reach config validation + await proc.start("", openocd_bin="/bin/true") + + async def test_start_dangling_flag(self): + """start() with a trailing -f and no argument raises ProcessError.""" + proc = OpenOCDProcess() + with pytest.raises(ProcessError, match="requires an argument"): + await proc.start("-f", openocd_bin="/bin/true") + + async def test_start_dangling_c_flag(self): + """start() with a trailing -c and no argument raises ProcessError.""" + proc = OpenOCDProcess() + with pytest.raises(ProcessError, match="requires an argument"): + await proc.start("-c", openocd_bin="/bin/true") + + async def test_start_nonexistent_binary(self): + """start() with a nonexistent binary path raises ProcessError.""" + proc = OpenOCDProcess() + with pytest.raises(ProcessError): + await proc.start( + "interface/cmsis-dap.cfg", + openocd_bin="/nonexistent/path/to/openocd", + ) + + async def test_pid_is_none_before_start(self): + """pid property is None before start().""" + proc = OpenOCDProcess() + assert proc.pid is None + + async def test_running_is_false_before_start(self): + """running property is False before start().""" + proc = OpenOCDProcess() + assert proc.running is False + + async def test_stop_before_start_is_safe(self): + """stop() before start() should not raise.""" + proc = OpenOCDProcess() + await proc.stop() # no-op, no exception + + +# ====================================================================== +# 9. Notification connection error paths +# ====================================================================== + + +class TestNotificationErrors: + """Error conditions for the notification subsystem.""" + + async def test_enable_notifications_before_connect(self): + """enable_notifications() before connect() raises ConnectionError.""" + conn = TclRpcConnection(timeout=1.0) + with pytest.raises(ConnectionError, match="Not connected"): + await conn.enable_notifications()