Address code review findings for file transfer protocols

Critical fixes:
- Add max_transfer_size (100MB default) to XMODEM receive
- Validate ZMODEM position values with optional max_valid bound

High priority:
- Extract shared utils to _utils.py (sanitize_filename, open_file_atomic)
- Document XMODEM padding behavior (protocol limitation)
- Add filesize bounds checking in YMODEM (clamp to 10GB)
- Increase ZMODEM subpacket limit from 8KB to 32KB

Medium priority:
- Add timeout parameter to YMODEM/ZMODEM send/receive methods
- Narrow exception handling (SerialException, OSError, ValueError)
- Make ZMODEM cancel more robust (3 retries with delays)
- Add length validation in _verify_block to prevent IndexError
This commit is contained in:
Ryan Malloy 2026-01-28 19:59:22 -07:00
parent fb671a7c34
commit e8a6197b8c
5 changed files with 191 additions and 104 deletions

66
src/mcserial/_utils.py Normal file
View File

@ -0,0 +1,66 @@
"""Shared utility functions for file transfer protocols.
These functions provide security and robustness for file receive operations.
"""
from __future__ import annotations
import os
from pathlib import Path
def sanitize_filename(filename: str) -> str:
"""Remove path traversal attempts and dangerous characters from filename.
Security: Prevents directory traversal attacks where malicious senders
could write files outside the target directory using names like
'../../../etc/passwd' or absolute paths like '/etc/cron.d/backdoor'.
Args:
filename: Raw filename from remote sender (untrusted input)
Returns:
Safe filename with path components and dangerous characters removed
"""
# Get just the basename, removing any directory components
name = Path(filename).name
# Reject empty names
if not name:
name = "unnamed_file"
# Prefix hidden files (starting with dot) to make them visible
if name.startswith("."):
name = "_" + name[1:]
# Replace any remaining problematic characters
name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_")
return name
def open_file_atomic(filepath: Path, overwrite: bool) -> tuple[object, str | None]:
"""Open file for writing with atomic creation to prevent TOCTOU races.
Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False),
preventing race conditions between existence check and file creation.
Args:
filepath: Path to the file to create/open
overwrite: If True, overwrite existing files. If False, fail if exists.
Returns:
Tuple of (file_object, None) on success, or (None, error_message) on failure
"""
if overwrite:
# Overwrite mode: just open normally
return open(filepath, "wb"), None
try:
# Atomic create: fails if file exists (O_EXCL)
fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
return os.fdopen(fd, "wb"), None
except FileExistsError:
return None, "File exists"
except OSError as e:
return None, str(e)

View File

@ -1672,8 +1672,12 @@ def file_transfer_send(
else:
return {"error": f"Unknown protocol: {protocol}", "success": False}
except Exception as e:
return {"error": str(e), "success": False, "protocol": protocol}
except serial.SerialException as e:
return {"error": f"Serial error: {e}", "success": False, "protocol": protocol}
except OSError as e:
return {"error": f"File error: {e}", "success": False, "protocol": protocol}
except (ValueError, OverflowError) as e:
return {"error": f"Protocol error: {e}", "success": False, "protocol": protocol}
@mcp.tool()
@ -1759,8 +1763,12 @@ def file_transfer_receive(
else:
return {"error": f"Unknown protocol: {protocol}", "success": False}
except Exception as e:
return {"error": str(e), "success": False, "protocol": protocol}
except serial.SerialException as e:
return {"error": f"Serial error: {e}", "success": False, "protocol": protocol}
except OSError as e:
return {"error": f"File error: {e}", "success": False, "protocol": protocol}
except (ValueError, OverflowError) as e:
return {"error": f"Protocol error: {e}", "success": False, "protocol": protocol}
@mcp.tool()
@ -1818,8 +1826,12 @@ def file_transfer_send_batch(
result["protocol"] = protocol
return result
except Exception as e:
return {"error": str(e), "success": False, "protocol": protocol}
except serial.SerialException as e:
return {"error": f"Serial error: {e}", "success": False, "protocol": protocol}
except OSError as e:
return {"error": f"File error: {e}", "success": False, "protocol": protocol}
except (ValueError, OverflowError) as e:
return {"error": f"Protocol error: {e}", "success": False, "protocol": protocol}
# ============================================================================

View File

@ -33,6 +33,7 @@ DEFAULT_RETRY_LIMIT = 16 # Max retries per block
DEFAULT_TIMEOUT = 60.0 # Total transfer timeout in seconds
READ_TIMEOUT_RETRIES = 30 # Retries when waiting for response byte
INIT_TIMEOUT_RETRIES = 10 # Retries when waiting for transfer initiation
DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit
logger = logging.getLogger(__name__)
@ -134,6 +135,11 @@ class XModem:
Returns:
True if verification passes, False otherwise
"""
# Validate check_bytes length before accessing indices
expected_len = 2 if self.use_crc else 1
if len(check_bytes) < expected_len:
return False
if self.use_crc:
expected = _calc_crc16(data)
received = (check_bytes[0] << 8) | check_bytes[1]
@ -256,14 +262,22 @@ class XModem:
callback: Callable[[int], None] | None = None,
retry_limit: int = DEFAULT_RETRY_LIMIT,
timeout: float = DEFAULT_TIMEOUT,
max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE,
) -> dict:
"""Receive a file via XMODEM.
Note: XMODEM has no file metadata, so the exact file size is unknown.
Blocks are padded with SUB (0x1A) characters. For binary files, you may
need to know the expected size and truncate accordingly. For text files,
trailing 0x1A bytes are traditionally stripped by the receiver.
Args:
stream: File-like object to write to
callback: Progress callback(bytes_received)
retry_limit: Max retries per block
timeout: Total timeout in seconds (enforced across entire transfer)
max_transfer_size: Maximum bytes to receive (default 100MB).
Set to 0 to disable limit. Prevents disk exhaustion.
Returns:
Dict with transfer statistics
@ -358,7 +372,22 @@ class XModem:
if block_num == expected_block:
stream.write(data)
bytes_received += len(data)
# Block numbers wrap at 256 (8-bit counter)
expected_block = (expected_block + 1) & 0xFF
# Check transfer size limit to prevent disk exhaustion
if max_transfer_size > 0 and bytes_received > max_transfer_size:
logger.warning(
f"Transfer aborted: exceeded {max_transfer_size} byte limit "
f"(received {bytes_received} bytes)"
)
self.write(bytes([CAN, CAN]))
return {
"success": False,
"error": f"Transfer size exceeded {max_transfer_size} byte limit",
"bytes_received": bytes_received,
}
if callback:
callback(bytes_received)
elif block_num == (expected_block - 1) & 0xFF:

View File

@ -23,6 +23,7 @@ import time
from collections.abc import Callable
from pathlib import Path
from mcserial._utils import open_file_atomic, sanitize_filename
from mcserial.xmodem import (
ACK,
CAN,
@ -37,55 +38,10 @@ from mcserial.xmodem import (
logger = logging.getLogger(__name__)
# Transfer limits
# Transfer limits and defaults
DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit per file
def _sanitize_filename(filename: str) -> str:
"""Remove path traversal attempts and dangerous characters from filename.
Security: Prevents directory traversal attacks where malicious senders
could write files outside the target directory using names like
'../../../etc/passwd' or absolute paths like '/etc/cron.d/backdoor'.
"""
# Get just the basename, removing any directory components
name = Path(filename).name
# Reject empty names
if not name:
name = "unnamed_file"
# Prefix hidden files (starting with dot) to make them visible
if name.startswith("."):
name = "_" + name[1:]
# Replace any remaining problematic characters
name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_")
return name
def _open_file_atomic(filepath: Path, overwrite: bool) -> tuple:
"""Open file for writing with atomic creation to prevent TOCTOU races.
Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False),
preventing race conditions between existence check and file creation.
Returns:
Tuple of (file_object, None) on success, or (None, error_message) on failure
"""
if overwrite:
# Overwrite mode: just open normally
return open(filepath, "wb"), None
try:
# Atomic create: fails if file exists (O_EXCL)
fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
return os.fdopen(fd, "wb"), None
except FileExistsError:
return None, "File exists"
except OSError as e:
return None, str(e)
DEFAULT_TIMEOUT = 60.0 # Total transfer timeout in seconds
DEFAULT_RETRY_LIMIT = 16 # Max retries per block
class YModemError(XModemError):
@ -167,6 +123,12 @@ class YModem:
rest = [x for x in rest if x and x != b"\x00"]
filesize = int(rest[0]) if rest else 0
# Clamp filesize to reasonable bounds (prevent memory issues from malicious values)
if filesize < 0:
filesize = 0
elif filesize > 10 * 1024 * 1024 * 1024: # 10GB sanity limit
logger.warning(f"Filesize {filesize} exceeds 10GB limit, clamping")
filesize = 10 * 1024 * 1024 * 1024
mtime = int(rest[1], 8) if len(rest) > 1 else None
return {
@ -181,7 +143,8 @@ class YModem:
self,
files: list[str | Path],
callback: Callable[[str, int, int], None] | None = None,
retry_limit: int = 16,
retry_limit: int = DEFAULT_RETRY_LIMIT,
timeout: float = DEFAULT_TIMEOUT,
) -> dict:
"""Send files via YMODEM batch transfer.
@ -189,10 +152,12 @@ class YModem:
files: List of file paths to send
callback: Progress callback(filename, bytes_sent, total_bytes)
retry_limit: Max retries per block
timeout: Total timeout in seconds (enforced across entire transfer)
Returns:
Dict with transfer statistics
"""
start_time = time.monotonic()
results = []
total_bytes = 0
total_errors = 0
@ -209,6 +174,8 @@ class YModem:
# Wait for receiver 'C'
logger.debug(f"Waiting for receiver to initiate {filepath.name}...")
for _ in range(retry_limit * 10):
if time.monotonic() - start_time > timeout:
return {"success": False, "error": f"Timeout ({timeout}s) waiting for receiver", "files": results}
b = self._read_byte(timeout_retries=3)
if b == CRC_MODE:
break
@ -320,9 +287,10 @@ class YModem:
self,
directory: str | Path,
callback: Callable[[str, int, int], None] | None = None,
retry_limit: int = 16,
retry_limit: int = DEFAULT_RETRY_LIMIT,
overwrite: bool = False,
max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE,
timeout: float = DEFAULT_TIMEOUT,
) -> dict:
"""Receive files via YMODEM batch transfer.
@ -333,10 +301,12 @@ class YModem:
overwrite: Overwrite existing files
max_transfer_size: Maximum bytes to receive per file (default 100MB).
Set to 0 to disable limit. Prevents unbounded memory usage.
timeout: Total timeout in seconds (enforced across entire transfer)
Returns:
Dict with transfer statistics
"""
start_time = time.monotonic()
directory = Path(directory)
directory.mkdir(parents=True, exist_ok=True)
@ -345,9 +315,15 @@ class YModem:
total_errors = 0
while True:
# Check timeout
if time.monotonic() - start_time > timeout:
return {"success": False, "error": f"Timeout ({timeout}s)", "files": results}
# Initiate with 'C' for CRC mode
logger.debug("Initiating YMODEM receive...")
for _attempt in range(retry_limit):
if time.monotonic() - start_time > timeout:
return {"success": False, "error": f"Timeout ({timeout}s)", "files": results}
self.write(bytes([CRC_MODE]))
header = self._read_byte(timeout_retries=30)
if header in (SOH, STX):
@ -391,14 +367,14 @@ class YModem:
logger.debug(f"Receiving: {filename} ({filesize} bytes)")
# Security: sanitize filename to prevent path traversal attacks
safe_filename = _sanitize_filename(filename)
safe_filename = sanitize_filename(filename)
if safe_filename != filename:
logger.warning(f"Sanitized filename: {filename!r} -> {safe_filename!r}")
filepath = directory / safe_filename
# Atomic file creation to prevent TOCTOU race conditions
f, file_error = _open_file_atomic(filepath, overwrite)
f, file_error = open_file_atomic(filepath, overwrite)
if file_error:
results.append({"file": filename, "error": file_error})
self.write(bytes([CAN, CAN]))

View File

@ -23,51 +23,19 @@ import time
from collections.abc import Callable
from pathlib import Path
from mcserial._utils import open_file_atomic, sanitize_filename
logger = logging.getLogger(__name__)
def _sanitize_filename(filename: str) -> str:
"""Remove path traversal attempts and dangerous characters from filename.
Security: Prevents directory traversal attacks where malicious senders
could write files outside the target directory.
"""
name = Path(filename).name
if not name:
name = "unnamed_file"
if name.startswith("."):
name = "_" + name[1:]
name = name.replace("\x00", "_").replace("/", "_").replace("\\", "_")
return name
def _open_file_atomic(filepath: Path, overwrite: bool) -> tuple:
"""Open file for writing with atomic creation to prevent TOCTOU races.
Uses O_CREAT | O_EXCL to atomically fail if file exists (when overwrite=False),
preventing race conditions between existence check and file creation.
Returns:
Tuple of (file_object, None) on success, or (None, error_message) on failure
"""
if overwrite:
return open(filepath, "wb"), None
try:
fd = os.open(filepath, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
return os.fdopen(fd, "wb"), None
except FileExistsError:
return None, "File exists"
except OSError as e:
return None, str(e)
# Protocol constants
ZPAD = 0x2A # '*' - Padding character
# Transfer limits
DEFAULT_MAX_TRANSFER_SIZE = 100 * 1024 * 1024 # 100MB default limit per file
MAX_SUBPACKET_SIZE = 8192 # Per-packet size limit (sanity check)
MAX_SUBPACKET_SIZE = 32768 # Per-packet size limit (32KB - matches common implementations)
DEFAULT_TIMEOUT = 60.0 # Total transfer timeout in seconds
MAX_HEADER_SEARCH_ITERATIONS = 1000 # Max iterations when searching for header sync
ZDLE = 0x18 # Data Link Escape
ZDLEE = 0x58 # Escaped ZDLE (ZDLE ^ 0x40)
@ -222,8 +190,8 @@ class ZModem:
return b ^ 0x40
return b
def _send_cancel(self) -> None:
"""Send ZMODEM cancel sequence.
def _send_cancel(self, retries: int = 3) -> None:
"""Send ZMODEM cancel sequence with retry for reliability.
ZMODEM uses 8 CAN bytes followed by 8 backspaces to abort transfer.
This is the proper way to cancel - just sending ZCAN header may not
@ -232,10 +200,16 @@ class ZModem:
Note: CAN (Cancel) and ZDLE share the same byte value 0x18. In ZMODEM,
ZDLE is the escape character, but when sent 5+ times consecutively
outside of frame context, it signals abort (CAN sequence).
Args:
retries: Number of times to send cancel sequence for reliability
"""
# CAN = 0x18 (same as ZDLE) - 8 consecutive cancels + 8 backspaces
cancel_seq = bytes([0x18] * 8 + [0x08] * 8)
for _ in range(retries):
self.write(cancel_seq)
# Brief delay between retries to ensure remote receives
time.sleep(0.1)
def _make_hex_header(self, frame_type: int, data: bytes = b"\x00\x00\x00\x00") -> bytes:
"""Create a hex-encoded ZMODEM header."""
@ -351,7 +325,7 @@ class ZModem:
# Look for ZPAD ZPAD ZDLE or ZPAD ZDLE
sync_count = 0
for _ in range(1000):
for _ in range(MAX_HEADER_SEARCH_ITERATIONS):
b = self._read_byte(timeout_retries=10)
if b is None:
continue
@ -429,15 +403,32 @@ class ZModem:
"ZMODEM uses 32-bit position encoding and cannot handle files larger than 4GB."
) from None
def _bytes_to_pos(self, data: bytes) -> int:
"""Convert 4-byte little-endian to position."""
return int.from_bytes(data[:4], "little")
def _bytes_to_pos(self, data: bytes, max_valid: int | None = None) -> int:
"""Convert 4-byte little-endian to position.
Args:
data: At least 4 bytes of position data
max_valid: Optional maximum valid position (e.g., filesize).
If provided and position exceeds this, returns max_valid.
Returns:
Position value, clamped to max_valid if specified
"""
if len(data) < 4:
return 0
pos = int.from_bytes(data[:4], "little")
# Validate against maximum if provided (prevents seeking past EOF)
if max_valid is not None and pos > max_valid:
logger.debug(f"Position {pos} exceeds max {max_valid}, clamping")
return max_valid
return pos
def send(
self,
files: list[str | Path],
callback: Callable[[str, int, int], None] | None = None,
retry_limit: int = 10,
timeout: float = DEFAULT_TIMEOUT,
) -> dict:
"""Send files via ZMODEM.
@ -445,10 +436,12 @@ class ZModem:
files: List of file paths to send
callback: Progress callback(filename, bytes_sent, total_bytes)
retry_limit: Max retries for errors
timeout: Total timeout in seconds (enforced across entire transfer)
Returns:
Dict with transfer statistics
"""
start_time = time.monotonic()
results = []
total_bytes = 0
@ -458,6 +451,8 @@ class ZModem:
# Wait for ZRINIT from receiver
for _ in range(retry_limit * 3):
if time.monotonic() - start_time > timeout:
return {"success": False, "error": f"Timeout ({timeout}s) waiting for receiver", "files": results}
header = self._read_header()
if header is None:
self.write(self._make_hex_header(ZRQINIT))
@ -595,6 +590,7 @@ class ZModem:
retry_limit: int = 10,
overwrite: bool = False,
max_transfer_size: int = DEFAULT_MAX_TRANSFER_SIZE,
timeout: float = DEFAULT_TIMEOUT,
) -> dict:
"""Receive files via ZMODEM.
@ -605,10 +601,12 @@ class ZModem:
overwrite: Overwrite existing files
max_transfer_size: Maximum bytes to receive per file (default 100MB).
Set to 0 to disable limit. Prevents unbounded memory usage.
timeout: Total timeout in seconds (enforced across entire transfer)
Returns:
Dict with transfer statistics
"""
start_time = time.monotonic()
directory = Path(directory)
directory.mkdir(parents=True, exist_ok=True)
@ -621,9 +619,15 @@ class ZModem:
self.write(self._make_hex_header(ZRINIT, self._pos_to_bytes(buffer_size)))
while True:
# Check timeout
if time.monotonic() - start_time > timeout:
return {"success": False, "error": f"Timeout ({timeout}s)", "files": results}
# Wait for ZFILE or ZFIN
header = None
for _ in range(retry_limit * 3):
if time.monotonic() - start_time > timeout:
return {"success": False, "error": f"Timeout ({timeout}s)", "files": results}
header = self._read_header()
if header:
break
@ -667,14 +671,14 @@ class ZModem:
logger.debug(f"Receiving: {filename} ({filesize} bytes)")
# Security: sanitize filename to prevent path traversal attacks
safe_filename = _sanitize_filename(filename)
safe_filename = sanitize_filename(filename)
if safe_filename != filename:
logger.warning(f"Sanitized filename: {filename!r} -> {safe_filename!r}")
filepath = directory / safe_filename
# Atomic file creation to prevent TOCTOU race conditions
f, file_error = _open_file_atomic(filepath, overwrite)
f, file_error = open_file_atomic(filepath, overwrite)
if file_error:
logger.warning(f"Cannot create file {filepath}: {file_error}")
results.append({"file": filename, "error": file_error})