Harden smart backup against review findings
- Path traversal protection: validate manifest file paths stay within backup directory (prevents ../../etc/passwd in crafted manifests) - Input validation: reject chunk_mb outside 1-16 range - Memory: _is_chunk_empty uses count() instead of allocating comparison buffer - Type safety: accept bytes | bytearray in hash/empty-check helpers - Timeout cap: _compute_chunk_timeout capped at 240s (MCP limit is 300s) - Manifest schema: validate required fields before processing - Disk full: wrap file writes in OSError handling with partial_backup path - Atomic manifest: write via temp+rename to prevent corrupt state - Decompression cache: avoid decompressing twice during restore - Tempfile race: close NamedTemporaryFile before Path.write_bytes - Unused variable: replace sum(1 for v in ...) with len()
This commit is contained in:
parent
9fa314dae9
commit
a07b3a0fd3
@ -66,9 +66,9 @@ def _parse_flash_size_string(size_str: str) -> int:
|
|||||||
return int(size_str, 0)
|
return int(size_str, 0)
|
||||||
|
|
||||||
|
|
||||||
def _is_chunk_empty(data: bytes) -> bool:
|
def _is_chunk_empty(data: bytes | bytearray) -> bool:
|
||||||
"""Check whether a data chunk is entirely erased flash (0xFF)."""
|
"""Check whether a data chunk is entirely erased flash (0xFF)."""
|
||||||
return data == b"\xff" * len(data)
|
return data.count(b"\xff") == len(data)
|
||||||
|
|
||||||
|
|
||||||
def _sha256_file(path: Path) -> str:
|
def _sha256_file(path: Path) -> str:
|
||||||
@ -83,7 +83,7 @@ def _sha256_file(path: Path) -> str:
|
|||||||
return h.hexdigest()
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _sha256_bytes(data: bytes) -> str:
|
def _sha256_bytes(data: bytes | bytearray) -> str:
|
||||||
"""Compute SHA256 hex digest of in-memory bytes."""
|
"""Compute SHA256 hex digest of in-memory bytes."""
|
||||||
return hashlib.sha256(data).hexdigest()
|
return hashlib.sha256(data).hexdigest()
|
||||||
|
|
||||||
@ -93,6 +93,19 @@ def _normalize_mac(mac: str) -> str:
|
|||||||
return mac.lower().replace(":", "-")
|
return mac.lower().replace(":", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_resolve(base: Path, untrusted: str) -> Path:
|
||||||
|
"""Resolve an untrusted relative path within a base directory.
|
||||||
|
|
||||||
|
Prevents path traversal attacks (e.g. ../../etc/passwd) by verifying
|
||||||
|
the resolved path is a child of the base directory.
|
||||||
|
"""
|
||||||
|
resolved = (base / untrusted).resolve()
|
||||||
|
base_resolved = base.resolve()
|
||||||
|
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||||
|
raise ValueError(f"Path traversal detected: {untrusted!r} escapes {base}")
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
def _parse_partition_table_binary(raw: bytes) -> list[dict[str, Any]]:
|
def _parse_partition_table_binary(raw: bytes) -> list[dict[str, Any]]:
|
||||||
"""Parse ESP32 binary partition table format.
|
"""Parse ESP32 binary partition table format.
|
||||||
|
|
||||||
@ -292,9 +305,9 @@ class SmartBackupManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _compute_chunk_timeout(self, chunk_bytes: int) -> float:
|
def _compute_chunk_timeout(self, chunk_bytes: int) -> float:
|
||||||
"""Compute per-chunk timeout: max(60, chunk_mb * 100) seconds."""
|
"""Compute per-chunk timeout, capped at 240s to stay under MCP 300s limit."""
|
||||||
chunk_mb = chunk_bytes / (1024 * 1024)
|
chunk_mb = chunk_bytes / (1024 * 1024)
|
||||||
return max(60.0, chunk_mb * 100.0)
|
return min(240.0, max(60.0, chunk_mb * 100.0))
|
||||||
|
|
||||||
def _adapt_chunk_size(
|
def _adapt_chunk_size(
|
||||||
self,
|
self,
|
||||||
@ -440,6 +453,12 @@ class SmartBackupManager:
|
|||||||
"error": "Port is required for smart backup",
|
"error": "Port is required for smart backup",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not (1 <= chunk_mb <= 16):
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"chunk_mb must be between 1 and 16, got {chunk_mb}",
|
||||||
|
}
|
||||||
|
|
||||||
overall_start = time.time()
|
overall_start = time.time()
|
||||||
|
|
||||||
# Step 1: Get chip info via flash-id
|
# Step 1: Get chip info via flash-id
|
||||||
@ -637,13 +656,21 @@ class SmartBackupManager:
|
|||||||
if is_empty:
|
if is_empty:
|
||||||
# Don't write a file, just record it in the manifest
|
# Don't write a file, just record it in the manifest
|
||||||
sha = _sha256_bytes(region_data)
|
sha = _sha256_bytes(region_data)
|
||||||
elif compress:
|
|
||||||
compressed = gzip.compress(bytes(region_data), compresslevel=6)
|
|
||||||
region_file.write_bytes(compressed)
|
|
||||||
sha = _sha256_bytes(bytes(region_data))
|
|
||||||
else:
|
else:
|
||||||
region_file.write_bytes(bytes(region_data))
|
try:
|
||||||
sha = _sha256_file(region_file)
|
if compress:
|
||||||
|
compressed = gzip.compress(bytes(region_data), compresslevel=6)
|
||||||
|
region_file.write_bytes(compressed)
|
||||||
|
sha = _sha256_bytes(bytes(region_data))
|
||||||
|
else:
|
||||||
|
region_file.write_bytes(bytes(region_data))
|
||||||
|
sha = _sha256_file(region_file)
|
||||||
|
except OSError as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Failed writing {region_file}: {e}",
|
||||||
|
"partial_backup": str(backup_path),
|
||||||
|
}
|
||||||
|
|
||||||
manifest_partitions.append({
|
manifest_partitions.append({
|
||||||
"name": region["name"],
|
"name": region["name"],
|
||||||
@ -676,7 +703,17 @@ class SmartBackupManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
manifest_path = backup_path / "manifest.json"
|
manifest_path = backup_path / "manifest.json"
|
||||||
manifest_path.write_text(json.dumps(manifest, indent=2) + "\n")
|
try:
|
||||||
|
# Atomic write: write to temp file in same directory, then rename
|
||||||
|
tmp_manifest = manifest_path.with_suffix(".json.tmp")
|
||||||
|
tmp_manifest.write_text(json.dumps(manifest, indent=2) + "\n")
|
||||||
|
tmp_manifest.rename(manifest_path)
|
||||||
|
except OSError as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Failed writing manifest: {e}",
|
||||||
|
"partial_backup": str(backup_path),
|
||||||
|
}
|
||||||
|
|
||||||
# Build summary
|
# Build summary
|
||||||
backed_up = [p for p in manifest_partitions if not p["empty"]]
|
backed_up = [p for p in manifest_partitions if not p["empty"]]
|
||||||
@ -743,6 +780,16 @@ class SmartBackupManager:
|
|||||||
|
|
||||||
all_partitions = manifest.get("partitions", [])
|
all_partitions = manifest.get("partitions", [])
|
||||||
|
|
||||||
|
# Validate required fields in each partition entry
|
||||||
|
required_fields = {"name", "offset", "size_bytes"}
|
||||||
|
for idx, p in enumerate(all_partitions):
|
||||||
|
missing = required_fields - set(p.keys())
|
||||||
|
if missing:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Manifest partition[{idx}] missing fields: {', '.join(sorted(missing))}",
|
||||||
|
}
|
||||||
|
|
||||||
# Filter partitions
|
# Filter partitions
|
||||||
if partition_filter:
|
if partition_filter:
|
||||||
filter_set = set(partition_filter)
|
filter_set = set(partition_filter)
|
||||||
@ -760,18 +807,25 @@ class SmartBackupManager:
|
|||||||
flashable = [p for p in all_partitions if not p.get("empty") and p.get("file")]
|
flashable = [p for p in all_partitions if not p.get("empty") and p.get("file")]
|
||||||
|
|
||||||
# Step 2: Verify all referenced files exist and check SHA256
|
# Step 2: Verify all referenced files exist and check SHA256
|
||||||
|
# Cache decompressed data so we don't decompress twice (validation + flash)
|
||||||
|
_decompressed_cache: dict[str, bytes] = {}
|
||||||
validation_errors: list[str] = []
|
validation_errors: list[str] = []
|
||||||
for part in flashable:
|
for part in flashable:
|
||||||
file_path = backup_path / part["file"]
|
try:
|
||||||
|
file_path = _safe_resolve(backup_path, part["file"])
|
||||||
|
except ValueError as e:
|
||||||
|
validation_errors.append(str(e))
|
||||||
|
continue
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
validation_errors.append(f"Missing file: {part['file']}")
|
validation_errors.append(f"Missing file: {part['file']}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Verify hash
|
# Verify hash
|
||||||
if part.get("compressed"):
|
if part.get("compressed"):
|
||||||
# Decompress to check hash of original data
|
# Decompress to check hash; cache for reuse during flash
|
||||||
try:
|
try:
|
||||||
decompressed = gzip.decompress(file_path.read_bytes())
|
decompressed = gzip.decompress(file_path.read_bytes())
|
||||||
|
_decompressed_cache[part["file"]] = decompressed
|
||||||
actual_hash = _sha256_bytes(decompressed)
|
actual_hash = _sha256_bytes(decompressed)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
validation_errors.append(f"Decompression failed for {part['file']}: {e}")
|
validation_errors.append(f"Decompression failed for {part['file']}: {e}")
|
||||||
@ -826,16 +880,18 @@ class SmartBackupManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for part in flashable:
|
for part in flashable:
|
||||||
file_path = backup_path / part["file"]
|
file_path = _safe_resolve(backup_path, part["file"])
|
||||||
offset = part["offset"]
|
offset = part["offset"]
|
||||||
|
|
||||||
if part.get("compressed"):
|
if part.get("compressed"):
|
||||||
# Decompress to temp file for flashing
|
# Use cached decompressed data from validation, or decompress now
|
||||||
decompressed = gzip.decompress(file_path.read_bytes())
|
decompressed = _decompressed_cache.pop(part["file"], None)
|
||||||
|
if decompressed is None:
|
||||||
|
decompressed = gzip.decompress(file_path.read_bytes())
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix=".bin", delete=False)
|
tmp = tempfile.NamedTemporaryFile(suffix=".bin", delete=False)
|
||||||
tmp_path = Path(tmp.name)
|
tmp_path = Path(tmp.name)
|
||||||
tmp_path.write_bytes(decompressed)
|
|
||||||
tmp.close()
|
tmp.close()
|
||||||
|
tmp_path.write_bytes(decompressed)
|
||||||
temp_files.append(tmp_path)
|
temp_files.append(tmp_path)
|
||||||
flash_pairs.extend([offset, str(tmp_path)])
|
flash_pairs.extend([offset, str(tmp_path)])
|
||||||
total_flash_bytes += len(decompressed)
|
total_flash_bytes += len(decompressed)
|
||||||
@ -916,10 +972,7 @@ class SmartBackupManager:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
progress_done = sum(
|
await context.report_progress(len(verify_results), len(flashable))
|
||||||
1 for v in verify_results
|
|
||||||
)
|
|
||||||
await context.report_progress(progress_done, len(flashable))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Progress reporting not supported by client")
|
logger.debug("Progress reporting not supported by client")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user