diff --git a/src/mcserial/server.py b/src/mcserial/server.py index eae1a91..4c65c5b 100644 --- a/src/mcserial/server.py +++ b/src/mcserial/server.py @@ -2,6 +2,7 @@ from __future__ import annotations +import atexit import os from dataclasses import dataclass, field from typing import Literal @@ -23,12 +24,34 @@ class SerialConnection: port: str connection: serial.Serial buffer: bytes = field(default_factory=bytes) + mode: Literal["rs232", "rs485"] = "rs232" # Active connections registry +# NOTE: This dict is NOT thread-safe. The MCP server is designed for single-threaded +# async operation (FastMCP uses async/await). If you need thread-safe access, +# wrap operations with threading.Lock or use asyncio.Lock for async contexts. _connections: dict[str, SerialConnection] = {} +def _cleanup_connections() -> None: + """Close all open serial connections on exit. + + Registered with atexit to ensure ports are released when the server stops, + preventing port lockups that would require manual intervention. + """ + for _port, conn in list(_connections.items()): + try: + if conn.connection.is_open: + conn.connection.close() + except Exception: + pass # Best effort - don't raise during shutdown + _connections.clear() + + +atexit.register(_cleanup_connections) + + def _detect_baud_rate_internal( port: str, probe: str | None = None, @@ -116,21 +139,19 @@ def _detect_baud_rate_internal( for rate in rates_to_try: try: - conn = serial.Serial(port=port, baudrate=rate, timeout=timeout_per_rate) + with serial.Serial(port=port, baudrate=rate, timeout=timeout_per_rate) as conn: + if probe: + conn.write(probe.encode()) + conn.flush() + time.sleep(0.05) - if probe: - conn.write(probe.encode()) - conn.flush() - time.sleep(0.05) + time.sleep(timeout_per_rate) + available = conn.in_waiting + data = conn.read(available) if available else b"" - time.sleep(timeout_per_rate) - available = conn.in_waiting - data = conn.read(available) if available else b"" - conn.close() - - score_result = score_data(data) - if score_result["bytes_received"] > 0: - results.append({"baudrate": rate, **score_result}) + score_result = score_data(data) + if score_result["bytes_received"] > 0: + results.append({"baudrate": rate, **score_result}) except serial.SerialException: continue @@ -145,11 +166,57 @@ def _detect_baud_rate_internal( } +def _require_mode(port: str, required_mode: Literal["rs232", "rs485"]) -> dict | None: + """Check if port is in the required mode. Returns error dict if wrong mode.""" + if port not in _connections: + return {"error": f"Port {port} is not open", "success": False} + + current_mode = _connections[port].mode + if current_mode != required_mode: + other_mode = "RS-232" if required_mode == "rs485" else "RS-485" + return { + "error": f"Port {port} is in {current_mode.upper()} mode. " + f"Use set_port_mode(port='{port}', mode='{required_mode}') to switch to {required_mode.upper()} mode first.", + "success": False, + "current_mode": current_mode, + "required_mode": required_mode, + "hint": f"This tool is for {required_mode.upper()} operations. " + f"Switch modes or use {other_mode} tools instead.", + } + return None # Mode is correct + + mcp = FastMCP( name="mcserial", - instructions="""Serial port MCP server. Use tools to open/close/write to serial ports. -Use resources to read data from open ports (serial://{port}/data). -Always list_serial_ports first to discover available ports.""", + instructions="""Serial port MCP server for RS-232/RS-485 communication and file transfers. + +## Quick Start +1. list_serial_ports() - Discover available ports +2. open_serial_port(port) - Open a connection (defaults to RS-232 mode) +3. Read/write data using tools or resources + +## Modes +Ports open in RS-232 mode by default. Use set_port_mode() to switch: + +**RS-232 mode** (default): Point-to-point serial with modem control lines. + - Tools: get_modem_lines, set_modem_lines, pulse_line, send_break + +**RS-485 mode**: Half-duplex multi-drop bus communication. + - Tools: set_rs485_mode, rs485_transact, rs485_scan_addresses + +## File Transfers (X/Y/ZMODEM) +Transfer files using classic protocols. Works in both modes. + - file_transfer_send(port, file, protocol="zmodem") + - file_transfer_receive(port, save_path, protocol="zmodem") + - file_transfer_send_batch(port, files) - YMODEM/ZMODEM only + +Protocols: xmodem (128B blocks), xmodem1k, ymodem (batch), zmodem (streaming, recommended) + +## Resources +- serial://ports - List available ports +- serial://{port}/data - Read data from open port +- serial://{port}/status - Port configuration and mode +- serial://{port}/raw - Read as hex dump""", ) @@ -283,6 +350,7 @@ def open_serial_port( result = { "success": True, "port": port, + "mode": "rs232", # Default mode "baudrate": baudrate, "bytesize": bytesize, "parity": parity, @@ -292,6 +360,7 @@ def open_serial_port( "dsrdtr": dsrdtr, "exclusive": exclusive, "resource_uri": f"serial://{port}/data", + "mode_hint": "Use set_port_mode() to switch to RS-485 mode if needed.", } if detected_info: result["autobaud"] = detected_info @@ -360,6 +429,14 @@ def write_serial_bytes(port: str, data: list[int]) -> dict: if port not in _connections: return {"error": f"Port {port} is not open", "success": False} + # Validate byte range before conversion - bytes() silently truncates values > 255 + invalid_values = [b for b in data if not (0 <= b <= 255)] + if invalid_values: + return { + "error": f"Byte values must be 0-255, got invalid values: {invalid_values[:5]}{'...' if len(invalid_values) > 5 else ''}", + "success": False, + } + try: conn = _connections[port].connection raw_bytes = bytes(data) @@ -520,7 +597,7 @@ def get_connection_status() -> dict: """Get status of all open serial connections. Returns: - Dictionary of open connections with their settings + Dictionary of open connections with their settings and current mode """ status = {} for port, sc in _connections.items(): @@ -528,6 +605,7 @@ def get_connection_status() -> dict: try: status[port] = { "is_open": conn.is_open, + "mode": sc.mode, "baudrate": conn.baudrate, "bytesize": conn.bytesize, "parity": conn.parity, @@ -549,7 +627,7 @@ def get_connection_status() -> dict: "resource_uri": f"serial://{port}/data", } except serial.SerialException: - status[port] = {"is_open": False, "error": "Port disconnected"} + status[port] = {"is_open": False, "mode": sc.mode, "error": "Port disconnected"} return {"connections": status, "count": len(_connections)} @@ -579,10 +657,70 @@ def flush_serial(port: str, input_buffer: bool = True, output_buffer: bool = Tru return {"error": str(e), "success": False} +@mcp.tool() +def set_port_mode(port: str, mode: Literal["rs232", "rs485"]) -> dict: + """Switch a serial port between RS-232 and RS-485 modes. + + Mode determines which tools are available: + + RS-232 mode (default): + - Standard point-to-point serial communication + - Full modem control lines (RTS, DTR, CTS, DSR, RI, CD) + - Tools: get_modem_lines, set_modem_lines, pulse_line, send_break + + RS-485 mode: + - Half-duplex multi-drop bus communication + - Automatic or manual TX/RX direction control + - Tools: set_rs485_mode, rs485_transact, rs485_scan_addresses, check_rs485_support + + Common tools work in both modes: + - open/close, read/write, configure, flush, detect_baud_rate + + Args: + port: Device path of the port to configure + mode: Target mode ("rs232" or "rs485") + + Returns: + Mode change status with available tools for the new mode + """ + if port not in _connections: + return {"error": f"Port {port} is not open", "success": False} + + old_mode = _connections[port].mode + _connections[port].mode = mode + + mode_tools = { + "rs232": [ + "get_modem_lines", + "set_modem_lines", + "pulse_line", + "send_break", + "set_break_condition", + ], + "rs485": [ + "set_rs485_mode", + "rs485_transact", + "rs485_scan_addresses", + "check_rs485_support", + ], + } + + return { + "success": True, + "port": port, + "previous_mode": old_mode, + "current_mode": mode, + "mode_tools": mode_tools[mode], + "hint": f"Port is now in {mode.upper()} mode. Use the listed tools for {mode.upper()} operations.", + } + + @mcp.tool() def get_modem_lines(port: str) -> dict: """Get all RS-232 modem control/status line states. + **Requires RS-232 mode** (default). Use set_port_mode() to switch if needed. + Input lines (directly readable from device): - CTS: Clear To Send (device ready to receive) - DSR: Data Set Ready (device is present/powered) @@ -599,8 +737,8 @@ def get_modem_lines(port: str) -> dict: Returns: All modem line states """ - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs232"): + return mode_error try: conn = _connections[port].connection @@ -630,6 +768,8 @@ def set_modem_lines( ) -> dict: """Set RS-232 output control lines (RTS and DTR). + **Requires RS-232 mode** (default). Use set_port_mode() to switch if needed. + These lines can be used for: - Hardware flow control - Device reset sequences (many boards use DTR for reset) @@ -644,8 +784,8 @@ def set_modem_lines( Returns: Updated line states """ - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs232"): + return mode_error try: conn = _connections[port].connection @@ -673,6 +813,8 @@ def pulse_line( ) -> dict: """Pulse an RS-232 control line (useful for reset sequences). + **Requires RS-232 mode** (default). Use set_port_mode() to switch if needed. + Many devices use DTR or RTS for reset: - ESP32/ESP8266: DTR low + RTS sequence for bootloader - Arduino: DTR pulse for reset @@ -689,8 +831,8 @@ def pulse_line( """ import time - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs232"): + return mode_error if duration_ms < 1 or duration_ms > 5000: return {"error": "duration_ms must be between 1 and 5000", "success": False} @@ -729,6 +871,8 @@ def pulse_line( def send_break(port: str, duration_ms: int = 250) -> dict: """Send a serial break signal. + **Requires RS-232 mode** (default). Use set_port_mode() to switch if needed. + A break is a sustained low signal longer than a character frame, used to get attention of remote device or trigger special modes. @@ -739,8 +883,8 @@ def send_break(port: str, duration_ms: int = 250) -> dict: Returns: Break operation status """ - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs232"): + return mode_error if duration_ms < 1 or duration_ms > 5000: return {"error": "duration_ms must be between 1 and 5000", "success": False} @@ -802,7 +946,9 @@ def set_rs485_mode( rts_level_for_rx: bool = False, loopback: bool = False, ) -> dict: - """Configure RS-485 mode for half-duplex communication. + """Configure RS-485 hardware mode for half-duplex communication. + + **Requires RS-485 mode.** Use set_port_mode(port, "rs485") first. RS-485 is commonly used in industrial applications (Modbus, etc.) where multiple devices share a bus. The driver must control the @@ -820,8 +966,8 @@ def set_rs485_mode( Returns: RS-485 configuration status """ - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs485"): + return mode_error try: conn = _connections[port].connection @@ -970,6 +1116,8 @@ def rs485_transact( ) -> dict: """Send data and receive response on RS-485 bus (half-duplex transaction). + **Requires RS-485 mode.** Use set_port_mode(port, "rs485") first. + Handles the TX→RX turnaround timing automatically. For devices without hardware RS-485 support, manually controls RTS around the transaction. @@ -989,8 +1137,8 @@ def rs485_transact( """ import time - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs485"): + return mode_error try: conn = _connections[port].connection @@ -1063,6 +1211,8 @@ def rs485_scan_addresses( ) -> dict: """Scan RS-485 bus for responding devices (address discovery). + **Requires RS-485 mode.** Use set_port_mode(port, "rs485") first. + Sends a probe message to each address and records which ones respond. Useful for discovering Modbus or similar addressed devices on the bus. @@ -1080,8 +1230,8 @@ def rs485_scan_addresses( """ import time - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs485"): + return mode_error try: conn = _connections[port].connection @@ -1175,6 +1325,8 @@ def set_low_latency_mode(port: str, enabled: bool = True) -> dict: def set_break_condition(port: str, enabled: bool) -> dict: """Set or clear the break condition on a serial port. + **Requires RS-232 mode** (default). Use set_port_mode() to switch if needed. + Unlike send_break() which sends a timed pulse, this holds the break condition until explicitly cleared. Useful for protocols that require sustained break states. @@ -1186,8 +1338,8 @@ def set_break_condition(port: str, enabled: bool) -> dict: Returns: Break condition status """ - if port not in _connections: - return {"error": f"Port {port} is not open", "success": False} + if mode_error := _require_mode(port, "rs232"): + return mode_error try: conn = _connections[port].connection @@ -1440,6 +1592,236 @@ def detect_baud_rate( return {"error": str(e), "success": False} +# ============================================================================ +# FILE TRANSFER TOOLS - X/Y/ZMODEM protocols +# ============================================================================ + + +@mcp.tool() +def file_transfer_send( + port: str, + file_path: str, + protocol: Literal["xmodem", "xmodem1k", "ymodem", "zmodem"] = "zmodem", +) -> dict: + """Send a file over serial using X/Y/ZMODEM protocol. + + Works in both RS-232 and RS-485 modes. + + Protocols: + - xmodem: 128-byte blocks, simple but compatible (1977) + - xmodem1k: 1024-byte blocks, faster than basic XMODEM + - ymodem: Batch mode, sends filename/size (1985) + - zmodem: Streaming, auto-resume, most efficient (1986) [recommended] + + The receiver must be waiting in receive mode before calling this. + For ZMODEM, the receiver typically auto-starts when it sees the init sequence. + + Args: + port: Device path of the open serial port + file_path: Path to the file to send + protocol: Transfer protocol to use (default: zmodem) + + Returns: + Transfer statistics including bytes sent and any errors + """ + from pathlib import Path + + if port not in _connections: + return {"error": f"Port {port} is not open", "success": False} + + filepath = Path(file_path) + if not filepath.exists(): + return {"error": f"File not found: {file_path}", "success": False} + + conn = _connections[port].connection + + # Create read/write functions for the protocol handlers + def read_func(n: int) -> bytes: + return conn.read(n) + + def write_func(data: bytes) -> int: + return conn.write(data) + + try: + if protocol in ("xmodem", "xmodem1k"): + from mcserial.xmodem import XModem + + xm = XModem(read_func, write_func, mode=protocol) + with open(filepath, "rb") as f: + result = xm.send(f) + result["protocol"] = protocol + result["file"] = str(filepath) + return result + + elif protocol == "ymodem": + from mcserial.ymodem import YModem + + ym = YModem(read_func, write_func) + result = ym.send([filepath]) + result["protocol"] = "ymodem" + return result + + elif protocol == "zmodem": + from mcserial.zmodem import ZModem + + zm = ZModem(read_func, write_func) + result = zm.send([filepath]) + result["protocol"] = "zmodem" + return result + + else: + return {"error": f"Unknown protocol: {protocol}", "success": False} + + except Exception as e: + return {"error": str(e), "success": False, "protocol": protocol} + + +@mcp.tool() +def file_transfer_receive( + port: str, + save_path: str, + protocol: Literal["xmodem", "ymodem", "zmodem"] = "zmodem", + overwrite: bool = False, +) -> dict: + """Receive a file over serial using X/Y/ZMODEM protocol. + + Works in both RS-232 and RS-485 modes. + + Protocols: + - xmodem: Must specify exact filename, no metadata from sender + - ymodem: Receives filename from sender, batch capable + - zmodem: Streaming, auto-resume capable [recommended] + + For XMODEM, save_path is the full file path. + For YMODEM/ZMODEM, save_path is a directory where files will be saved. + + Args: + port: Device path of the open serial port + save_path: File path (xmodem) or directory (ymodem/zmodem) + protocol: Transfer protocol to use (default: zmodem) + overwrite: Whether to overwrite existing files + + Returns: + Transfer statistics including bytes received and file paths + """ + from pathlib import Path + + if port not in _connections: + return {"error": f"Port {port} is not open", "success": False} + + conn = _connections[port].connection + + def read_func(n: int) -> bytes: + return conn.read(n) + + def write_func(data: bytes) -> int: + return conn.write(data) + + try: + if protocol == "xmodem": + from mcserial.xmodem import XModem + + filepath = Path(save_path) + if filepath.exists() and not overwrite: + return {"error": f"File exists: {save_path}", "success": False} + + filepath.parent.mkdir(parents=True, exist_ok=True) + + xm = XModem(read_func, write_func) + with open(filepath, "wb") as f: + result = xm.receive(f) + result["protocol"] = "xmodem" + result["file"] = str(filepath) + return result + + elif protocol == "ymodem": + from mcserial.ymodem import YModem + + directory = Path(save_path) + directory.mkdir(parents=True, exist_ok=True) + + ym = YModem(read_func, write_func) + result = ym.receive(directory, overwrite=overwrite) + result["protocol"] = "ymodem" + return result + + elif protocol == "zmodem": + from mcserial.zmodem import ZModem + + directory = Path(save_path) + directory.mkdir(parents=True, exist_ok=True) + + zm = ZModem(read_func, write_func) + result = zm.receive(directory, overwrite=overwrite) + result["protocol"] = "zmodem" + return result + + else: + return {"error": f"Unknown protocol: {protocol}", "success": False} + + except Exception as e: + return {"error": str(e), "success": False, "protocol": protocol} + + +@mcp.tool() +def file_transfer_send_batch( + port: str, + file_paths: list[str], + protocol: Literal["ymodem", "zmodem"] = "zmodem", +) -> dict: + """Send multiple files in a batch transfer. + + Only YMODEM and ZMODEM support batch transfers. + + Args: + port: Device path of the open serial port + file_paths: List of file paths to send + protocol: Transfer protocol (ymodem or zmodem) + + Returns: + Transfer statistics for all files + """ + from pathlib import Path + + if port not in _connections: + return {"error": f"Port {port} is not open", "success": False} + + if protocol not in ("ymodem", "zmodem"): + return {"error": "Batch transfer requires ymodem or zmodem", "success": False} + + # Validate all files exist + paths = [] + for fp in file_paths: + p = Path(fp) + if not p.exists(): + return {"error": f"File not found: {fp}", "success": False} + paths.append(p) + + conn = _connections[port].connection + + def read_func(n: int) -> bytes: + return conn.read(n) + + def write_func(data: bytes) -> int: + return conn.write(data) + + try: + if protocol == "ymodem": + from mcserial.ymodem import YModem + ym = YModem(read_func, write_func) + result = ym.send(paths) + else: + from mcserial.zmodem import ZModem + zm = ZModem(read_func, write_func) + result = zm.send(paths) + + result["protocol"] = protocol + return result + + except Exception as e: + return {"error": str(e), "success": False, "protocol": protocol} + + # ============================================================================ # RESOURCES - Dynamic data access via URIs # ============================================================================ @@ -1493,9 +1875,11 @@ def resource_port_status(port: str) -> str: return f"Port {port} is not open." try: - conn = _connections[port].connection + sc = _connections[port] + conn = sc.connection lines = [ f"# Serial Port Status: {port}", + f"- Mode: {sc.mode.upper()}", f"- Open: {conn.is_open}", f"- Baudrate: {conn.baudrate}", f"- Bytesize: {conn.bytesize}", diff --git a/src/mcserial/xmodem.py b/src/mcserial/xmodem.py new file mode 100644 index 0000000..d58e300 --- /dev/null +++ b/src/mcserial/xmodem.py @@ -0,0 +1,407 @@ +"""XMODEM file transfer protocol implementation. + +XMODEM is the simplest serial file transfer protocol: +- 128-byte blocks (XMODEM) or 1024-byte blocks (XMODEM-1K) +- Checksum or CRC-16 error detection +- Stop-and-wait: waits for ACK after each block + +Protocol flow: + Receiver sends NAK (checksum) or 'C' (CRC) to initiate + Sender transmits: SOH + block# + ~block# + 128 bytes + check + Receiver sends ACK or NAK + Repeat until EOT +""" + +from __future__ import annotations + +import logging +import time +from collections.abc import Callable +from typing import BinaryIO + +# Protocol constants +SOH = 0x01 # Start of 128-byte header +STX = 0x02 # Start of 1024-byte header (XMODEM-1K) +EOT = 0x04 # End of transmission +ACK = 0x06 # Acknowledge +NAK = 0x15 # Negative acknowledge +CAN = 0x18 # Cancel +CRC_MODE = 0x43 # 'C' - request CRC mode + +# Default parameters +DEFAULT_RETRY_LIMIT = 16 # Max retries per block +DEFAULT_TIMEOUT = 60.0 # Total transfer timeout in seconds +READ_TIMEOUT_RETRIES = 30 # Retries when waiting for response byte +INIT_TIMEOUT_RETRIES = 10 # Retries when waiting for transfer initiation + +logger = logging.getLogger(__name__) + + +def _calc_checksum(data: bytes) -> int: + """Calculate simple 8-bit checksum.""" + return sum(data) & 0xFF + + +def _calc_crc16(data: bytes) -> int: + """Calculate CRC-16-CCITT (XMODEM variant).""" + crc = 0 + for byte in data: + crc ^= byte << 8 + for _ in range(8): + if crc & 0x8000: + crc = (crc << 1) ^ 0x1021 + else: + crc <<= 1 + crc &= 0xFFFF + return crc + + +class XModemError(Exception): + """Base exception for XMODEM errors.""" + pass + + +class XModem: + """XMODEM file transfer protocol handler. + + Args: + read_func: Callable that reads n bytes from serial, returns bytes + write_func: Callable that writes bytes to serial + mode: "xmodem" (128-byte) or "xmodem1k" (1024-byte blocks) + use_crc: Use CRC-16 instead of checksum (default True) + """ + + def __init__( + self, + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + mode: str = "xmodem", + use_crc: bool = True, + ): + self.read = read_func + self.write = write_func + self.mode = mode + self.use_crc = use_crc + self.block_size = 1024 if mode == "xmodem1k" else 128 + + def _read_byte(self, timeout_retries: int = INIT_TIMEOUT_RETRIES) -> int | None: + """Read a single byte with retry. + + Args: + timeout_retries: Number of read attempts before returning None + + Returns: + Byte value (0-255) or None if no data available after retries + """ + for _ in range(timeout_retries): + data = self.read(1) + if data: + return data[0] + return None + + def _make_block(self, block_num: int, data: bytes) -> bytes: + """Create a complete XMODEM block with header, data, and checksum/CRC. + + Args: + block_num: Block sequence number (0-255, wraps) + data: Payload data (will be padded to block_size with SUB chars) + + Returns: + Complete block bytes ready for transmission + """ + # Pad data to block size with SUB (0x1A) - standard XMODEM padding + if len(data) < self.block_size: + data = data + bytes([0x1A] * (self.block_size - len(data))) + + header = STX if self.block_size == 1024 else SOH + block = bytes([header, block_num & 0xFF, (255 - block_num) & 0xFF]) + data + + if self.use_crc: + crc = _calc_crc16(data) + block += bytes([crc >> 8, crc & 0xFF]) + else: + block += bytes([_calc_checksum(data)]) + + return block + + def _verify_block(self, data: bytes, check_bytes: bytes) -> bool: + """Verify block integrity using checksum or CRC. + + Args: + data: The block payload to verify + check_bytes: Received checksum (1 byte) or CRC (2 bytes) + + Returns: + True if verification passes, False otherwise + """ + if self.use_crc: + expected = _calc_crc16(data) + received = (check_bytes[0] << 8) | check_bytes[1] + return expected == received + else: + expected = _calc_checksum(data) + return expected == check_bytes[0] + + def send( + self, + stream: BinaryIO, + callback: Callable[[int, int], None] | None = None, + retry_limit: int = DEFAULT_RETRY_LIMIT, + timeout: float = DEFAULT_TIMEOUT, + ) -> dict: + """Send a file via XMODEM. + + Args: + stream: File-like object to read from + callback: Progress callback(bytes_sent, total_bytes) + retry_limit: Max retries per block + timeout: Total timeout in seconds (enforced across entire transfer) + + Returns: + Dict with transfer statistics + """ + start_time = time.monotonic() + + # Get file size for progress + stream.seek(0, 2) # Seek to end + total_size = stream.tell() + stream.seek(0) # Seek back to start + + bytes_sent = 0 + block_num = 1 + errors = 0 + + # Wait for receiver to initiate + logger.debug("Waiting for receiver initiation...") + init_byte = None + for _ in range(retry_limit * 10): + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s) waiting for receiver initiation"} + + b = self._read_byte(timeout_retries=3) + if b == CRC_MODE: + self.use_crc = True + init_byte = b + logger.debug("Receiver requested CRC mode") + break + elif b == NAK: + self.use_crc = False + init_byte = b + logger.debug("Receiver requested checksum mode") + break + elif b == CAN: + return {"success": False, "error": "Transfer cancelled by receiver"} + + if init_byte is None: + return {"success": False, "error": "Timeout waiting for receiver"} + + # Send blocks + while True: + # Check timeout + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s) during transfer at block {block_num}"} + + data = stream.read(self.block_size) + if not data: + break + + block = self._make_block(block_num, data) + retries = 0 + + while retries < retry_limit: + self.write(block) + logger.debug(f"Sent block {block_num}, {len(data)} bytes") + + response = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + + if response == ACK: + bytes_sent += len(data) + if callback: + callback(bytes_sent, total_size) + block_num = (block_num + 1) & 0xFF + break + elif response == NAK: + retries += 1 + errors += 1 + logger.debug(f"Block {block_num} NAK'd, retry {retries}") + elif response == CAN: + return {"success": False, "error": "Transfer cancelled by receiver"} + else: + retries += 1 + errors += 1 + logger.debug(f"Block {block_num} no response, retry {retries}") + else: + return {"success": False, "error": f"Max retries exceeded at block {block_num}"} + + # Send EOT + for _ in range(retry_limit): + self.write(bytes([EOT])) + response = self._read_byte(timeout_retries=10) + if response == ACK: + break + elif response == NAK: + continue # Resend EOT + + return { + "success": True, + "bytes_sent": bytes_sent, + "blocks": block_num - 1, + "errors": errors, + "mode": "crc" if self.use_crc else "checksum", + } + + def receive( + self, + stream: BinaryIO, + callback: Callable[[int], None] | None = None, + retry_limit: int = DEFAULT_RETRY_LIMIT, + timeout: float = DEFAULT_TIMEOUT, + ) -> dict: + """Receive a file via XMODEM. + + Args: + stream: File-like object to write to + callback: Progress callback(bytes_received) + retry_limit: Max retries per block + timeout: Total timeout in seconds (enforced across entire transfer) + + Returns: + Dict with transfer statistics + """ + start_time = time.monotonic() + bytes_received = 0 + expected_block = 1 + errors = 0 + + # Initiate transfer + init_char = CRC_MODE if self.use_crc else NAK + logger.debug(f"Initiating transfer with {'CRC' if self.use_crc else 'checksum'} mode") + + for attempt in range(retry_limit): + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s) waiting for sender"} + + self.write(bytes([init_char])) + + # Wait for first byte + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + if header in (SOH, STX): + break + elif header == CAN: + return {"success": False, "error": "Transfer cancelled by sender"} + + # Fallback to checksum if CRC not supported + if attempt == 3 and self.use_crc: + logger.debug("CRC mode not responding, trying checksum") + self.use_crc = False + init_char = NAK + else: + return {"success": False, "error": "Timeout waiting for sender"} + + # Receive blocks + while True: + # Check timeout + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s) during receive at block {expected_block}"} + + if header == EOT: + self.write(bytes([ACK])) + break + + # Determine block size from header + block_size = 1024 if header == STX else 128 + check_size = 2 if self.use_crc else 1 + + # Read block number + block_num = self._read_byte() + block_num_comp = self._read_byte() + + if block_num is None or block_num_comp is None: + errors += 1 + self.write(bytes([NAK])) + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + continue + + # Verify block number complement + if (block_num + block_num_comp) & 0xFF != 0xFF: + logger.debug(f"Block number mismatch: {block_num} vs {block_num_comp}") + errors += 1 + self.write(bytes([NAK])) + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + continue + + # Read data + data = self.read(block_size) + if len(data) != block_size: + logger.debug(f"Short block: {len(data)} bytes") + errors += 1 + self.write(bytes([NAK])) + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + continue + + # Read and verify check bytes + check_bytes = self.read(check_size) + if len(check_bytes) != check_size: + errors += 1 + self.write(bytes([NAK])) + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + continue + + if not self._verify_block(data, check_bytes): + logger.debug(f"Block {block_num} checksum/CRC error") + errors += 1 + self.write(bytes([NAK])) + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + continue + + # Check sequence + if block_num == expected_block: + stream.write(data) + bytes_received += len(data) + expected_block = (expected_block + 1) & 0xFF + if callback: + callback(bytes_received) + elif block_num == (expected_block - 1) & 0xFF: + # Duplicate block, already ACK'd + logger.debug(f"Duplicate block {block_num}") + else: + logger.debug(f"Out of sequence: expected {expected_block}, got {block_num}") + errors += 1 + self.write(bytes([NAK])) + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + continue + + self.write(bytes([ACK])) + header = self._read_byte(timeout_retries=READ_TIMEOUT_RETRIES) + + return { + "success": True, + "bytes_received": bytes_received, + "blocks": expected_block - 1, + "errors": errors, + "mode": "crc" if self.use_crc else "checksum", + } + + +def send_xmodem( + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + stream: BinaryIO, + mode: str = "xmodem", + callback: Callable[[int, int], None] | None = None, +) -> dict: + """Convenience function to send a file via XMODEM.""" + xm = XModem(read_func, write_func, mode=mode) + return xm.send(stream, callback) + + +def receive_xmodem( + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + stream: BinaryIO, + mode: str = "xmodem", + callback: Callable[[int], None] | None = None, +) -> dict: + """Convenience function to receive a file via XMODEM.""" + xm = XModem(read_func, write_func, mode=mode) + return xm.receive(stream, callback) diff --git a/src/mcserial/ymodem.py b/src/mcserial/ymodem.py new file mode 100644 index 0000000..78d007b --- /dev/null +++ b/src/mcserial/ymodem.py @@ -0,0 +1,542 @@ +"""YMODEM file transfer protocol implementation. + +YMODEM extends XMODEM-1K with: +- Batch transfers (multiple files in one session) +- File metadata in block 0 (filename, size, modification time) +- CRC-16 required (no checksum mode) + +Protocol flow: + Receiver sends 'C' to initiate + Sender sends block 0: filename + size + mtime (null-terminated strings) + Receiver ACKs, then sends 'C' again for data + Sender transmits data blocks (1024 bytes, XMODEM-1K style) + Sender sends EOT, receiver NAKs, sender sends EOT again, receiver ACKs + For batch: repeat with next file, or send empty block 0 to end session +""" + +from __future__ import annotations + +import contextlib +import logging +import os +import time +from collections.abc import Callable +from pathlib import Path + +from mcserial.xmodem import ( + ACK, + CAN, + CRC_MODE, + EOT, + NAK, + SOH, + STX, + XModemError, + _calc_crc16, +) + +logger = logging.getLogger(__name__) + +# Transfer limits +DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit per file + + +def _sanitize_filename(filename: str) -> str: + """Remove path traversal attempts and dangerous characters from filename. + + Security: Prevents directory traversal attacks where malicious senders + could write files outside the target directory using names like + '../../../etc/passwd' or absolute paths like '/etc/cron.d/backdoor'. + """ + # Get just the basename, removing any directory components + name = Path(filename).name + + # Reject empty names + if not name: + name = "unnamed_file" + + # Prefix hidden files (starting with dot) to make them visible + if name.startswith("."): + name = "_" + name[1:] + + # Replace any remaining problematic characters + name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_") + + return name + + +def _open_file_atomic(filepath: Path, overwrite: bool) -> tuple: + """Open file for writing with atomic creation to prevent TOCTOU races. + + Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False), + preventing race conditions between existence check and file creation. + + Returns: + Tuple of (file_object, None) on success, or (None, error_message) on failure + """ + if overwrite: + # Overwrite mode: just open normally + return open(filepath, "wb"), None + + try: + # Atomic create: fails if file exists (O_EXCL) + fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) + return os.fdopen(fd, "wb"), None + except FileExistsError: + return None, "File exists" + except OSError as e: + return None, str(e) + + +class YModemError(XModemError): + """YMODEM-specific error.""" + pass + + +class YModem: + """YMODEM batch file transfer protocol handler. + + Args: + read_func: Callable that reads n bytes from serial, returns bytes + write_func: Callable that writes bytes to serial + """ + + def __init__( + self, + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + ): + self.read = read_func + self.write = write_func + + def _read_byte(self, timeout_retries: int = 10) -> int | None: + """Read a single byte with retry.""" + for _ in range(timeout_retries): + data = self.read(1) + if data: + return data[0] + return None + + def _make_block(self, block_num: int, data: bytes, block_size: int = 1024) -> bytes: + """Create a YMODEM block with CRC-16.""" + # Pad data to block size + if len(data) < block_size: + data = data + bytes(block_size - len(data)) + + header = STX if block_size == 1024 else SOH + block = bytes([header, block_num & 0xFF, (255 - block_num) & 0xFF]) + data + + crc = _calc_crc16(data) + block += bytes([crc >> 8, crc & 0xFF]) + + return block + + def _make_block0(self, filename: str, filesize: int, mtime: int | None = None) -> bytes: + """Create YMODEM block 0 with file metadata. + + Block 0 format: filename\0size mtime\0 (space-separated, octal for time) + """ + # Build metadata string + if mtime is not None: + meta = f"{filename}\x00{filesize} {mtime:o}" + else: + meta = f"{filename}\x00{filesize}" + + data = meta.encode("latin-1") + return data + + def _parse_block0(self, data: bytes) -> tuple[dict | None, str]: + """Parse YMODEM block 0 metadata. + + Returns: + Tuple of (metadata_dict, status) where: + - (dict, "ok"): Successfully parsed file metadata + - (None, "end_of_batch"): Empty block 0, signals end of batch transfer + - (None, "parse_error"): Failed to parse block 0 data + """ + # Find null terminator after filename + try: + null_pos = data.index(0) + filename = data[:null_pos].decode("latin-1").strip() + + if not filename: + return None, "end_of_batch" + + # Parse size and optional mtime + rest = data[null_pos + 1:].split(b" ") + rest = [x for x in rest if x and x != b"\x00"] + + filesize = int(rest[0]) if rest else 0 + mtime = int(rest[1], 8) if len(rest) > 1 else None + + return { + "filename": filename, + "size": filesize, + "mtime": mtime, + }, "ok" + except (ValueError, IndexError): + return None, "parse_error" + + def send( + self, + files: list[str | Path], + callback: Callable[[str, int, int], None] | None = None, + retry_limit: int = 16, + ) -> dict: + """Send files via YMODEM batch transfer. + + Args: + files: List of file paths to send + callback: Progress callback(filename, bytes_sent, total_bytes) + retry_limit: Max retries per block + + Returns: + Dict with transfer statistics + """ + results = [] + total_bytes = 0 + total_errors = 0 + + for filepath in files: + filepath = Path(filepath) + if not filepath.exists(): + results.append({"file": str(filepath), "error": "File not found"}) + continue + + filesize = filepath.stat().st_size + mtime = int(filepath.stat().st_mtime) + + # Wait for receiver 'C' + logger.debug(f"Waiting for receiver to initiate {filepath.name}...") + for _ in range(retry_limit * 10): + b = self._read_byte(timeout_retries=3) + if b == CRC_MODE: + break + elif b == CAN: + return {"success": False, "error": "Transfer cancelled", "files": results} + else: + return {"success": False, "error": "Timeout waiting for receiver", "files": results} + + # Send block 0 with metadata + block0_data = self._make_block0(filepath.name, filesize, mtime) + block0 = self._make_block(0, block0_data, block_size=128) + + retries = 0 + while retries < retry_limit: + self.write(block0) + response = self._read_byte(timeout_retries=30) + if response == ACK: + break + elif response == CAN: + return {"success": False, "error": "Transfer cancelled", "files": results} + retries += 1 + else: + results.append({"file": str(filepath), "error": "Block 0 not acknowledged"}) + continue + + # Wait for second 'C' to start data + for _ in range(retry_limit * 5): + b = self._read_byte(timeout_retries=3) + if b == CRC_MODE: + break + else: + results.append({"file": str(filepath), "error": "No data initiation"}) + continue + + # Send file data + bytes_sent = 0 + block_num = 1 + errors = 0 + + with open(filepath, "rb") as f: + while True: + data = f.read(1024) + if not data: + break + + block = self._make_block(block_num, data) + retries = 0 + + while retries < retry_limit: + self.write(block) + response = self._read_byte(timeout_retries=30) + + if response == ACK: + bytes_sent += len(data) + if callback: + callback(filepath.name, bytes_sent, filesize) + block_num = (block_num + 1) & 0xFF + break + elif response == CAN: + results.append({"file": str(filepath), "error": "Cancelled"}) + break + else: + retries += 1 + errors += 1 + else: + results.append({"file": str(filepath), "error": f"Max retries at block {block_num}"}) + break + + # Send EOT sequence (NAK then ACK expected) + for _ in range(retry_limit): + self.write(bytes([EOT])) + response = self._read_byte(timeout_retries=10) + if response == NAK: + self.write(bytes([EOT])) + response = self._read_byte(timeout_retries=10) + if response == ACK: + break + elif response == ACK: + break + + results.append({ + "file": str(filepath), + "bytes_sent": bytes_sent, + "blocks": block_num - 1, + "errors": errors, + "success": True, + }) + total_bytes += bytes_sent + total_errors += errors + + # Send empty block 0 to end batch + for _ in range(retry_limit * 5): + b = self._read_byte(timeout_retries=3) + if b == CRC_MODE: + break + + empty_block0 = self._make_block(0, b"", block_size=128) + self.write(empty_block0) + self._read_byte(timeout_retries=10) # ACK + + return { + "success": all(r.get("success", False) for r in results), + "files": results, + "total_bytes": total_bytes, + "total_errors": total_errors, + } + + def receive( + self, + directory: str | Path, + callback: Callable[[str, int, int], None] | None = None, + retry_limit: int = 16, + overwrite: bool = False, + max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE, + ) -> dict: + """Receive files via YMODEM batch transfer. + + Args: + directory: Directory to save received files + callback: Progress callback(filename, bytes_received, total_bytes) + retry_limit: Max retries per block + overwrite: Overwrite existing files + max_transfer_size: Maximum bytes to receive per file (default 100MB). + Set to 0 to disable limit. Prevents unbounded memory usage. + + Returns: + Dict with transfer statistics + """ + directory = Path(directory) + directory.mkdir(parents=True, exist_ok=True) + + results = [] + total_bytes = 0 + total_errors = 0 + + while True: + # Initiate with 'C' for CRC mode + logger.debug("Initiating YMODEM receive...") + for _attempt in range(retry_limit): + self.write(bytes([CRC_MODE])) + header = self._read_byte(timeout_retries=30) + if header in (SOH, STX): + break + elif header == CAN: + return {"success": False, "error": "Cancelled", "files": results} + else: + return {"success": False, "error": "Timeout", "files": results} + + # Receive block 0 + block_size = 1024 if header == STX else 128 + block_num = self._read_byte() + block_num_comp = self._read_byte() + + if block_num != 0: + self.write(bytes([NAK])) + continue + + data = self.read(block_size) + crc_bytes = self.read(2) + + # Verify CRC + expected_crc = _calc_crc16(data) + received_crc = (crc_bytes[0] << 8) | crc_bytes[1] + if expected_crc != received_crc: + self.write(bytes([NAK])) + continue + + # Parse metadata + meta, status = self._parse_block0(data) + if status == "end_of_batch": + self.write(bytes([ACK])) + break + elif status == "parse_error": + logger.debug("Failed to parse block 0 metadata") + self.write(bytes([NAK])) + continue + + filename = meta["filename"] + filesize = meta["size"] + logger.debug(f"Receiving: {filename} ({filesize} bytes)") + + # Security: sanitize filename to prevent path traversal attacks + safe_filename = _sanitize_filename(filename) + if safe_filename != filename: + logger.warning(f"Sanitized filename: {filename!r} -> {safe_filename!r}") + + filepath = directory / safe_filename + + # Atomic file creation to prevent TOCTOU race conditions + f, file_error = _open_file_atomic(filepath, overwrite) + if file_error: + results.append({"file": filename, "error": file_error}) + self.write(bytes([CAN, CAN])) + continue + + self.write(bytes([ACK])) + + # Send 'C' to start data transfer + self.write(bytes([CRC_MODE])) + + # Receive data blocks + bytes_received = 0 + expected_block = 1 + errors = 0 + transfer_aborted = False + + try: + while bytes_received < filesize and not transfer_aborted: + header = self._read_byte(timeout_retries=60) + + if header == EOT: + self.write(bytes([NAK])) + header = self._read_byte(timeout_retries=10) + if header == EOT: + self.write(bytes([ACK])) + break + continue + + if header not in (SOH, STX): + errors += 1 + self.write(bytes([NAK])) + continue + + block_size = 1024 if header == STX else 128 + block_num = self._read_byte() + block_num_comp = self._read_byte() + + if (block_num + block_num_comp) & 0xFF != 0xFF: + errors += 1 + self.write(bytes([NAK])) + continue + + data = self.read(block_size) + crc_bytes = self.read(2) + + if len(data) != block_size or len(crc_bytes) != 2: + errors += 1 + self.write(bytes([NAK])) + continue + + expected_crc = _calc_crc16(data) + received_crc = (crc_bytes[0] << 8) | crc_bytes[1] + if expected_crc != received_crc: + errors += 1 + self.write(bytes([NAK])) + continue + + if block_num == expected_block: + # Trim last block to actual file size + remaining = filesize - bytes_received + if remaining < block_size: + data = data[:remaining] + + f.write(data) + bytes_received += len(data) + expected_block = (expected_block + 1) & 0xFF + + # Check transfer size limit to prevent unbounded memory usage + if max_transfer_size > 0 and bytes_received > max_transfer_size: + logger.warning( + f"Transfer aborted: {filename} exceeded {max_transfer_size} byte limit " + f"(received {bytes_received} bytes)" + ) + self.write(bytes([CAN, CAN])) + transfer_aborted = True + break + + if callback: + callback(filename, bytes_received, filesize) + + self.write(bytes([ACK])) + finally: + f.close() + + if transfer_aborted: + # Transfer was aborted due to size limit + results.append({ + "file": filename, + "path": str(filepath), + "error": f"Transfer size exceeded {max_transfer_size} byte limit", + "bytes_received": bytes_received, + "success": False, + }) + # Try to clean up partial file + with contextlib.suppress(OSError): + filepath.unlink() + continue + + # Set modification time if provided + if meta.get("mtime"): + os.utime(filepath, (time.time(), meta["mtime"])) + + results.append({ + "file": filename, + "path": str(filepath), + "bytes_received": bytes_received, + "errors": errors, + "success": True, + }) + total_bytes += bytes_received + total_errors += errors + + return { + "success": all(r.get("success", False) for r in results), + "files": results, + "total_bytes": total_bytes, + "total_errors": total_errors, + } + + +def send_ymodem( + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + files: list[str | Path], + callback: Callable[[str, int, int], None] | None = None, +) -> dict: + """Convenience function to send files via YMODEM.""" + ym = YModem(read_func, write_func) + return ym.send(files, callback) + + +def receive_ymodem( + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + directory: str | Path, + callback: Callable[[str, int, int], None] | None = None, + overwrite: bool = False, +) -> dict: + """Convenience function to receive files via YMODEM.""" + ym = YModem(read_func, write_func) + return ym.receive(directory, callback, overwrite=overwrite) diff --git a/src/mcserial/zmodem.py b/src/mcserial/zmodem.py new file mode 100644 index 0000000..124fc03 --- /dev/null +++ b/src/mcserial/zmodem.py @@ -0,0 +1,815 @@ +"""ZMODEM file transfer protocol implementation. + +ZMODEM is the most sophisticated serial file transfer protocol: +- Streaming with selective retransmission (no stop-and-wait) +- CRC-16 or CRC-32 error detection +- Crash recovery (resume interrupted transfers) +- Auto-start capability +- Escape sequence encoding for 8-bit transparency + +Frame structure: + ZPAD ZPAD ZDLE frame_type [4 bytes header data] [CRC] + +Data subpacket: + [escaped data] ZDLE frame_end_type [CRC] +""" + +from __future__ import annotations + +import contextlib +import logging +import os +import time +from collections.abc import Callable +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def _sanitize_filename(filename: str) -> str: + """Remove path traversal attempts and dangerous characters from filename. + + Security: Prevents directory traversal attacks where malicious senders + could write files outside the target directory. + """ + name = Path(filename).name + if not name: + name = "unnamed_file" + if name.startswith("."): + name = "_" + name[1:] + name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_") + return name + + +def _open_file_atomic(filepath: Path, overwrite: bool) -> tuple: + """Open file for writing with atomic creation to prevent TOCTOU races. + + Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False), + preventing race conditions between existence check and file creation. + + Returns: + Tuple of (file_object, None) on success, or (None, error_message) on failure + """ + if overwrite: + return open(filepath, "wb"), None + + try: + fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) + return os.fdopen(fd, "wb"), None + except FileExistsError: + return None, "File exists" + except OSError as e: + return None, str(e) + + +# Protocol constants +ZPAD = 0x2A # '*' - Padding character + +# Transfer limits +DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit per file +MAX_SUBPACKET_SIZE = 8192 # Per-packet size limit (sanity check) +ZDLE = 0x18 # Data Link Escape +ZDLEE = 0x58 # Escaped ZDLE (ZDLE ^ 0x40) + +# Frame types +ZBIN = 0x41 # 'A' - Binary header, CRC-16 +ZHEX = 0x42 # 'B' - Hex header, CRC-16 +ZBIN32 = 0x43 # 'C' - Binary header, CRC-32 + +# Header types +ZRQINIT = 0 # Request receive init +ZRINIT = 1 # Receive init +ZSINIT = 2 # Send init sequence +ZACK = 3 # ACK +ZFILE = 4 # File name/info +ZSKIP = 5 # Skip this file +ZNAK = 6 # Last packet was garbled +ZABORT = 7 # Abort batch transfers +ZFIN = 8 # Finish session +ZRPOS = 9 # Resume at position +ZDATA = 10 # Data packet follows +ZEOF = 11 # End of file +ZFERR = 12 # Fatal error +ZCRC = 13 # Request file CRC +ZCHALLENGE = 14 # Challenge +ZCOMPL = 15 # Request complete +ZCAN = 16 # Cancel (5 CANs) +ZFREECNT = 17 # Request free bytes +ZCOMMAND = 18 # Execute command +ZSTDERR = 19 # Output to stderr + +# Data subpacket end types +ZCRCE = 0x68 # 'h' - CRC next, end, no ACK +ZCRCG = 0x69 # 'i' - CRC next, not end, no ACK +ZCRCQ = 0x6A # 'j' - CRC next, not end, ACK requested +ZCRCW = 0x6B # 'k' - CRC next, end, ACK requested + +# Special characters to escape +ESCAPE_CHARS = {0x10, 0x11, 0x13, 0x90, 0x91, 0x93, ZDLE, 0x0D, 0x8D} + +# CRC-32 lookup table - initialized at module load for thread-safety +def _build_crc32_table() -> list[int]: + """Build CRC-32 lookup table (called once at module load).""" + table = [] + for i in range(256): + crc = i + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + table.append(crc) + return table + + +_CRC32_TABLE: list[int] = _build_crc32_table() + + +def _calc_crc16(data: bytes) -> int: + """Calculate CRC-16-CCITT.""" + crc = 0 + for byte in data: + crc ^= byte << 8 + for _ in range(8): + if crc & 0x8000: + crc = (crc << 1) ^ 0x1021 + else: + crc <<= 1 + crc &= 0xFFFF + return crc + + +def _calc_crc32(data: bytes) -> int: + """Calculate CRC-32 using pre-computed lookup table.""" + crc = 0xFFFFFFFF + for byte in data: + crc = _CRC32_TABLE[(crc ^ byte) & 0xFF] ^ (crc >> 8) + return crc ^ 0xFFFFFFFF + + +def _escape_byte(b: int) -> bytes: + """Escape a byte if needed for ZMODEM transmission.""" + if b == ZDLE: + return bytes([ZDLE, ZDLEE]) + elif b in ESCAPE_CHARS: + return bytes([ZDLE, b ^ 0x40]) + else: + return bytes([b]) + + +def _escape_data(data: bytes) -> bytes: + """Escape all bytes in data that need escaping.""" + result = bytearray() + for b in data: + result.extend(_escape_byte(b)) + return bytes(result) + + +def _to_hex(value: int, digits: int = 2) -> bytes: + """Convert value to hex string bytes.""" + return f"{value:0{digits}x}".encode("ascii") + + +class ZModemError(Exception): + """ZMODEM protocol error.""" + pass + + +class ZModem: + """ZMODEM file transfer protocol handler. + + Args: + read_func: Callable that reads n bytes from serial + write_func: Callable that writes bytes to serial + use_crc32: Use CRC-32 instead of CRC-16 (default True) + """ + + def __init__( + self, + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + use_crc32: bool = True, + ): + self.read = read_func + self.write = write_func + self.use_crc32 = use_crc32 + self.last_sync_pos = 0 + self.tx_buffer_size = 1024 + self.rx_buffer_size = 0 + + def _read_byte(self, timeout_retries: int = 10) -> int | None: + """Read a single byte with retry.""" + for _ in range(timeout_retries): + data = self.read(1) + if data: + return data[0] + return None + + def _read_zdle_byte(self) -> int | None: + """Read a byte, handling ZDLE escaping.""" + b = self._read_byte(timeout_retries=30) + if b is None: + return None + if b == ZDLE: + b = self._read_byte(timeout_retries=30) + if b is None: + return None + if b == ZDLEE: + return ZDLE + elif b in (ZCRCE, ZCRCG, ZCRCQ, ZCRCW): + return b | 0x100 # Flag as frame end + else: + return b ^ 0x40 + return b + + def _send_cancel(self) -> None: + """Send ZMODEM cancel sequence. + + ZMODEM uses 8 CAN bytes followed by 8 backspaces to abort transfer. + This is the proper way to cancel - just sending ZCAN header may not + be recognized by all implementations. + + Note: CAN (Cancel) and ZDLE share the same byte value 0x18. In ZMODEM, + ZDLE is the escape character, but when sent 5+ times consecutively + outside of frame context, it signals abort (CAN sequence). + """ + # CAN = 0x18 (same as ZDLE) - 8 consecutive cancels + 8 backspaces + cancel_seq = bytes([0x18] * 8 + [0x08] * 8) + self.write(cancel_seq) + + def _make_hex_header(self, frame_type: int, data: bytes = b"\x00\x00\x00\x00") -> bytes: + """Create a hex-encoded ZMODEM header.""" + header = bytes([frame_type]) + data[:4].ljust(4, b"\x00") + crc = _calc_crc16(header) + + result = bytearray([ZPAD, ZPAD, ZDLE, ZHEX]) + for b in header: + result.extend(_to_hex(b)) + result.extend(_to_hex(crc >> 8)) + result.extend(_to_hex(crc & 0xFF)) + result.extend(b"\r\n") + + return bytes(result) + + def _make_bin_header(self, frame_type: int, data: bytes = b"\x00\x00\x00\x00") -> bytes: + """Create a binary ZMODEM header.""" + header = bytes([frame_type]) + data[:4].ljust(4, b"\x00") + + if self.use_crc32: + crc = _calc_crc32(header) + result = bytearray([ZPAD, ZDLE, ZBIN32]) + result.extend(_escape_data(header)) + crc_bytes = crc.to_bytes(4, "little") + result.extend(_escape_data(crc_bytes)) + else: + crc = _calc_crc16(header) + result = bytearray([ZPAD, ZDLE, ZBIN]) + result.extend(_escape_data(header)) + result.extend(_escape_data(bytes([crc >> 8, crc & 0xFF]))) + + return bytes(result) + + def _make_data_subpacket(self, data: bytes, end_type: int) -> bytes: + """Create a data subpacket with CRC.""" + result = bytearray(_escape_data(data)) + result.extend([ZDLE, end_type]) + + if self.use_crc32: + crc = _calc_crc32(data + bytes([end_type])) + crc_bytes = crc.to_bytes(4, "little") + result.extend(_escape_data(crc_bytes)) + else: + crc = _calc_crc16(data + bytes([end_type])) + result.extend(_escape_data(bytes([crc >> 8, crc & 0xFF]))) + + return bytes(result) + + def _read_hex_header(self) -> tuple[int, bytes] | None: + """Read and parse a hex-encoded header.""" + # Read hex digits: type(2) + data(8) + crc(4) = 14 hex chars + hex_data = self.read(14) + if len(hex_data) != 14: + return None + + try: + raw = bytes.fromhex(hex_data.decode("ascii")) + except (ValueError, UnicodeDecodeError): + return None + + frame_type = raw[0] + data = raw[1:5] + recv_crc = (raw[5] << 8) | raw[6] + + calc_crc = _calc_crc16(raw[:5]) + if calc_crc != recv_crc: + logger.debug(f"Hex header CRC mismatch: {calc_crc:04x} vs {recv_crc:04x}") + return None + + # Read trailing CR/LF + self.read(2) + + return frame_type, data + + def _read_bin_header(self, use_crc32: bool) -> tuple[int, bytes] | None: + """Read and parse a binary header.""" + # Read 5 bytes: type + 4 data bytes + raw = bytearray() + for _ in range(5): + b = self._read_zdle_byte() + if b is None: + return None + raw.append(b & 0xFF) + + frame_type = raw[0] + data = bytes(raw[1:5]) + + # Read CRC + crc_len = 4 if use_crc32 else 2 + crc_bytes = bytearray() + for _ in range(crc_len): + b = self._read_zdle_byte() + if b is None: + return None + crc_bytes.append(b & 0xFF) + + # Verify CRC + if use_crc32: + recv_crc = int.from_bytes(crc_bytes, "little") + calc_crc = _calc_crc32(bytes(raw)) + else: + recv_crc = (crc_bytes[0] << 8) | crc_bytes[1] + calc_crc = _calc_crc16(bytes(raw)) + + if calc_crc != recv_crc: + logger.debug("Binary header CRC mismatch") + return None + + return frame_type, data + + def _read_header(self) -> tuple[int, bytes] | None: + """Wait for and read any ZMODEM header.""" + # Look for ZPAD ZPAD ZDLE or ZPAD ZDLE + sync_count = 0 + + for _ in range(1000): + b = self._read_byte(timeout_retries=10) + if b is None: + continue + + if b == ZPAD: + sync_count += 1 + elif b == ZDLE and sync_count >= 1: + # Found sync, read frame type + frame_marker = self._read_byte(timeout_retries=10) + if frame_marker == ZHEX: + return self._read_hex_header() + elif frame_marker == ZBIN: + return self._read_bin_header(use_crc32=False) + elif frame_marker == ZBIN32: + return self._read_bin_header(use_crc32=True) + sync_count = 0 + else: + sync_count = 0 + + return None + + def _read_data_subpacket(self) -> tuple[bytes, int] | None: + """Read a data subpacket.""" + data = bytearray() + + while True: + b = self._read_zdle_byte() + if b is None: + return None + + if b & 0x100: # Frame end marker + end_type = b & 0xFF + break + + data.append(b) + if len(data) > MAX_SUBPACKET_SIZE: + logger.debug(f"Subpacket exceeded {MAX_SUBPACKET_SIZE} bytes limit") + return None + + # Read and verify CRC + crc_len = 4 if self.use_crc32 else 2 + crc_bytes = bytearray() + for _ in range(crc_len): + b = self._read_zdle_byte() + if b is None: + return None + crc_bytes.append(b & 0xFF) + + # Verify CRC + check_data = bytes(data) + bytes([end_type]) + if self.use_crc32: + recv_crc = int.from_bytes(crc_bytes, "little") + calc_crc = _calc_crc32(check_data) + else: + recv_crc = (crc_bytes[0] << 8) | crc_bytes[1] + calc_crc = _calc_crc16(check_data) + + if calc_crc != recv_crc: + logger.debug("Data subpacket CRC error") + return None + + return bytes(data), end_type + + def _pos_to_bytes(self, pos: int) -> bytes: + """Convert position to 4-byte little-endian. + + Raises: + OverflowError: If position exceeds 4GB (ZMODEM protocol limit) + """ + try: + return pos.to_bytes(4, "little") + except OverflowError: + raise OverflowError( + f"File position {pos} exceeds ZMODEM's 4GB limit (max 4,294,967,295 bytes). " + "ZMODEM uses 32-bit position encoding and cannot handle files larger than 4GB." + ) from None + + def _bytes_to_pos(self, data: bytes) -> int: + """Convert 4-byte little-endian to position.""" + return int.from_bytes(data[:4], "little") + + def send( + self, + files: list[str | Path], + callback: Callable[[str, int, int], None] | None = None, + retry_limit: int = 10, + ) -> dict: + """Send files via ZMODEM. + + Args: + files: List of file paths to send + callback: Progress callback(filename, bytes_sent, total_bytes) + retry_limit: Max retries for errors + + Returns: + Dict with transfer statistics + """ + results = [] + total_bytes = 0 + + # Send ZRQINIT to request receiver init + logger.debug("Sending ZRQINIT...") + self.write(self._make_hex_header(ZRQINIT)) + + # Wait for ZRINIT from receiver + for _ in range(retry_limit * 3): + header = self._read_header() + if header is None: + self.write(self._make_hex_header(ZRQINIT)) + continue + + frame_type, data = header + if frame_type == ZRINIT: + self.rx_buffer_size = self._bytes_to_pos(data) + logger.debug(f"Received ZRINIT, buffer={self.rx_buffer_size}") + break + elif frame_type == ZCAN: + return {"success": False, "error": "Cancelled by receiver", "files": results} + else: + return {"success": False, "error": "No response from receiver", "files": results} + + # Send each file + for filepath in files: + filepath = Path(filepath) + if not filepath.exists(): + results.append({"file": str(filepath), "error": "File not found"}) + continue + + filesize = filepath.stat().st_size + mtime = int(filepath.stat().st_mtime) + + # Send ZFILE with filename info + file_info = f"{filepath.name}\x00{filesize} {mtime:o} 0 0 0 0\x00".encode("latin-1") + + self.write(self._make_bin_header(ZFILE)) + self.write(self._make_data_subpacket(file_info, ZCRCW)) + + # Wait for ZRPOS or ZSKIP + start_pos = 0 + for _ in range(retry_limit): + header = self._read_header() + if header is None: + continue + + frame_type, data = header + if frame_type == ZRPOS: + start_pos = self._bytes_to_pos(data) + logger.debug(f"Receiver wants to start at {start_pos}") + break + elif frame_type == ZSKIP: + logger.debug(f"Receiver skipping {filepath.name}") + results.append({"file": str(filepath), "skipped": True}) + break + elif frame_type == ZCAN: + return {"success": False, "error": "Cancelled", "files": results} + else: + results.append({"file": str(filepath), "error": "No ZRPOS received"}) + continue + + if frame_type == ZSKIP: + continue + + # Send file data + bytes_sent = start_pos + errors = 0 + + with open(filepath, "rb") as f: + f.seek(start_pos) + + # Send ZDATA header with position + self.write(self._make_bin_header(ZDATA, self._pos_to_bytes(bytes_sent))) + + while bytes_sent < filesize: + chunk = f.read(min(self.tx_buffer_size, filesize - bytes_sent)) + if not chunk: + break + + # ZCRCE = end of file, ZCRCG = more data coming + end_type = ZCRCE if bytes_sent + len(chunk) >= filesize else ZCRCG + + self.write(self._make_data_subpacket(chunk, end_type)) + bytes_sent += len(chunk) + + if callback: + callback(filepath.name, bytes_sent, filesize) + + # Check for ZACK or ZRPOS occasionally + if end_type == ZCRCE or bytes_sent % (self.tx_buffer_size * 10) == 0: + header = self._read_header() + if header: + htype, hdata = header + if htype == ZRPOS: + # Receiver wants us to resend from position + new_pos = self._bytes_to_pos(hdata) + logger.debug(f"ZRPOS received, resending from {new_pos}") + f.seek(new_pos) + bytes_sent = new_pos + errors += 1 + self.write(self._make_bin_header(ZDATA, self._pos_to_bytes(bytes_sent))) + elif htype == ZCAN: + return {"success": False, "error": "Cancelled", "files": results} + + # Send ZEOF + self.write(self._make_bin_header(ZEOF, self._pos_to_bytes(bytes_sent))) + + # Wait for ZRINIT (ready for next file) + for _ in range(retry_limit): + header = self._read_header() + if header and header[0] == ZRINIT: + break + + results.append({ + "file": str(filepath), + "bytes_sent": bytes_sent, + "errors": errors, + "success": True, + }) + total_bytes += bytes_sent + + # Send ZFIN to end session + self.write(self._make_hex_header(ZFIN)) + + # Wait for ZFIN response + for _ in range(retry_limit): + header = self._read_header() + if header and header[0] == ZFIN: + # Send OO to complete + self.write(b"OO") + break + + return { + "success": all(r.get("success", False) or r.get("skipped", False) for r in results), + "files": results, + "total_bytes": total_bytes, + } + + def receive( + self, + directory: str | Path, + callback: Callable[[str, int, int], None] | None = None, + retry_limit: int = 10, + overwrite: bool = False, + max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE, + ) -> dict: + """Receive files via ZMODEM. + + Args: + directory: Directory to save received files + callback: Progress callback(filename, bytes_received, total_bytes) + retry_limit: Max retries for errors + overwrite: Overwrite existing files + max_transfer_size: Maximum bytes to receive per file (default 100MB). + Set to 0 to disable limit. Prevents unbounded memory usage. + + Returns: + Dict with transfer statistics + """ + directory = Path(directory) + directory.mkdir(parents=True, exist_ok=True) + + results = [] + total_bytes = 0 + + # Send ZRINIT to indicate ready + buffer_size = 8192 + logger.debug("Sending ZRINIT...") + self.write(self._make_hex_header(ZRINIT, self._pos_to_bytes(buffer_size))) + + while True: + # Wait for ZFILE or ZFIN + header = None + for _ in range(retry_limit * 3): + header = self._read_header() + if header: + break + # Resend ZRINIT + self.write(self._make_hex_header(ZRINIT, self._pos_to_bytes(buffer_size))) + + if header is None: + return {"success": False, "error": "Timeout", "files": results} + + frame_type, data = header + + if frame_type == ZFIN: + # Session complete + self.write(self._make_hex_header(ZFIN)) + break + elif frame_type == ZCAN: + return {"success": False, "error": "Cancelled", "files": results} + elif frame_type != ZFILE: + continue + + # Read file info subpacket + subpacket = self._read_data_subpacket() + if subpacket is None: + self.write(self._make_hex_header(ZNAK)) + continue + + file_data, _ = subpacket + + # Parse filename and metadata + try: + null_pos = file_data.index(0) + filename = file_data[:null_pos].decode("latin-1") + rest = file_data[null_pos + 1:].split(b" ") + filesize = int(rest[0]) if rest[0] else 0 + mtime = int(rest[1], 8) if len(rest) > 1 and rest[1] else None + except (ValueError, IndexError): + filename = file_data.split(b"\x00")[0].decode("latin-1", errors="replace") + filesize = 0 + mtime = None + + logger.debug(f"Receiving: {filename} ({filesize} bytes)") + + # Security: sanitize filename to prevent path traversal attacks + safe_filename = _sanitize_filename(filename) + if safe_filename != filename: + logger.warning(f"Sanitized filename: {filename!r} -> {safe_filename!r}") + + filepath = directory / safe_filename + + # Atomic file creation to prevent TOCTOU race conditions + f, file_error = _open_file_atomic(filepath, overwrite) + if file_error: + logger.warning(f"Cannot create file {filepath}: {file_error}") + results.append({"file": filename, "error": file_error}) + self.write(self._make_hex_header(ZSKIP)) + continue + + # Check for resume (TODO: implement crash recovery using existing file size) + start_pos = 0 + + # Send ZRPOS to indicate where to start + self.write(self._make_bin_header(ZRPOS, self._pos_to_bytes(start_pos))) + + # Receive data + bytes_received = start_pos + errors = 0 + transfer_aborted = False + + try: + if start_pos > 0: + f.seek(start_pos) + + while not transfer_aborted: + # Wait for ZDATA header + header = self._read_header() + if header is None: + errors += 1 + self.write(self._make_bin_header(ZRPOS, self._pos_to_bytes(bytes_received))) + continue + + frame_type, data = header + + if frame_type == ZEOF: + eof_pos = self._bytes_to_pos(data) + logger.debug(f"ZEOF at position {eof_pos}") + break + elif frame_type == ZCAN: + results.append({"file": filename, "error": "Cancelled"}) + break + elif frame_type != ZDATA: + continue + + data_pos = self._bytes_to_pos(data) + if data_pos != bytes_received: + logger.debug(f"Position mismatch: expected {bytes_received}, got {data_pos}") + self.write(self._make_bin_header(ZRPOS, self._pos_to_bytes(bytes_received))) + continue + + # Read data subpackets + while True: + subpacket = self._read_data_subpacket() + if subpacket is None: + errors += 1 + self.write(self._make_bin_header(ZRPOS, self._pos_to_bytes(bytes_received))) + break + + chunk, end_type = subpacket + f.write(chunk) + bytes_received += len(chunk) + + # Check transfer size limit to prevent unbounded memory usage + if max_transfer_size > 0 and bytes_received > max_transfer_size: + logger.warning( + f"Transfer aborted: {filename} exceeded {max_transfer_size} byte limit " + f"(received {bytes_received} bytes)" + ) + self._send_cancel() + transfer_aborted = True + break + + if callback: + callback(filename, bytes_received, filesize) + + if end_type in (ZCRCE, ZCRCW): + # End of this ZDATA frame + if end_type == ZCRCW: + self.write(self._make_bin_header(ZACK, self._pos_to_bytes(bytes_received))) + break + finally: + f.close() + + if transfer_aborted: + # Transfer was aborted due to size limit + results.append({ + "file": filename, + "path": str(filepath), + "error": f"Transfer size exceeded {max_transfer_size} byte limit", + "bytes_received": bytes_received, + "success": False, + }) + # Try to clean up partial file + with contextlib.suppress(OSError): + filepath.unlink() + continue + + # Set modification time + if mtime: + os.utime(filepath, (time.time(), mtime)) + + # Send ZRINIT for next file + self.write(self._make_hex_header(ZRINIT, self._pos_to_bytes(buffer_size))) + + results.append({ + "file": filename, + "path": str(filepath), + "bytes_received": bytes_received, + "errors": errors, + "success": True, + }) + total_bytes += bytes_received + + return { + "success": all(r.get("success", False) for r in results), + "files": results, + "total_bytes": total_bytes, + } + + +def send_zmodem( + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + files: list[str | Path], + callback: Callable[[str, int, int], None] | None = None, +) -> dict: + """Convenience function to send files via ZMODEM.""" + zm = ZModem(read_func, write_func) + return zm.send(files, callback) + + +def receive_zmodem( + read_func: Callable[[int], bytes], + write_func: Callable[[bytes], int], + directory: str | Path, + callback: Callable[[str, int, int], None] | None = None, + overwrite: bool = False, +) -> dict: + """Convenience function to receive files via ZMODEM.""" + zm = ZModem(read_func, write_func) + return zm.receive(directory, callback, overwrite=overwrite)