Apply ruff format to existing source and test files

This commit is contained in:
Ryan Malloy 2026-02-15 16:36:12 -07:00
parent 988234f4c3
commit 7a893cb328
22 changed files with 89 additions and 133 deletions

View File

@ -203,6 +203,7 @@ class BreakpointManager:
# Sync wrapper # Sync wrapper
# ====================================================================== # ======================================================================
class SyncBreakpointManager: class SyncBreakpointManager:
"""Synchronous wrapper around BreakpointManager.""" """Synchronous wrapper around BreakpointManager."""

View File

@ -27,12 +27,8 @@ def main() -> None:
prog="openocd-python", prog="openocd-python",
description=f"OpenOCD Python bindings v{pkg_version}", description=f"OpenOCD Python bindings v{pkg_version}",
) )
parser.add_argument( parser.add_argument("--version", action="version", version=f"openocd-python {pkg_version}")
"--version", action="version", version=f"openocd-python {pkg_version}" parser.add_argument("--host", default="localhost", help="OpenOCD host (default: localhost)")
)
parser.add_argument(
"--host", default="localhost", help="OpenOCD host (default: localhost)"
)
parser.add_argument( parser.add_argument(
"--port", type=int, default=6666, help="OpenOCD TCL RPC port (default: 6666)" "--port", type=int, default=6666, help="OpenOCD TCL RPC port (default: 6666)"
) )

View File

@ -131,9 +131,7 @@ class TclRpcConnection(Connection):
timeout=self._timeout, timeout=self._timeout,
) )
except TimeoutError as exc: except TimeoutError as exc:
raise OcdTimeoutError( raise OcdTimeoutError(f"Timed out waiting for response to: {command}") from exc
f"Timed out waiting for response to: {command}"
) from exc
response = raw.decode("utf-8", errors="replace") response = raw.decode("utf-8", errors="replace")
log.debug("RX: %s", response[:200]) log.debug("RX: %s", response[:200])
@ -209,13 +207,9 @@ class TclRpcConnection(Connection):
# Read and discard the acknowledgement # Read and discard the acknowledgement
ack_buf = bytearray() ack_buf = bytearray()
while True: while True:
chunk = await asyncio.wait_for( chunk = await asyncio.wait_for(self._notif_reader.read(4096), timeout=self._timeout)
self._notif_reader.read(4096), timeout=self._timeout
)
if not chunk: if not chunk:
raise ConnectionError( raise ConnectionError("Notification connection closed during setup")
"Notification connection closed during setup"
)
ack_buf.extend(chunk) ack_buf.extend(chunk)
if ack_buf.find(SEPARATOR) != -1: if ack_buf.find(SEPARATOR) != -1:
break break

View File

@ -92,9 +92,7 @@ class TelnetConnection(Connection):
raise ConnectionError("OpenOCD closed the connection") raise ConnectionError("OpenOCD closed the connection")
buf.extend(chunk) buf.extend(chunk)
if len(buf) > MAX_RESPONSE_SIZE: if len(buf) > MAX_RESPONSE_SIZE:
raise ConnectionError( raise ConnectionError(f"Response exceeded {MAX_RESPONSE_SIZE} bytes without prompt")
f"Response exceeded {MAX_RESPONSE_SIZE} bytes without prompt"
)
if buf.endswith(PROMPT): if buf.endswith(PROMPT):
return bytes(buf[: -len(PROMPT)]) return bytes(buf[: -len(PROMPT)])

View File

@ -168,12 +168,7 @@ class Flash:
# Read the file back through TCL to handle remote OpenOCD instances. # Read the file back through TCL to handle remote OpenOCD instances.
# Use ocd_find + binary read if available, otherwise fall back to # Use ocd_find + binary read if available, otherwise fall back to
# reading the local file. # reading the local file.
tcl_read = ( tcl_read = f"set fp [open {tmp_path} rb]; set data [read $fp]; close $fp; set data"
f"set fp [open {tmp_path} rb]; "
f"set data [read $fp]; "
f"close $fp; "
f"set data"
)
try: try:
raw = await self._conn.send(tcl_read) raw = await self._conn.send(tcl_read)
# TCL returns binary as string; try base64 approach if garbled # TCL returns binary as string; try base64 approach if garbled
@ -343,6 +338,7 @@ class Flash:
# Sync wrapper # Sync wrapper
# ====================================================================== # ======================================================================
class SyncFlash: class SyncFlash:
"""Synchronous wrapper around Flash for use outside async contexts.""" """Synchronous wrapper around Flash for use outside async contexts."""

View File

@ -20,12 +20,12 @@ log = logging.getLogger(__name__)
# Example line: # Example line:
# 0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f # 0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f
_CHAIN_ROW_RE = re.compile( _CHAIN_ROW_RE = re.compile(
r"^\s*\d+\s+" # index r"^\s*\d+\s+" # index
r"(\S+)\s+" # tap name (chip.tap) r"(\S+)\s+" # tap name (chip.tap)
r"([YN])\s+" # enabled r"([YN])\s+" # enabled
r"(0x[0-9a-fA-F]+)\s+" # idcode r"(0x[0-9a-fA-F]+)\s+" # idcode
r"(0x[0-9a-fA-F]+)\s+" # expected r"(0x[0-9a-fA-F]+)\s+" # expected
r"(\d+)", # ir_length r"(\d+)", # ir_length
) )
@ -93,6 +93,7 @@ def _parse_scan_chain(raw: str) -> list[TAPInfo]:
# JTAGController — unified facade # JTAGController — unified facade
# ====================================================================== # ======================================================================
class JTAGController: class JTAGController:
"""High-level async interface to all JTAG operations. """High-level async interface to all JTAG operations.
@ -161,6 +162,7 @@ class JTAGController:
# SyncJTAGController — blocking wrappers # SyncJTAGController — blocking wrappers
# ====================================================================== # ======================================================================
class SyncJTAGController: class SyncJTAGController:
"""Synchronous wrapper around :class:`JTAGController`. """Synchronous wrapper around :class:`JTAGController`.

View File

@ -151,9 +151,7 @@ class Memory:
hex_str = hex_str.ljust(49) hex_str = hex_str.ljust(49)
# ASCII portion # ASCII portion
ascii_str = "".join( ascii_str = "".join(chr(b) if 0x20 <= b < 0x7F else "." for b in chunk)
chr(b) if 0x20 <= b < 0x7F else "." for b in chunk
)
lines.append(f"{line_addr:08X}: {hex_str} |{ascii_str}|") lines.append(f"{line_addr:08X}: {hex_str} |{ascii_str}|")
@ -179,9 +177,7 @@ class Memory:
try: try:
return [int(t, 16) for t in tokens] return [int(t, 16) for t in tokens]
except ValueError as exc: except ValueError as exc:
raise TargetError( raise TargetError(f"Cannot parse read_memory response: {resp!r}") from exc
f"Cannot parse read_memory response: {resp!r}"
) from exc
async def _write(self, addr: int, width: int, values: int | list[int]) -> None: async def _write(self, addr: int, width: int, values: int | list[int]) -> None:
"""Write values of *width* bits using the TCL ``write_memory`` API. """Write values of *width* bits using the TCL ``write_memory`` API.

View File

@ -59,9 +59,7 @@ class OpenOCDProcess:
self._tcl_port = tcl_port self._tcl_port = tcl_port
binary = openocd_bin or shutil.which("openocd") binary = openocd_bin or shutil.which("openocd")
if not binary: if not binary:
raise ProcessError( raise ProcessError("OpenOCD binary not found. Install it or pass openocd_bin=")
"OpenOCD binary not found. Install it or pass openocd_bin="
)
args = [binary] args = [binary]
@ -77,9 +75,7 @@ class OpenOCDProcess:
part = config_parts[i] part = config_parts[i]
if part in ("-f", "-c"): if part in ("-f", "-c"):
if i + 1 >= len(config_parts): if i + 1 >= len(config_parts):
raise ProcessError( raise ProcessError(f"Config flag '{part}' requires an argument")
f"Config flag '{part}' requires an argument"
)
args.extend([part, config_parts[i + 1]]) args.extend([part, config_parts[i + 1]])
i += 2 i += 2
else: else:
@ -129,9 +125,7 @@ class OpenOCDProcess:
except (OSError, TimeoutError): except (OSError, TimeoutError):
await asyncio.sleep(READY_POLL_INTERVAL) await asyncio.sleep(READY_POLL_INTERVAL)
raise OpenOCDTimeoutError( raise OpenOCDTimeoutError(f"OpenOCD did not become ready within {timeout}s")
f"OpenOCD did not become ready within {timeout}s"
)
async def stop(self) -> None: async def stop(self) -> None:
"""Terminate the OpenOCD process.""" """Terminate the OpenOCD process."""

View File

@ -17,20 +17,18 @@ from openocd.types import Register
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Matches "reg <name>" output: "pc (/32): 0x08001234" # Matches "reg <name>" output: "pc (/32): 0x08001234"
_REG_VALUE_RE = re.compile( _REG_VALUE_RE = re.compile(r"(\S+)\s+\(/(\d+)\):\s*(0x[0-9a-fA-F]+)")
r"(\S+)\s+\(/(\d+)\):\s*(0x[0-9a-fA-F]+)"
)
# Matches a row in "reg" (list all) output. # Matches a row in "reg" (list all) output.
# Typical formats: # Typical formats:
# "(0) r0 (/32): 0x00000000" # "(0) r0 (/32): 0x00000000"
# "(123) xPSR (/32): 0x61000000 (dirty)" # "(123) xPSR (/32): 0x61000000 (dirty)"
_REG_LIST_RE = re.compile( _REG_LIST_RE = re.compile(
r"\((\d+)\)\s+" # register number r"\((\d+)\)\s+" # register number
r"(\S+)\s+" # register name r"(\S+)\s+" # register name
r"\(/(\d+)\):\s*" # bit width r"\(/(\d+)\):\s*" # bit width
r"(0x[0-9a-fA-F]+)" # value r"(0x[0-9a-fA-F]+)" # value
r"(?:\s+\(dirty\))?" # optional dirty flag r"(?:\s+\(dirty\))?" # optional dirty flag
) )
@ -149,9 +147,7 @@ class Registers:
""" """
lower = resp.lower() lower = resp.lower()
if "not halted" in lower or "target not halted" in lower: if "not halted" in lower or "target not halted" in lower:
raise TargetNotHaltedError( raise TargetNotHaltedError("Target must be halted to access registers")
"Target must be halted to access registers"
)
class SyncRegisters: class SyncRegisters:

View File

@ -126,10 +126,7 @@ class RTTManager:
""" """
# Escape TCL special characters to prevent injection # Escape TCL special characters to prevent injection
escaped = ( escaped = (
data.replace("\\", "\\\\") data.replace("\\", "\\\\").replace('"', '\\"').replace("[", "\\[").replace("$", "\\$")
.replace('"', '\\"')
.replace("[", "\\[")
.replace("$", "\\$")
) )
cmd = f'rtt channelwrite {channel} "{escaped}"' cmd = f'rtt channelwrite {channel} "{escaped}"'
response = await self._conn.send(cmd) response = await self._conn.send(cmd)
@ -149,9 +146,7 @@ class SyncRTTManager:
size: int, size: int,
id_string: str = "SEGGER RTT", id_string: str = "SEGGER RTT",
) -> None: ) -> None:
self._loop.run_until_complete( self._loop.run_until_complete(self._manager.setup(address, size, id_string))
self._manager.setup(address, size, id_string)
)
def start(self) -> None: def start(self) -> None:
self._loop.run_until_complete(self._manager.start()) self._loop.run_until_complete(self._manager.start())
@ -173,6 +168,7 @@ class SyncRTTManager:
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _check_rtt_response(response: str, command: str) -> None: def _check_rtt_response(response: str, command: str) -> None:
"""Raise on error responses from RTT commands.""" """Raise on error responses from RTT commands."""
if response and "error" in response.lower(): if response and "error" in response.lower():

View File

@ -75,8 +75,7 @@ class SVDParserWrapper:
periph = self._peripherals.get(name) periph = self._peripherals.get(name)
if periph is None: if periph is None:
raise SVDError( raise SVDError(
f"Peripheral '{name}' not found. " f"Peripheral '{name}' not found. Available: {', '.join(sorted(self._peripherals))}"
f"Available: {', '.join(sorted(self._peripherals))}"
) )
return periph return periph

View File

@ -173,14 +173,10 @@ class SyncSVDManager:
return self._manager.list_registers(peripheral) return self._manager.list_registers(peripheral)
def read_register(self, peripheral: str, register: str) -> DecodedRegister: def read_register(self, peripheral: str, register: str) -> DecodedRegister:
return self._loop.run_until_complete( return self._loop.run_until_complete(self._manager.read_register(peripheral, register))
self._manager.read_register(peripheral, register)
)
def read_peripheral(self, peripheral: str) -> dict[str, DecodedRegister]: def read_peripheral(self, peripheral: str) -> dict[str, DecodedRegister]:
return self._loop.run_until_complete( return self._loop.run_until_complete(self._manager.read_peripheral(peripheral))
self._manager.read_peripheral(peripheral)
)
def decode(self, peripheral: str, register: str, value: int) -> DecodedRegister: def decode(self, peripheral: str, register: str, value: int) -> DecodedRegister:
return self._manager.decode(peripheral, register, value) return self._manager.decode(peripheral, register, value)

View File

@ -20,12 +20,12 @@ log = logging.getLogger(__name__)
# Matches a target row from "targets" output, e.g.: # Matches a target row from "targets" output, e.g.:
# " 0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted" # " 0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted"
_TARGET_ROW_RE = re.compile( _TARGET_ROW_RE = re.compile(
r"^\s*\d+\*?\s+" # index, optional current marker r"^\s*\d+\*?\s+" # index, optional current marker
r"(\S+)\s+" # target name r"(\S+)\s+" # target name
r"\S+\s+" # type r"\S+\s+" # type
r"\S+\s+" # endian r"\S+\s+" # endian
r"\S+\s+" # tap name r"\S+\s+" # tap name
r"(\S+)" # state r"(\S+)" # state
) )

View File

@ -125,6 +125,7 @@ class Transport:
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _parse_speed(response: str) -> int | None: def _parse_speed(response: str) -> int | None:
"""Extract a numeric kHz value from an adapter speed response. """Extract a numeric kHz value from an adapter speed response.

View File

@ -1,4 +1,5 @@
"""Shared pytest fixtures for openocd-python tests.""" """Shared pytest fixtures for openocd-python tests."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest

View File

@ -1,4 +1,5 @@
"""Tests for the TclRpcConnection class.""" """Tests for the TclRpcConnection class."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -63,6 +64,7 @@ async def test_send_before_connect_raises():
async def test_timeout_on_hung_server(): async def test_timeout_on_hung_server():
"""A server that never sends \\x1a should trigger a TimeoutError.""" """A server that never sends \\x1a should trigger a TimeoutError."""
# Start a server that accepts connections but never responds # Start a server that accepts connections but never responds
async def _hang(reader, writer): async def _hang(reader, writer):
# Read the command but never reply # Read the command but never reply

View File

@ -8,6 +8,7 @@ Each test configures a mock server to return error responses (or
misbehave at the protocol level) and asserts that the correct misbehave at the protocol level) and asserts that the correct
exception type is raised with a meaningful message. exception type is raised with a meaningful message.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -289,9 +290,7 @@ class TestMemoryErrors:
async def test_read_u32_target_not_halted(self, error_server, error_conn): async def test_read_u32_target_not_halted(self, error_server, error_conn):
"""read_u32 with 'target not halted' in response raises TargetError.""" """read_u32 with 'target not halted' in response raises TargetError."""
error_server.add_response( error_server.add_response(r"^read_memory\s+", "error: target not halted")
r"^read_memory\s+", "error: target not halted"
)
mem = Memory(error_conn) mem = Memory(error_conn)
with pytest.raises(TargetError, match="read_memory failed"): with pytest.raises(TargetError, match="read_memory failed"):
@ -299,9 +298,7 @@ class TestMemoryErrors:
async def test_write_u32_error_response(self, error_server, error_conn): async def test_write_u32_error_response(self, error_server, error_conn):
"""write_u32 with an error response raises TargetError.""" """write_u32 with an error response raises TargetError."""
error_server.add_response( error_server.add_response(r"^write_memory\s+", "error: target not halted")
r"^write_memory\s+", "error: target not halted"
)
mem = Memory(error_conn) mem = Memory(error_conn)
with pytest.raises(TargetError, match="write_memory failed"): with pytest.raises(TargetError, match="write_memory failed"):
@ -309,9 +306,7 @@ class TestMemoryErrors:
async def test_read_u32_non_hex_tokens(self, error_server, error_conn): async def test_read_u32_non_hex_tokens(self, error_server, error_conn):
"""read_memory returning non-hex garbage raises TargetError.""" """read_memory returning non-hex garbage raises TargetError."""
error_server.add_response( error_server.add_response(r"^read_memory\s+", "not_a_hex_value xyz !!!")
r"^read_memory\s+", "not_a_hex_value xyz !!!"
)
mem = Memory(error_conn) mem = Memory(error_conn)
with pytest.raises(TargetError, match="Cannot parse read_memory"): with pytest.raises(TargetError, match="Cannot parse read_memory"):
@ -319,9 +314,7 @@ class TestMemoryErrors:
async def test_read_u8_error_response(self, error_server, error_conn): async def test_read_u8_error_response(self, error_server, error_conn):
"""read_u8 with an error response raises TargetError.""" """read_u8 with an error response raises TargetError."""
error_server.add_response( error_server.add_response(r"^read_memory\s+", "error: bus fault during memory read")
r"^read_memory\s+", "error: bus fault during memory read"
)
mem = Memory(error_conn) mem = Memory(error_conn)
with pytest.raises(TargetError, match="read_memory failed"): with pytest.raises(TargetError, match="read_memory failed"):
@ -329,9 +322,7 @@ class TestMemoryErrors:
async def test_write_bytes_error_response(self, error_server, error_conn): async def test_write_bytes_error_response(self, error_server, error_conn):
"""write_bytes with an error response raises TargetError.""" """write_bytes with an error response raises TargetError."""
error_server.add_response( error_server.add_response(r"^write_memory\s+", "error: write access violation")
r"^write_memory\s+", "error: write access violation"
)
mem = Memory(error_conn) mem = Memory(error_conn)
with pytest.raises(TargetError, match="write_memory failed"): with pytest.raises(TargetError, match="write_memory failed"):
@ -339,9 +330,7 @@ class TestMemoryErrors:
async def test_read_u16_error_response(self, error_server, error_conn): async def test_read_u16_error_response(self, error_server, error_conn):
"""read_u16 with an error response raises TargetError.""" """read_u16 with an error response raises TargetError."""
error_server.add_response( error_server.add_response(r"^read_memory\s+", "error: alignment fault")
r"^read_memory\s+", "error: alignment fault"
)
mem = Memory(error_conn) mem = Memory(error_conn)
with pytest.raises(TargetError, match="read_memory failed"): with pytest.raises(TargetError, match="read_memory failed"):
@ -358,9 +347,7 @@ class TestRegisterErrors:
async def test_read_not_halted(self, error_server, error_conn): async def test_read_not_halted(self, error_server, error_conn):
"""read('pc') when target is not halted raises TargetNotHaltedError.""" """read('pc') when target is not halted raises TargetNotHaltedError."""
error_server.add_response( error_server.add_response(r"^reg\s+pc$", "target not halted")
r"^reg\s+pc$", "target not halted"
)
regs = Registers(error_conn) regs = Registers(error_conn)
with pytest.raises(TargetNotHaltedError, match="halted"): with pytest.raises(TargetNotHaltedError, match="halted"):
@ -368,9 +355,7 @@ class TestRegisterErrors:
async def test_read_nonexistent_register(self, error_server, error_conn): async def test_read_nonexistent_register(self, error_server, error_conn):
"""read('nonexistent') with unparseable response raises TargetError.""" """read('nonexistent') with unparseable response raises TargetError."""
error_server.add_response( error_server.add_response(r"^reg\s+nonexistent$", 'invalid command name "nonexistent"')
r"^reg\s+nonexistent$", "invalid command name \"nonexistent\""
)
regs = Registers(error_conn) regs = Registers(error_conn)
with pytest.raises(TargetError, match="Cannot parse register"): with pytest.raises(TargetError, match="Cannot parse register"):
@ -378,9 +363,7 @@ class TestRegisterErrors:
async def test_write_not_halted(self, error_server, error_conn): async def test_write_not_halted(self, error_server, error_conn):
"""write('pc', val) when target is not halted raises TargetNotHaltedError.""" """write('pc', val) when target is not halted raises TargetNotHaltedError."""
error_server.add_response( error_server.add_response(r"^reg\s+pc\s+0x", "target not halted")
r"^reg\s+pc\s+0x", "target not halted"
)
regs = Registers(error_conn) regs = Registers(error_conn)
with pytest.raises(TargetNotHaltedError, match="halted"): with pytest.raises(TargetNotHaltedError, match="halted"):
@ -388,9 +371,7 @@ class TestRegisterErrors:
async def test_write_generic_error(self, error_server, error_conn): async def test_write_generic_error(self, error_server, error_conn):
"""write() with a non-halted-related error raises TargetError.""" """write() with a non-halted-related error raises TargetError."""
error_server.add_response( error_server.add_response(r"^reg\s+r0\s+0x", "error: register write failed")
r"^reg\s+r0\s+0x", "error: register write failed"
)
regs = Registers(error_conn) regs = Registers(error_conn)
with pytest.raises(TargetError, match="reg write failed"): with pytest.raises(TargetError, match="reg write failed"):
@ -455,20 +436,17 @@ class TestFlashErrors:
async def test_write_image_error(self, error_server, error_conn): async def test_write_image_error(self, error_server, error_conn):
"""write_image with error from server raises FlashError.""" """write_image with error from server raises FlashError."""
error_server.add_response( error_server.add_response(r"^flash write_image\s+", "error: flash write failed")
r"^flash write_image\s+", "error: flash write failed"
)
flash = Flash(error_conn) flash = Flash(error_conn)
with pytest.raises(FlashError, match="flash write_image"): with pytest.raises(FlashError, match="flash write_image"):
from pathlib import Path from pathlib import Path
await flash.write_image(Path("/tmp/fake_firmware.bin"), verify=False) await flash.write_image(Path("/tmp/fake_firmware.bin"), verify=False)
async def test_protect_error(self, error_server, error_conn): async def test_protect_error(self, error_server, error_conn):
"""flash.protect() with error response raises FlashError.""" """flash.protect() with error response raises FlashError."""
error_server.add_response( error_server.add_response(r"^flash protect\s+", "error: protection change not supported")
r"^flash protect\s+", "error: protection change not supported"
)
flash = Flash(error_conn) flash = Flash(error_conn)
with pytest.raises(FlashError, match="flash protect"): with pytest.raises(FlashError, match="flash protect"):
@ -495,9 +473,7 @@ class TestBreakpointErrors:
async def test_remove_breakpoint_error(self, error_server, error_conn): async def test_remove_breakpoint_error(self, error_server, error_conn):
"""remove() with error response raises BreakpointError.""" """remove() with error response raises BreakpointError."""
error_server.add_response( error_server.add_response(r"^rbp\s+", "error: no breakpoint at address")
r"^rbp\s+", "error: no breakpoint at address"
)
bp = BreakpointManager(error_conn) bp = BreakpointManager(error_conn)
with pytest.raises(BreakpointError, match="rbp 0x"): with pytest.raises(BreakpointError, match="rbp 0x"):
@ -505,9 +481,7 @@ class TestBreakpointErrors:
async def test_add_watchpoint_error(self, error_server, error_conn): async def test_add_watchpoint_error(self, error_server, error_conn):
"""add_watchpoint() with error response raises BreakpointError.""" """add_watchpoint() with error response raises BreakpointError."""
error_server.add_response( error_server.add_response(r"^wp\s+0x", "error: no free watchpoint comparator")
r"^wp\s+0x", "error: no free watchpoint comparator"
)
bp = BreakpointManager(error_conn) bp = BreakpointManager(error_conn)
with pytest.raises(BreakpointError, match="wp 0x"): with pytest.raises(BreakpointError, match="wp 0x"):
@ -515,9 +489,7 @@ class TestBreakpointErrors:
async def test_remove_watchpoint_error(self, error_server, error_conn): async def test_remove_watchpoint_error(self, error_server, error_conn):
"""remove_watchpoint() with error response raises BreakpointError.""" """remove_watchpoint() with error response raises BreakpointError."""
error_server.add_response( error_server.add_response(r"^rwp\s+", "error: no watchpoint at address")
r"^rwp\s+", "error: no watchpoint at address"
)
bp = BreakpointManager(error_conn) bp = BreakpointManager(error_conn)
with pytest.raises(BreakpointError, match="rwp 0x"): with pytest.raises(BreakpointError, match="rwp 0x"):

View File

@ -1,4 +1,5 @@
"""Tests for the JTAG subsystem.""" """Tests for the JTAG subsystem."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest

View File

@ -1,4 +1,5 @@
"""Tests for the Memory subsystem.""" """Tests for the Memory subsystem."""
from __future__ import annotations from __future__ import annotations

View File

@ -1,4 +1,5 @@
"""Tests for the Registers subsystem.""" """Tests for the Registers subsystem."""
from __future__ import annotations from __future__ import annotations
from openocd.types import Register from openocd.types import Register

View File

@ -3,6 +3,7 @@
These tests exercise the bitfield decoder and DecodedRegister formatting These tests exercise the bitfield decoder and DecodedRegister formatting
using synthetic data, without needing an SVD file or a mock server. using synthetic data, without needing an SVD file or a mock server.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@ -14,6 +15,7 @@ from openocd.types import BitField, DecodedRegister
# -- Fake SVD objects to avoid needing a real .svd file ----------------------- # -- Fake SVD objects to avoid needing a real .svd file -----------------------
@dataclass @dataclass
class FakeSVDField: class FakeSVDField:
name: str name: str
@ -41,22 +43,26 @@ def gpioa_odr():
"""A fake GPIOA.ODR register with two bitfields.""" """A fake GPIOA.ODR register with two bitfields."""
fields = [ fields = [
FakeSVDField( FakeSVDField(
name="ODR0", bit_offset=0, bit_width=1, name="ODR0",
bit_offset=0,
bit_width=1,
description="Port output data bit 0", description="Port output data bit 0",
), ),
FakeSVDField( FakeSVDField(
name="ODR1", bit_offset=1, bit_width=1, name="ODR1",
bit_offset=1,
bit_width=1,
description="Port output data bit 1", description="Port output data bit 1",
), ),
FakeSVDField( FakeSVDField(
name="ODR15_2", bit_offset=2, bit_width=14, name="ODR15_2",
bit_offset=2,
bit_width=14,
description="Port output data bits 15:2", description="Port output data bits 15:2",
), ),
] ]
register = FakeSVDRegister(name="ODR", address_offset=0x14, fields=fields) register = FakeSVDRegister(name="ODR", address_offset=0x14, fields=fields)
peripheral = FakeSVDPeripheral( peripheral = FakeSVDPeripheral(name="GPIOA", base_address=0x40010800, registers=[register])
name="GPIOA", base_address=0x40010800, registers=[register]
)
return peripheral, register return peripheral, register
@ -68,27 +74,33 @@ def usart_cr1():
FakeSVDField(name="RE", bit_offset=2, bit_width=1, description="Receiver enable"), FakeSVDField(name="RE", bit_offset=2, bit_width=1, description="Receiver enable"),
FakeSVDField(name="TE", bit_offset=3, bit_width=1, description="Transmitter enable"), FakeSVDField(name="TE", bit_offset=3, bit_width=1, description="Transmitter enable"),
FakeSVDField( FakeSVDField(
name="RXNEIE", bit_offset=5, bit_width=1, name="RXNEIE",
bit_offset=5,
bit_width=1,
description="RXNE interrupt enable", description="RXNE interrupt enable",
), ),
FakeSVDField( FakeSVDField(
name="TCIE", bit_offset=6, bit_width=1, name="TCIE",
bit_offset=6,
bit_width=1,
description="Transmission complete IE", description="Transmission complete IE",
), ),
FakeSVDField( FakeSVDField(
name="TXEIE", bit_offset=7, bit_width=1, name="TXEIE",
bit_offset=7,
bit_width=1,
description="TXE interrupt enable", description="TXE interrupt enable",
), ),
FakeSVDField(name="M", bit_offset=12, bit_width=1, description="Word length"), FakeSVDField(name="M", bit_offset=12, bit_width=1, description="Word length"),
FakeSVDField( FakeSVDField(
name="OVER8", bit_offset=15, bit_width=1, name="OVER8",
bit_offset=15,
bit_width=1,
description="Oversampling mode", description="Oversampling mode",
), ),
] ]
register = FakeSVDRegister(name="CR1", address_offset=0x0C, fields=fields) register = FakeSVDRegister(name="CR1", address_offset=0x0C, fields=fields)
peripheral = FakeSVDPeripheral( peripheral = FakeSVDPeripheral(name="USART1", base_address=0x40013800, registers=[register])
name="USART1", base_address=0x40013800, registers=[register]
)
return peripheral, register return peripheral, register

View File

@ -1,4 +1,5 @@
"""Tests for the Target subsystem.""" """Tests for the Target subsystem."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest