Harden smart backup against review findings

- Path traversal protection: validate manifest file paths stay within
  backup directory (prevents ../../etc/passwd in crafted manifests)
- Input validation: reject chunk_mb outside 1-16 range
- Memory: _is_chunk_empty uses count() instead of allocating comparison buffer
- Type safety: accept bytes | bytearray in hash/empty-check helpers
- Timeout cap: _compute_chunk_timeout capped at 240s (MCP limit is 300s)
- Manifest schema: validate required fields before processing
- Disk full: wrap file writes in OSError handling with partial_backup path
- Atomic manifest: write via temp+rename to prevent corrupt state
- Decompression cache: avoid decompressing twice during restore
- Tempfile race: close NamedTemporaryFile before Path.write_bytes
- Unused variable: replace sum(1 for v in ...) with len()
This commit is contained in:
Ryan Malloy 2026-02-25 16:17:09 -07:00
parent 9fa314dae9
commit a07b3a0fd3

View File

@ -66,9 +66,9 @@ def _parse_flash_size_string(size_str: str) -> int:
return int(size_str, 0)
def _is_chunk_empty(data: bytes) -> bool:
def _is_chunk_empty(data: bytes | bytearray) -> bool:
"""Check whether a data chunk is entirely erased flash (0xFF)."""
return data == b"\xff" * len(data)
return data.count(b"\xff") == len(data)
def _sha256_file(path: Path) -> str:
@ -83,7 +83,7 @@ def _sha256_file(path: Path) -> str:
return h.hexdigest()
def _sha256_bytes(data: bytes) -> str:
def _sha256_bytes(data: bytes | bytearray) -> str:
"""Compute SHA256 hex digest of in-memory bytes."""
return hashlib.sha256(data).hexdigest()
@ -93,6 +93,19 @@ def _normalize_mac(mac: str) -> str:
return mac.lower().replace(":", "-")
def _safe_resolve(base: Path, untrusted: str) -> Path:
"""Resolve an untrusted relative path within a base directory.
Prevents path traversal attacks (e.g. ../../etc/passwd) by verifying
the resolved path is a child of the base directory.
"""
resolved = (base / untrusted).resolve()
base_resolved = base.resolve()
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
raise ValueError(f"Path traversal detected: {untrusted!r} escapes {base}")
return resolved
def _parse_partition_table_binary(raw: bytes) -> list[dict[str, Any]]:
"""Parse ESP32 binary partition table format.
@ -292,9 +305,9 @@ class SmartBackupManager:
return result
def _compute_chunk_timeout(self, chunk_bytes: int) -> float:
"""Compute per-chunk timeout: max(60, chunk_mb * 100) seconds."""
"""Compute per-chunk timeout, capped at 240s to stay under MCP 300s limit."""
chunk_mb = chunk_bytes / (1024 * 1024)
return max(60.0, chunk_mb * 100.0)
return min(240.0, max(60.0, chunk_mb * 100.0))
def _adapt_chunk_size(
self,
@ -440,6 +453,12 @@ class SmartBackupManager:
"error": "Port is required for smart backup",
}
if not (1 <= chunk_mb <= 16):
return {
"success": False,
"error": f"chunk_mb must be between 1 and 16, got {chunk_mb}",
}
overall_start = time.time()
# Step 1: Get chip info via flash-id
@ -637,13 +656,21 @@ class SmartBackupManager:
if is_empty:
# Don't write a file, just record it in the manifest
sha = _sha256_bytes(region_data)
elif compress:
compressed = gzip.compress(bytes(region_data), compresslevel=6)
region_file.write_bytes(compressed)
sha = _sha256_bytes(bytes(region_data))
else:
region_file.write_bytes(bytes(region_data))
sha = _sha256_file(region_file)
try:
if compress:
compressed = gzip.compress(bytes(region_data), compresslevel=6)
region_file.write_bytes(compressed)
sha = _sha256_bytes(bytes(region_data))
else:
region_file.write_bytes(bytes(region_data))
sha = _sha256_file(region_file)
except OSError as e:
return {
"success": False,
"error": f"Failed writing {region_file}: {e}",
"partial_backup": str(backup_path),
}
manifest_partitions.append({
"name": region["name"],
@ -676,7 +703,17 @@ class SmartBackupManager:
}
manifest_path = backup_path / "manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2) + "\n")
try:
# Atomic write: write to temp file in same directory, then rename
tmp_manifest = manifest_path.with_suffix(".json.tmp")
tmp_manifest.write_text(json.dumps(manifest, indent=2) + "\n")
tmp_manifest.rename(manifest_path)
except OSError as e:
return {
"success": False,
"error": f"Failed writing manifest: {e}",
"partial_backup": str(backup_path),
}
# Build summary
backed_up = [p for p in manifest_partitions if not p["empty"]]
@ -743,6 +780,16 @@ class SmartBackupManager:
all_partitions = manifest.get("partitions", [])
# Validate required fields in each partition entry
required_fields = {"name", "offset", "size_bytes"}
for idx, p in enumerate(all_partitions):
missing = required_fields - set(p.keys())
if missing:
return {
"success": False,
"error": f"Manifest partition[{idx}] missing fields: {', '.join(sorted(missing))}",
}
# Filter partitions
if partition_filter:
filter_set = set(partition_filter)
@ -760,18 +807,25 @@ class SmartBackupManager:
flashable = [p for p in all_partitions if not p.get("empty") and p.get("file")]
# Step 2: Verify all referenced files exist and check SHA256
# Cache decompressed data so we don't decompress twice (validation + flash)
_decompressed_cache: dict[str, bytes] = {}
validation_errors: list[str] = []
for part in flashable:
file_path = backup_path / part["file"]
try:
file_path = _safe_resolve(backup_path, part["file"])
except ValueError as e:
validation_errors.append(str(e))
continue
if not file_path.exists():
validation_errors.append(f"Missing file: {part['file']}")
continue
# Verify hash
if part.get("compressed"):
# Decompress to check hash of original data
# Decompress to check hash; cache for reuse during flash
try:
decompressed = gzip.decompress(file_path.read_bytes())
_decompressed_cache[part["file"]] = decompressed
actual_hash = _sha256_bytes(decompressed)
except Exception as e:
validation_errors.append(f"Decompression failed for {part['file']}: {e}")
@ -826,16 +880,18 @@ class SmartBackupManager:
try:
for part in flashable:
file_path = backup_path / part["file"]
file_path = _safe_resolve(backup_path, part["file"])
offset = part["offset"]
if part.get("compressed"):
# Decompress to temp file for flashing
decompressed = gzip.decompress(file_path.read_bytes())
# Use cached decompressed data from validation, or decompress now
decompressed = _decompressed_cache.pop(part["file"], None)
if decompressed is None:
decompressed = gzip.decompress(file_path.read_bytes())
tmp = tempfile.NamedTemporaryFile(suffix=".bin", delete=False)
tmp_path = Path(tmp.name)
tmp_path.write_bytes(decompressed)
tmp.close()
tmp_path.write_bytes(decompressed)
temp_files.append(tmp_path)
flash_pairs.extend([offset, str(tmp_path)])
total_flash_bytes += len(decompressed)
@ -916,10 +972,7 @@ class SmartBackupManager:
pass
try:
progress_done = sum(
1 for v in verify_results
)
await context.report_progress(progress_done, len(flashable))
await context.report_progress(len(verify_results), len(flashable))
except Exception:
logger.debug("Progress reporting not supported by client")