diff --git a/src/mcserial/_utils.py b/src/mcserial/_utils.py new file mode 100644 index 0000000..5a08094 --- /dev/null +++ b/src/mcserial/_utils.py @@ -0,0 +1,66 @@ +"""Shared utility functions for file transfer protocols. + +These functions provide security and robustness for file receive operations. +""" + +from __future__ import annotations + +import os +from pathlib import Path + + +def sanitize_filename(filename: str) -> str: + """Remove path traversal attempts and dangerous characters from filename. + + Security: Prevents directory traversal attacks where malicious senders + could write files outside the target directory using names like + '../../../etc/passwd' or absolute paths like '/etc/cron.d/backdoor'. + + Args: + filename: Raw filename from remote sender (untrusted input) + + Returns: + Safe filename with path components and dangerous characters removed + """ + # Get just the basename, removing any directory components + name = Path(filename).name + + # Reject empty names + if not name: + name = "unnamed_file" + + # Prefix hidden files (starting with dot) to make them visible + if name.startswith("."): + name = "_" + name[1:] + + # Replace any remaining problematic characters + name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_") + + return name + + +def open_file_atomic(filepath: Path, overwrite: bool) -> tuple[object, str | None]: + """Open file for writing with atomic creation to prevent TOCTOU races. + + Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False), + preventing race conditions between existence check and file creation. + + Args: + filepath: Path to the file to create/open + overwrite: If True, overwrite existing files. If False, fail if exists. + + Returns: + Tuple of (file_object, None) on success, or (None, error_message) on failure + """ + if overwrite: + # Overwrite mode: just open normally + return open(filepath, "wb"), None + + try: + # Atomic create: fails if file exists (O_EXCL) + fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) + return os.fdopen(fd, "wb"), None + except FileExistsError: + return None, "File exists" + except OSError as e: + return None, str(e) diff --git a/src/mcserial/server.py b/src/mcserial/server.py index 4c65c5b..1db3a12 100644 --- a/src/mcserial/server.py +++ b/src/mcserial/server.py @@ -1672,8 +1672,12 @@ def file_transfer_send( else: return {"error": f"Unknown protocol: {protocol}", "success": False} - except Exception as e: - return {"error": str(e), "success": False, "protocol": protocol} + except serial.SerialException as e: + return {"error": f"Serial error: {e}", "success": False, "protocol": protocol} + except OSError as e: + return {"error": f"File error: {e}", "success": False, "protocol": protocol} + except (ValueError, OverflowError) as e: + return {"error": f"Protocol error: {e}", "success": False, "protocol": protocol} @mcp.tool() @@ -1759,8 +1763,12 @@ def file_transfer_receive( else: return {"error": f"Unknown protocol: {protocol}", "success": False} - except Exception as e: - return {"error": str(e), "success": False, "protocol": protocol} + except serial.SerialException as e: + return {"error": f"Serial error: {e}", "success": False, "protocol": protocol} + except OSError as e: + return {"error": f"File error: {e}", "success": False, "protocol": protocol} + except (ValueError, OverflowError) as e: + return {"error": f"Protocol error: {e}", "success": False, "protocol": protocol} @mcp.tool() @@ -1818,8 +1826,12 @@ def file_transfer_send_batch( result["protocol"] = protocol return result - except Exception as e: - return {"error": str(e), "success": False, "protocol": protocol} + except serial.SerialException as e: + return {"error": f"Serial error: {e}", "success": False, "protocol": protocol} + except OSError as e: + return {"error": f"File error: {e}", "success": False, "protocol": protocol} + except (ValueError, OverflowError) as e: + return {"error": f"Protocol error: {e}", "success": False, "protocol": protocol} # ============================================================================ diff --git a/src/mcserial/xmodem.py b/src/mcserial/xmodem.py index d58e300..71a2ad3 100644 --- a/src/mcserial/xmodem.py +++ b/src/mcserial/xmodem.py @@ -33,6 +33,7 @@ DEFAULT_RETRY_LIMIT = 16 # Max retries per block DEFAULT_TIMEOUT = 60.0 # Total transfer timeout in seconds READ_TIMEOUT_RETRIES = 30 # Retries when waiting for response byte INIT_TIMEOUT_RETRIES = 10 # Retries when waiting for transfer initiation +DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit logger = logging.getLogger(__name__) @@ -134,6 +135,11 @@ class XModem: Returns: True if verification passes, False otherwise """ + # Validate check_bytes length before accessing indices + expected_len = 2 if self.use_crc else 1 + if len(check_bytes) < expected_len: + return False + if self.use_crc: expected = _calc_crc16(data) received = (check_bytes[0] << 8) | check_bytes[1] @@ -256,14 +262,22 @@ class XModem: callback: Callable[[int], None] | None = None, retry_limit: int = DEFAULT_RETRY_LIMIT, timeout: float = DEFAULT_TIMEOUT, + max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE, ) -> dict: """Receive a file via XMODEM. + Note: XMODEM has no file metadata, so the exact file size is unknown. + Blocks are padded with SUB (0x1A) characters. For binary files, you may + need to know the expected size and truncate accordingly. For text files, + trailing 0x1A bytes are traditionally stripped by the receiver. + Args: stream: File-like object to write to callback: Progress callback(bytes_received) retry_limit: Max retries per block timeout: Total timeout in seconds (enforced across entire transfer) + max_transfer_size: Maximum bytes to receive (default 100MB). + Set to 0 to disable limit. Prevents disk exhaustion. Returns: Dict with transfer statistics @@ -358,7 +372,22 @@ class XModem: if block_num == expected_block: stream.write(data) bytes_received += len(data) + # Block numbers wrap at 256 (8-bit counter) expected_block = (expected_block + 1) & 0xFF + + # Check transfer size limit to prevent disk exhaustion + if max_transfer_size > 0 and bytes_received > max_transfer_size: + logger.warning( + f"Transfer aborted: exceeded {max_transfer_size} byte limit " + f"(received {bytes_received} bytes)" + ) + self.write(bytes([CAN, CAN])) + return { + "success": False, + "error": f"Transfer size exceeded {max_transfer_size} byte limit", + "bytes_received": bytes_received, + } + if callback: callback(bytes_received) elif block_num == (expected_block - 1) & 0xFF: diff --git a/src/mcserial/ymodem.py b/src/mcserial/ymodem.py index 78d007b..39269f3 100644 --- a/src/mcserial/ymodem.py +++ b/src/mcserial/ymodem.py @@ -23,6 +23,7 @@ import time from collections.abc import Callable from pathlib import Path +from mcserial._utils import open_file_atomic, sanitize_filename from mcserial.xmodem import ( ACK, CAN, @@ -37,55 +38,10 @@ from mcserial.xmodem import ( logger = logging.getLogger(__name__) -# Transfer limits +# Transfer limits and defaults DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit per file - - -def _sanitize_filename(filename: str) -> str: - """Remove path traversal attempts and dangerous characters from filename. - - Security: Prevents directory traversal attacks where malicious senders - could write files outside the target directory using names like - '../../../etc/passwd' or absolute paths like '/etc/cron.d/backdoor'. - """ - # Get just the basename, removing any directory components - name = Path(filename).name - - # Reject empty names - if not name: - name = "unnamed_file" - - # Prefix hidden files (starting with dot) to make them visible - if name.startswith("."): - name = "_" + name[1:] - - # Replace any remaining problematic characters - name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_") - - return name - - -def _open_file_atomic(filepath: Path, overwrite: bool) -> tuple: - """Open file for writing with atomic creation to prevent TOCTOU races. - - Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False), - preventing race conditions between existence check and file creation. - - Returns: - Tuple of (file_object, None) on success, or (None, error_message) on failure - """ - if overwrite: - # Overwrite mode: just open normally - return open(filepath, "wb"), None - - try: - # Atomic create: fails if file exists (O_EXCL) - fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) - return os.fdopen(fd, "wb"), None - except FileExistsError: - return None, "File exists" - except OSError as e: - return None, str(e) +DEFAULT_TIMEOUT = 60.0 # Total transfer timeout in seconds +DEFAULT_RETRY_LIMIT = 16 # Max retries per block class YModemError(XModemError): @@ -167,6 +123,12 @@ class YModem: rest = [x for x in rest if x and x != b"\x00"] filesize = int(rest[0]) if rest else 0 + # Clamp filesize to reasonable bounds (prevent memory issues from malicious values) + if filesize < 0: + filesize = 0 + elif filesize > 10 * 1024 * 1024 * 1024: # 10GB sanity limit + logger.warning(f"Filesize {filesize} exceeds 10GB limit, clamping") + filesize = 10 * 1024 * 1024 * 1024 mtime = int(rest[1], 8) if len(rest) > 1 else None return { @@ -181,7 +143,8 @@ class YModem: self, files: list[str | Path], callback: Callable[[str, int, int], None] | None = None, - retry_limit: int = 16, + retry_limit: int = DEFAULT_RETRY_LIMIT, + timeout: float = DEFAULT_TIMEOUT, ) -> dict: """Send files via YMODEM batch transfer. @@ -189,10 +152,12 @@ class YModem: files: List of file paths to send callback: Progress callback(filename, bytes_sent, total_bytes) retry_limit: Max retries per block + timeout: Total timeout in seconds (enforced across entire transfer) Returns: Dict with transfer statistics """ + start_time = time.monotonic() results = [] total_bytes = 0 total_errors = 0 @@ -209,6 +174,8 @@ class YModem: # Wait for receiver 'C' logger.debug(f"Waiting for receiver to initiate {filepath.name}...") for _ in range(retry_limit * 10): + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s) waiting for receiver", "files": results} b = self._read_byte(timeout_retries=3) if b == CRC_MODE: break @@ -320,9 +287,10 @@ class YModem: self, directory: str | Path, callback: Callable[[str, int, int], None] | None = None, - retry_limit: int = 16, + retry_limit: int = DEFAULT_RETRY_LIMIT, overwrite: bool = False, max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE, + timeout: float = DEFAULT_TIMEOUT, ) -> dict: """Receive files via YMODEM batch transfer. @@ -333,10 +301,12 @@ class YModem: overwrite: Overwrite existing files max_transfer_size: Maximum bytes to receive per file (default 100MB). Set to 0 to disable limit. Prevents unbounded memory usage. + timeout: Total timeout in seconds (enforced across entire transfer) Returns: Dict with transfer statistics """ + start_time = time.monotonic() directory = Path(directory) directory.mkdir(parents=True, exist_ok=True) @@ -345,9 +315,15 @@ class YModem: total_errors = 0 while True: + # Check timeout + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s)", "files": results} + # Initiate with 'C' for CRC mode logger.debug("Initiating YMODEM receive...") for _attempt in range(retry_limit): + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s)", "files": results} self.write(bytes([CRC_MODE])) header = self._read_byte(timeout_retries=30) if header in (SOH, STX): @@ -391,14 +367,14 @@ class YModem: logger.debug(f"Receiving: {filename} ({filesize} bytes)") # Security: sanitize filename to prevent path traversal attacks - safe_filename = _sanitize_filename(filename) + safe_filename = sanitize_filename(filename) if safe_filename != filename: logger.warning(f"Sanitized filename: {filename!r} -> {safe_filename!r}") filepath = directory / safe_filename # Atomic file creation to prevent TOCTOU race conditions - f, file_error = _open_file_atomic(filepath, overwrite) + f, file_error = open_file_atomic(filepath, overwrite) if file_error: results.append({"file": filename, "error": file_error}) self.write(bytes([CAN, CAN])) diff --git a/src/mcserial/zmodem.py b/src/mcserial/zmodem.py index 124fc03..1a72326 100644 --- a/src/mcserial/zmodem.py +++ b/src/mcserial/zmodem.py @@ -23,51 +23,19 @@ import time from collections.abc import Callable from pathlib import Path +from mcserial._utils import open_file_atomic, sanitize_filename + logger = logging.getLogger(__name__) -def _sanitize_filename(filename: str) -> str: - """Remove path traversal attempts and dangerous characters from filename. - - Security: Prevents directory traversal attacks where malicious senders - could write files outside the target directory. - """ - name = Path(filename).name - if not name: - name = "unnamed_file" - if name.startswith("."): - name = "_" + name[1:] - name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_") - return name - - -def _open_file_atomic(filepath: Path, overwrite: bool) -> tuple: - """Open file for writing with atomic creation to prevent TOCTOU races. - - Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False), - preventing race conditions between existence check and file creation. - - Returns: - Tuple of (file_object, None) on success, or (None, error_message) on failure - """ - if overwrite: - return open(filepath, "wb"), None - - try: - fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) - return os.fdopen(fd, "wb"), None - except FileExistsError: - return None, "File exists" - except OSError as e: - return None, str(e) - - # Protocol constants ZPAD = 0x2A # '*' - Padding character # Transfer limits DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit per file -MAX_SUBPACKET_SIZE = 8192 # Per-packet size limit (sanity check) +MAX_SUBPACKET_SIZE = 32768 # Per-packet size limit (32KB - matches common implementations) +DEFAULT_TIMEOUT = 60.0 # Total transfer timeout in seconds +MAX_HEADER_SEARCH_ITERATIONS = 1000 # Max iterations when searching for header sync ZDLE = 0x18 # Data Link Escape ZDLEE = 0x58 # Escaped ZDLE (ZDLE ^ 0x40) @@ -222,8 +190,8 @@ class ZModem: return b ^ 0x40 return b - def _send_cancel(self) -> None: - """Send ZMODEM cancel sequence. + def _send_cancel(self, retries: int = 3) -> None: + """Send ZMODEM cancel sequence with retry for reliability. ZMODEM uses 8 CAN bytes followed by 8 backspaces to abort transfer. This is the proper way to cancel - just sending ZCAN header may not @@ -232,10 +200,16 @@ class ZModem: Note: CAN (Cancel) and ZDLE share the same byte value 0x18. In ZMODEM, ZDLE is the escape character, but when sent 5+ times consecutively outside of frame context, it signals abort (CAN sequence). + + Args: + retries: Number of times to send cancel sequence for reliability """ # CAN = 0x18 (same as ZDLE) - 8 consecutive cancels + 8 backspaces cancel_seq = bytes([0x18] * 8 + [0x08] * 8) - self.write(cancel_seq) + for _ in range(retries): + self.write(cancel_seq) + # Brief delay between retries to ensure remote receives + time.sleep(0.1) def _make_hex_header(self, frame_type: int, data: bytes = b"\x00\x00\x00\x00") -> bytes: """Create a hex-encoded ZMODEM header.""" @@ -351,7 +325,7 @@ class ZModem: # Look for ZPAD ZPAD ZDLE or ZPAD ZDLE sync_count = 0 - for _ in range(1000): + for _ in range(MAX_HEADER_SEARCH_ITERATIONS): b = self._read_byte(timeout_retries=10) if b is None: continue @@ -429,15 +403,32 @@ class ZModem: "ZMODEM uses 32-bit position encoding and cannot handle files larger than 4GB." ) from None - def _bytes_to_pos(self, data: bytes) -> int: - """Convert 4-byte little-endian to position.""" - return int.from_bytes(data[:4], "little") + def _bytes_to_pos(self, data: bytes, max_valid: int | None = None) -> int: + """Convert 4-byte little-endian to position. + + Args: + data: At least 4 bytes of position data + max_valid: Optional maximum valid position (e.g., filesize). + If provided and position exceeds this, returns max_valid. + + Returns: + Position value, clamped to max_valid if specified + """ + if len(data) < 4: + return 0 + pos = int.from_bytes(data[:4], "little") + # Validate against maximum if provided (prevents seeking past EOF) + if max_valid is not None and pos > max_valid: + logger.debug(f"Position {pos} exceeds max {max_valid}, clamping") + return max_valid + return pos def send( self, files: list[str | Path], callback: Callable[[str, int, int], None] | None = None, retry_limit: int = 10, + timeout: float = DEFAULT_TIMEOUT, ) -> dict: """Send files via ZMODEM. @@ -445,10 +436,12 @@ class ZModem: files: List of file paths to send callback: Progress callback(filename, bytes_sent, total_bytes) retry_limit: Max retries for errors + timeout: Total timeout in seconds (enforced across entire transfer) Returns: Dict with transfer statistics """ + start_time = time.monotonic() results = [] total_bytes = 0 @@ -458,6 +451,8 @@ class ZModem: # Wait for ZRINIT from receiver for _ in range(retry_limit * 3): + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s) waiting for receiver", "files": results} header = self._read_header() if header is None: self.write(self._make_hex_header(ZRQINIT)) @@ -595,6 +590,7 @@ class ZModem: retry_limit: int = 10, overwrite: bool = False, max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE, + timeout: float = DEFAULT_TIMEOUT, ) -> dict: """Receive files via ZMODEM. @@ -605,10 +601,12 @@ class ZModem: overwrite: Overwrite existing files max_transfer_size: Maximum bytes to receive per file (default 100MB). Set to 0 to disable limit. Prevents unbounded memory usage. + timeout: Total timeout in seconds (enforced across entire transfer) Returns: Dict with transfer statistics """ + start_time = time.monotonic() directory = Path(directory) directory.mkdir(parents=True, exist_ok=True) @@ -621,9 +619,15 @@ class ZModem: self.write(self._make_hex_header(ZRINIT, self._pos_to_bytes(buffer_size))) while True: + # Check timeout + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s)", "files": results} + # Wait for ZFILE or ZFIN header = None for _ in range(retry_limit * 3): + if time.monotonic() - start_time > timeout: + return {"success": False, "error": f"Timeout ({timeout}s)", "files": results} header = self._read_header() if header: break @@ -667,14 +671,14 @@ class ZModem: logger.debug(f"Receiving: {filename} ({filesize} bytes)") # Security: sanitize filename to prevent path traversal attacks - safe_filename = _sanitize_filename(filename) + safe_filename = sanitize_filename(filename) if safe_filename != filename: logger.warning(f"Sanitized filename: {filename!r} -> {safe_filename!r}") filepath = directory / safe_filename # Atomic file creation to prevent TOCTOU race conditions - f, file_error = _open_file_atomic(filepath, overwrite) + f, file_error = open_file_atomic(filepath, overwrite) if file_error: logger.warning(f"Cannot create file {filepath}: {file_error}") results.append({"file": filename, "error": file_error})