diff --git a/src/mcesptool/components/smart_backup.py b/src/mcesptool/components/smart_backup.py index 1ba0793..d36ea4f 100644 --- a/src/mcesptool/components/smart_backup.py +++ b/src/mcesptool/components/smart_backup.py @@ -66,9 +66,9 @@ def _parse_flash_size_string(size_str: str) -> int: return int(size_str, 0) -def _is_chunk_empty(data: bytes) -> bool: +def _is_chunk_empty(data: bytes | bytearray) -> bool: """Check whether a data chunk is entirely erased flash (0xFF).""" - return data == b"\xff" * len(data) + return data.count(b"\xff") == len(data) def _sha256_file(path: Path) -> str: @@ -83,7 +83,7 @@ def _sha256_file(path: Path) -> str: return h.hexdigest() -def _sha256_bytes(data: bytes) -> str: +def _sha256_bytes(data: bytes | bytearray) -> str: """Compute SHA256 hex digest of in-memory bytes.""" return hashlib.sha256(data).hexdigest() @@ -93,6 +93,19 @@ def _normalize_mac(mac: str) -> str: return mac.lower().replace(":", "-") +def _safe_resolve(base: Path, untrusted: str) -> Path: + """Resolve an untrusted relative path within a base directory. + + Prevents path traversal attacks (e.g. ../../etc/passwd) by verifying + the resolved path is a child of the base directory. + """ + resolved = (base / untrusted).resolve() + base_resolved = base.resolve() + if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved: + raise ValueError(f"Path traversal detected: {untrusted!r} escapes {base}") + return resolved + + def _parse_partition_table_binary(raw: bytes) -> list[dict[str, Any]]: """Parse ESP32 binary partition table format. @@ -292,9 +305,9 @@ class SmartBackupManager: return result def _compute_chunk_timeout(self, chunk_bytes: int) -> float: - """Compute per-chunk timeout: max(60, chunk_mb * 100) seconds.""" + """Compute per-chunk timeout, capped at 240s to stay under MCP 300s limit.""" chunk_mb = chunk_bytes / (1024 * 1024) - return max(60.0, chunk_mb * 100.0) + return min(240.0, max(60.0, chunk_mb * 100.0)) def _adapt_chunk_size( self, @@ -440,6 +453,12 @@ class SmartBackupManager: "error": "Port is required for smart backup", } + if not (1 <= chunk_mb <= 16): + return { + "success": False, + "error": f"chunk_mb must be between 1 and 16, got {chunk_mb}", + } + overall_start = time.time() # Step 1: Get chip info via flash-id @@ -637,13 +656,21 @@ class SmartBackupManager: if is_empty: # Don't write a file, just record it in the manifest sha = _sha256_bytes(region_data) - elif compress: - compressed = gzip.compress(bytes(region_data), compresslevel=6) - region_file.write_bytes(compressed) - sha = _sha256_bytes(bytes(region_data)) else: - region_file.write_bytes(bytes(region_data)) - sha = _sha256_file(region_file) + try: + if compress: + compressed = gzip.compress(bytes(region_data), compresslevel=6) + region_file.write_bytes(compressed) + sha = _sha256_bytes(bytes(region_data)) + else: + region_file.write_bytes(bytes(region_data)) + sha = _sha256_file(region_file) + except OSError as e: + return { + "success": False, + "error": f"Failed writing {region_file}: {e}", + "partial_backup": str(backup_path), + } manifest_partitions.append({ "name": region["name"], @@ -676,7 +703,17 @@ class SmartBackupManager: } manifest_path = backup_path / "manifest.json" - manifest_path.write_text(json.dumps(manifest, indent=2) + "\n") + try: + # Atomic write: write to temp file in same directory, then rename + tmp_manifest = manifest_path.with_suffix(".json.tmp") + tmp_manifest.write_text(json.dumps(manifest, indent=2) + "\n") + tmp_manifest.rename(manifest_path) + except OSError as e: + return { + "success": False, + "error": f"Failed writing manifest: {e}", + "partial_backup": str(backup_path), + } # Build summary backed_up = [p for p in manifest_partitions if not p["empty"]] @@ -743,6 +780,16 @@ class SmartBackupManager: all_partitions = manifest.get("partitions", []) + # Validate required fields in each partition entry + required_fields = {"name", "offset", "size_bytes"} + for idx, p in enumerate(all_partitions): + missing = required_fields - set(p.keys()) + if missing: + return { + "success": False, + "error": f"Manifest partition[{idx}] missing fields: {', '.join(sorted(missing))}", + } + # Filter partitions if partition_filter: filter_set = set(partition_filter) @@ -760,18 +807,25 @@ class SmartBackupManager: flashable = [p for p in all_partitions if not p.get("empty") and p.get("file")] # Step 2: Verify all referenced files exist and check SHA256 + # Cache decompressed data so we don't decompress twice (validation + flash) + _decompressed_cache: dict[str, bytes] = {} validation_errors: list[str] = [] for part in flashable: - file_path = backup_path / part["file"] + try: + file_path = _safe_resolve(backup_path, part["file"]) + except ValueError as e: + validation_errors.append(str(e)) + continue if not file_path.exists(): validation_errors.append(f"Missing file: {part['file']}") continue # Verify hash if part.get("compressed"): - # Decompress to check hash of original data + # Decompress to check hash; cache for reuse during flash try: decompressed = gzip.decompress(file_path.read_bytes()) + _decompressed_cache[part["file"]] = decompressed actual_hash = _sha256_bytes(decompressed) except Exception as e: validation_errors.append(f"Decompression failed for {part['file']}: {e}") @@ -826,16 +880,18 @@ class SmartBackupManager: try: for part in flashable: - file_path = backup_path / part["file"] + file_path = _safe_resolve(backup_path, part["file"]) offset = part["offset"] if part.get("compressed"): - # Decompress to temp file for flashing - decompressed = gzip.decompress(file_path.read_bytes()) + # Use cached decompressed data from validation, or decompress now + decompressed = _decompressed_cache.pop(part["file"], None) + if decompressed is None: + decompressed = gzip.decompress(file_path.read_bytes()) tmp = tempfile.NamedTemporaryFile(suffix=".bin", delete=False) tmp_path = Path(tmp.name) - tmp_path.write_bytes(decompressed) tmp.close() + tmp_path.write_bytes(decompressed) temp_files.append(tmp_path) flash_pairs.extend([offset, str(tmp_path)]) total_flash_bytes += len(decompressed) @@ -916,10 +972,7 @@ class SmartBackupManager: pass try: - progress_done = sum( - 1 for v in verify_results - ) - await context.report_progress(progress_done, len(flashable)) + await context.report_progress(len(verify_results), len(flashable)) except Exception: logger.debug("Progress reporting not supported by client")