Harden smart backup against review findings
- Path traversal protection: validate manifest file paths stay within backup directory (prevents ../../etc/passwd in crafted manifests) - Input validation: reject chunk_mb outside 1-16 range - Memory: _is_chunk_empty uses count() instead of allocating comparison buffer - Type safety: accept bytes | bytearray in hash/empty-check helpers - Timeout cap: _compute_chunk_timeout capped at 240s (MCP limit is 300s) - Manifest schema: validate required fields before processing - Disk full: wrap file writes in OSError handling with partial_backup path - Atomic manifest: write via temp+rename to prevent corrupt state - Decompression cache: avoid decompressing twice during restore - Tempfile race: close NamedTemporaryFile before Path.write_bytes - Unused variable: replace sum(1 for v in ...) with len()
This commit is contained in:
parent
9fa314dae9
commit
a07b3a0fd3
@ -66,9 +66,9 @@ def _parse_flash_size_string(size_str: str) -> int:
|
||||
return int(size_str, 0)
|
||||
|
||||
|
||||
def _is_chunk_empty(data: bytes) -> bool:
|
||||
def _is_chunk_empty(data: bytes | bytearray) -> bool:
|
||||
"""Check whether a data chunk is entirely erased flash (0xFF)."""
|
||||
return data == b"\xff" * len(data)
|
||||
return data.count(b"\xff") == len(data)
|
||||
|
||||
|
||||
def _sha256_file(path: Path) -> str:
|
||||
@ -83,7 +83,7 @@ def _sha256_file(path: Path) -> str:
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _sha256_bytes(data: bytes) -> str:
|
||||
def _sha256_bytes(data: bytes | bytearray) -> str:
|
||||
"""Compute SHA256 hex digest of in-memory bytes."""
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
@ -93,6 +93,19 @@ def _normalize_mac(mac: str) -> str:
|
||||
return mac.lower().replace(":", "-")
|
||||
|
||||
|
||||
def _safe_resolve(base: Path, untrusted: str) -> Path:
|
||||
"""Resolve an untrusted relative path within a base directory.
|
||||
|
||||
Prevents path traversal attacks (e.g. ../../etc/passwd) by verifying
|
||||
the resolved path is a child of the base directory.
|
||||
"""
|
||||
resolved = (base / untrusted).resolve()
|
||||
base_resolved = base.resolve()
|
||||
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||
raise ValueError(f"Path traversal detected: {untrusted!r} escapes {base}")
|
||||
return resolved
|
||||
|
||||
|
||||
def _parse_partition_table_binary(raw: bytes) -> list[dict[str, Any]]:
|
||||
"""Parse ESP32 binary partition table format.
|
||||
|
||||
@ -292,9 +305,9 @@ class SmartBackupManager:
|
||||
return result
|
||||
|
||||
def _compute_chunk_timeout(self, chunk_bytes: int) -> float:
|
||||
"""Compute per-chunk timeout: max(60, chunk_mb * 100) seconds."""
|
||||
"""Compute per-chunk timeout, capped at 240s to stay under MCP 300s limit."""
|
||||
chunk_mb = chunk_bytes / (1024 * 1024)
|
||||
return max(60.0, chunk_mb * 100.0)
|
||||
return min(240.0, max(60.0, chunk_mb * 100.0))
|
||||
|
||||
def _adapt_chunk_size(
|
||||
self,
|
||||
@ -440,6 +453,12 @@ class SmartBackupManager:
|
||||
"error": "Port is required for smart backup",
|
||||
}
|
||||
|
||||
if not (1 <= chunk_mb <= 16):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"chunk_mb must be between 1 and 16, got {chunk_mb}",
|
||||
}
|
||||
|
||||
overall_start = time.time()
|
||||
|
||||
# Step 1: Get chip info via flash-id
|
||||
@ -637,13 +656,21 @@ class SmartBackupManager:
|
||||
if is_empty:
|
||||
# Don't write a file, just record it in the manifest
|
||||
sha = _sha256_bytes(region_data)
|
||||
elif compress:
|
||||
compressed = gzip.compress(bytes(region_data), compresslevel=6)
|
||||
region_file.write_bytes(compressed)
|
||||
sha = _sha256_bytes(bytes(region_data))
|
||||
else:
|
||||
region_file.write_bytes(bytes(region_data))
|
||||
sha = _sha256_file(region_file)
|
||||
try:
|
||||
if compress:
|
||||
compressed = gzip.compress(bytes(region_data), compresslevel=6)
|
||||
region_file.write_bytes(compressed)
|
||||
sha = _sha256_bytes(bytes(region_data))
|
||||
else:
|
||||
region_file.write_bytes(bytes(region_data))
|
||||
sha = _sha256_file(region_file)
|
||||
except OSError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed writing {region_file}: {e}",
|
||||
"partial_backup": str(backup_path),
|
||||
}
|
||||
|
||||
manifest_partitions.append({
|
||||
"name": region["name"],
|
||||
@ -676,7 +703,17 @@ class SmartBackupManager:
|
||||
}
|
||||
|
||||
manifest_path = backup_path / "manifest.json"
|
||||
manifest_path.write_text(json.dumps(manifest, indent=2) + "\n")
|
||||
try:
|
||||
# Atomic write: write to temp file in same directory, then rename
|
||||
tmp_manifest = manifest_path.with_suffix(".json.tmp")
|
||||
tmp_manifest.write_text(json.dumps(manifest, indent=2) + "\n")
|
||||
tmp_manifest.rename(manifest_path)
|
||||
except OSError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed writing manifest: {e}",
|
||||
"partial_backup": str(backup_path),
|
||||
}
|
||||
|
||||
# Build summary
|
||||
backed_up = [p for p in manifest_partitions if not p["empty"]]
|
||||
@ -743,6 +780,16 @@ class SmartBackupManager:
|
||||
|
||||
all_partitions = manifest.get("partitions", [])
|
||||
|
||||
# Validate required fields in each partition entry
|
||||
required_fields = {"name", "offset", "size_bytes"}
|
||||
for idx, p in enumerate(all_partitions):
|
||||
missing = required_fields - set(p.keys())
|
||||
if missing:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Manifest partition[{idx}] missing fields: {', '.join(sorted(missing))}",
|
||||
}
|
||||
|
||||
# Filter partitions
|
||||
if partition_filter:
|
||||
filter_set = set(partition_filter)
|
||||
@ -760,18 +807,25 @@ class SmartBackupManager:
|
||||
flashable = [p for p in all_partitions if not p.get("empty") and p.get("file")]
|
||||
|
||||
# Step 2: Verify all referenced files exist and check SHA256
|
||||
# Cache decompressed data so we don't decompress twice (validation + flash)
|
||||
_decompressed_cache: dict[str, bytes] = {}
|
||||
validation_errors: list[str] = []
|
||||
for part in flashable:
|
||||
file_path = backup_path / part["file"]
|
||||
try:
|
||||
file_path = _safe_resolve(backup_path, part["file"])
|
||||
except ValueError as e:
|
||||
validation_errors.append(str(e))
|
||||
continue
|
||||
if not file_path.exists():
|
||||
validation_errors.append(f"Missing file: {part['file']}")
|
||||
continue
|
||||
|
||||
# Verify hash
|
||||
if part.get("compressed"):
|
||||
# Decompress to check hash of original data
|
||||
# Decompress to check hash; cache for reuse during flash
|
||||
try:
|
||||
decompressed = gzip.decompress(file_path.read_bytes())
|
||||
_decompressed_cache[part["file"]] = decompressed
|
||||
actual_hash = _sha256_bytes(decompressed)
|
||||
except Exception as e:
|
||||
validation_errors.append(f"Decompression failed for {part['file']}: {e}")
|
||||
@ -826,16 +880,18 @@ class SmartBackupManager:
|
||||
|
||||
try:
|
||||
for part in flashable:
|
||||
file_path = backup_path / part["file"]
|
||||
file_path = _safe_resolve(backup_path, part["file"])
|
||||
offset = part["offset"]
|
||||
|
||||
if part.get("compressed"):
|
||||
# Decompress to temp file for flashing
|
||||
decompressed = gzip.decompress(file_path.read_bytes())
|
||||
# Use cached decompressed data from validation, or decompress now
|
||||
decompressed = _decompressed_cache.pop(part["file"], None)
|
||||
if decompressed is None:
|
||||
decompressed = gzip.decompress(file_path.read_bytes())
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".bin", delete=False)
|
||||
tmp_path = Path(tmp.name)
|
||||
tmp_path.write_bytes(decompressed)
|
||||
tmp.close()
|
||||
tmp_path.write_bytes(decompressed)
|
||||
temp_files.append(tmp_path)
|
||||
flash_pairs.extend([offset, str(tmp_path)])
|
||||
total_flash_bytes += len(decompressed)
|
||||
@ -916,10 +972,7 @@ class SmartBackupManager:
|
||||
pass
|
||||
|
||||
try:
|
||||
progress_done = sum(
|
||||
1 for v in verify_results
|
||||
)
|
||||
await context.report_progress(progress_done, len(flashable))
|
||||
await context.report_progress(len(verify_results), len(flashable))
|
||||
except Exception:
|
||||
logger.debug("Progress reporting not supported by client")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user