diff --git a/docs-site/src/content/docs/reference/index.mdx b/docs-site/src/content/docs/reference/index.mdx index 74fb5b6..a4e1a95 100644 --- a/docs-site/src/content/docs/reference/index.mdx +++ b/docs-site/src/content/docs/reference/index.mdx @@ -103,6 +103,14 @@ All tools follow a consistent pattern: they return a JSON object with a `success | `esp_list_tools` | List all tools by category | | `esp_health_check` | Environment health check | +### Smart Backup (2 tools + 1 prompt) + +| Tool | Description | +|------|-------------| +| `esp_smart_backup` | Partition-aware flash backup with chunked reads, skip-empty, and compression | +| `esp_smart_restore` | Restore a partition-aware backup with SHA256 verification before flashing | +| `esp_backup_plan` *(prompt)* | Generate an informed backup plan based on device partition layout | + ### Product Catalog (5 tools) | Tool | Description | diff --git a/src/mcesptool/components/__init__.py b/src/mcesptool/components/__init__.py index 3e66bd6..19c7678 100644 --- a/src/mcesptool/components/__init__.py +++ b/src/mcesptool/components/__init__.py @@ -15,6 +15,7 @@ from .product_catalog import ProductCatalog from .production_tools import ProductionTools from .qemu_manager import QemuManager from .security_manager import SecurityManager +from .smart_backup import SmartBackupManager # Component registry for dynamic loading COMPONENT_REGISTRY = { @@ -28,6 +29,7 @@ COMPONENT_REGISTRY = { "diagnostics": Diagnostics, "qemu_manager": QemuManager, "product_catalog": ProductCatalog, + "smart_backup": SmartBackupManager, } __all__ = [ @@ -41,5 +43,6 @@ __all__ = [ "Diagnostics", "QemuManager", "ProductCatalog", + "SmartBackupManager", "COMPONENT_REGISTRY", ] diff --git a/src/mcesptool/components/smart_backup.py b/src/mcesptool/components/smart_backup.py new file mode 100644 index 0000000..d36ea4f --- /dev/null +++ b/src/mcesptool/components/smart_backup.py @@ -0,0 +1,1116 @@ +""" +Smart Backup Manager Component + +Partition-aware flash backup and restore for ESP devices. Reads the partition +table first, then backs up each partition individually with chunked reads to +stay well under the 300-second MCP timeout. Supports skip-empty detection, +gzip compression, and adaptive chunk sizing based on measured throughput. +""" + +import asyncio +import gzip +import hashlib +import json +import logging +import math +import re +import struct +import tempfile +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from fastmcp import Context, FastMCP + +from ..config import ESPToolServerConfig + +logger = logging.getLogger(__name__) + +# ESP partition table binary format constants +_ENTRY_SIZE = 32 +_PT_MAGIC = 0x50AA + +# Reverse lookups for partition type/subtype names (matches partition_manager.py) +_PARTITION_TYPES = {0x00: "app", 0x01: "data"} +_APP_SUBTYPES = { + 0x00: "factory", 0x10: "ota_0", 0x11: "ota_1", + 0x12: "ota_2", 0x13: "ota_3", 0x20: "test", +} +_DATA_SUBTYPES = { + 0x00: "ota", 0x01: "phy", 0x02: "nvs", 0x03: "coredump", + 0x04: "nvs_keys", 0x05: "efuse", 0x81: "fat", + 0x82: "spiffs", 0x83: "littlefs", +} + +# Default estimated throughput for read operations (bytes/sec) +_DEFAULT_THROUGHPUT_BPS = 11 * 1024 # ~11 KB/s conservative estimate + + +def _format_size(size_bytes: int) -> str: + """Format byte count as human-readable size.""" + if size_bytes >= 1024 * 1024 and size_bytes % (1024 * 1024) == 0: + return f"{size_bytes // (1024 * 1024)}MB" + if size_bytes >= 1024 and size_bytes % 1024 == 0: + return f"{size_bytes // 1024}KB" + return f"{size_bytes}B" + + +def _parse_flash_size_string(size_str: str) -> int: + """Parse a flash size like '4MB', '16MB', '512KB' into bytes.""" + size_str = size_str.strip().upper() + if size_str.endswith("MB"): + return int(size_str[:-2]) * 1024 * 1024 + if size_str.endswith("KB"): + return int(size_str[:-2]) * 1024 + return int(size_str, 0) + + +def _is_chunk_empty(data: bytes | bytearray) -> bool: + """Check whether a data chunk is entirely erased flash (0xFF).""" + return data.count(b"\xff") == len(data) + + +def _sha256_file(path: Path) -> str: + """Compute SHA256 hex digest of a file.""" + h = hashlib.sha256() + with open(path, "rb") as f: + while True: + block = f.read(65536) + if not block: + break + h.update(block) + return h.hexdigest() + + +def _sha256_bytes(data: bytes | bytearray) -> str: + """Compute SHA256 hex digest of in-memory bytes.""" + return hashlib.sha256(data).hexdigest() + + +def _normalize_mac(mac: str) -> str: + """Normalize a MAC address to lowercase-dashed form (80-f1-b2-d1-c5-4e).""" + return mac.lower().replace(":", "-") + + +def _safe_resolve(base: Path, untrusted: str) -> Path: + """Resolve an untrusted relative path within a base directory. + + Prevents path traversal attacks (e.g. ../../etc/passwd) by verifying + the resolved path is a child of the base directory. + """ + resolved = (base / untrusted).resolve() + base_resolved = base.resolve() + if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved: + raise ValueError(f"Path traversal detected: {untrusted!r} escapes {base}") + return resolved + + +def _parse_partition_table_binary(raw: bytes) -> list[dict[str, Any]]: + """Parse ESP32 binary partition table format. + + Each entry is 32 bytes: + - 2 bytes: magic (0xAA50) + - 1 byte: type + - 1 byte: subtype + - 4 bytes: offset (LE) + - 4 bytes: size (LE) + - 16 bytes: name (null-terminated) + - 4 bytes: flags + """ + partitions: list[dict[str, Any]] = [] + for i in range(0, len(raw) - _ENTRY_SIZE + 1, _ENTRY_SIZE): + entry = raw[i : i + _ENTRY_SIZE] + magic = struct.unpack_from(" dict[str, Any]: + """Extract chip info from esptool flash-id output.""" + info: dict[str, Any] = {} + + chip_match = re.search(r"Chip type:\s*(.+?)(?:\n|$)", output) + if not chip_match: + chip_match = re.search(r"Chip is\s+(.+?)(?:\n|$)", output) + if not chip_match: + chip_match = re.search(r"Detecting chip type[.\u2026]+\s*(\S+)", output) + if chip_match: + info["type"] = chip_match.group(1).strip() + + mac_match = re.search(r"MAC:\s*([0-9a-f:]+)", output, re.IGNORECASE) + if mac_match: + info["mac"] = mac_match.group(1) + + flash_size_match = re.search(r"Detected flash size:\s*(\S+)", output) + if flash_size_match: + size_str = flash_size_match.group(1) + info["flash_size_str"] = size_str + try: + info["flash_size_bytes"] = _parse_flash_size_string(size_str) + except (ValueError, TypeError): + pass + + return info + + +class SmartBackupManager: + """Partition-aware flash backup and restore for ESP devices""" + + def __init__(self, app: FastMCP, config: ESPToolServerConfig) -> None: + self.app = app + self.config = config + self._register_tools() + self._register_prompts() + + # ------------------------------------------------------------------ + # esptool subprocess runner (matches component pattern) + # ------------------------------------------------------------------ + + async def _run_esptool( + self, + port: str, + args: list[str], + timeout: float = 120.0, + ) -> dict[str, Any]: + """Run esptool with arbitrary args as an async subprocess. + + Args: + port: Serial port or socket:// URI + args: esptool arguments after --port + timeout: Timeout in seconds + + Returns: + dict with "success", "output", and optionally "error" + """ + cmd = [ + self.config.esptool_path, + "--port", port, + *args, + ] + proc = None + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) + output = (stdout or b"").decode() + (stderr or b"").decode() + + if proc.returncode != 0: + return {"success": False, "error": output.strip()[:500]} + + return {"success": True, "output": output} + + except asyncio.TimeoutError: + if proc and proc.returncode is None: + proc.kill() + await proc.wait() + return {"success": False, "error": f"Timeout after {timeout}s"} + except FileNotFoundError: + return { + "success": False, + "error": f"esptool not found at {self.config.esptool_path}", + } + except Exception as e: + if proc and proc.returncode is None: + proc.kill() + await proc.wait() + return {"success": False, "error": str(e)} + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + async def _get_chip_info(self, port: str) -> dict[str, Any]: + """Run flash-id and return parsed chip info.""" + result = await self._run_esptool(port, ["flash-id"], timeout=15.0) + if not result["success"]: + return {"success": False, "error": result["error"]} + + info = _parse_chip_info(result["output"]) + if not info.get("flash_size_bytes"): + return {"success": False, "error": "Could not determine flash size from flash-id output"} + + info["success"] = True + return info + + async def _read_partition_table(self, port: str) -> dict[str, Any]: + """Read and parse partition table from device.""" + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + result = await self._run_esptool( + port, + ["read-flash", "0x8000", "0xC00", str(tmp_path)], + timeout=60.0, + ) + if not result["success"]: + return {"success": False, "error": result["error"]} + + raw = tmp_path.read_bytes() + partitions = _parse_partition_table_binary(raw) + return {"success": True, "partitions": partitions, "raw": raw[:0xC00]} + finally: + try: + tmp_path.unlink() + except OSError: + pass + + async def _read_flash_chunk( + self, + port: str, + offset: int, + size: int, + out_path: Path, + timeout: float, + ) -> dict[str, Any]: + """Read a single chunk of flash to a file.""" + result = await self._run_esptool( + port, + ["read-flash", f"0x{offset:x}", str(size), str(out_path)], + timeout=timeout, + ) + return result + + def _compute_chunk_timeout(self, chunk_bytes: int) -> float: + """Compute per-chunk timeout, capped at 240s to stay under MCP 300s limit.""" + chunk_mb = chunk_bytes / (1024 * 1024) + return min(240.0, max(60.0, chunk_mb * 100.0)) + + def _adapt_chunk_size( + self, + measured_seconds: float, + chunk_bytes: int, + current_chunk_mb: int, + ) -> int: + """Adapt chunk size based on measured throughput. + + If the first chunk was faster than expected (< chunk_mb * 80s), + increase the chunk size for remaining reads, capping so that each + chunk stays under a 60-second safety margin from the 300s MCP timeout. + """ + if measured_seconds <= 0: + return current_chunk_mb + + throughput_bps = chunk_bytes / measured_seconds + # Target: each chunk finishes within 240 seconds (300s - 60s margin) + max_bytes_per_chunk = throughput_bps * 240.0 + new_chunk_mb = int(max_bytes_per_chunk / (1024 * 1024)) + + # Keep within reasonable bounds + new_chunk_mb = max(current_chunk_mb, min(new_chunk_mb, 16)) + + expected_time = current_chunk_mb * 80 + if measured_seconds < expected_time: + logger.info( + "Chunk read faster than expected (%.1fs < %ds), " + "increasing chunk size from %dMB to %dMB", + measured_seconds, expected_time, current_chunk_mb, new_chunk_mb, + ) + return new_chunk_mb + + return current_chunk_mb + + # ------------------------------------------------------------------ + # Tool registration + # ------------------------------------------------------------------ + + def _register_tools(self) -> None: + """Register smart backup/restore tools""" + + @self.app.tool("esp_smart_backup") + async def smart_backup( + context: Context, + output_dir: str, + port: str | None = None, + partitions: list[str] | None = None, + skip_empty: bool = True, + compress: bool = False, + chunk_mb: int = 2, + dry_run: bool = False, + ) -> dict[str, Any]: + """Partition-aware flash backup that reads each partition individually. + + Avoids the 300-second MCP timeout by chunking large reads and + reporting progress. Optionally skips erased (0xFF) regions and + compresses output files. + + Creates a backup directory containing one file per partition, a + manifest.json with SHA256 hashes, and chip metadata for restore. + + Args: + output_dir: Base directory for backup files. A subdirectory + is created using the device MAC and a timestamp. + port: Serial port or socket:// URI (auto-detect if not specified) + partitions: Specific partition names to back up, or all if omitted + skip_empty: Detect 0xFF-only chunks and skip writing them (default: true) + compress: Gzip individual partition backup files (default: false) + chunk_mb: Chunk size in megabytes for large partition reads (default: 2) + dry_run: Estimate backup time without reading flash (default: false) + """ + return await self._smart_backup_impl( + context, output_dir, port, partitions, + skip_empty, compress, chunk_mb, dry_run, + ) + + @self.app.tool("esp_smart_restore") + async def smart_restore( + context: Context, + backup_dir: str, + port: str | None = None, + partitions: list[str] | None = None, + verify: bool = True, + dry_run: bool = False, + ) -> dict[str, Any]: + """Restore a partition-aware backup created by esp_smart_backup. + + Reads manifest.json from the backup directory, validates file + integrity via SHA256, and flashes each partition back to the device. + + Args: + backup_dir: Directory containing manifest.json and backup files + port: Serial port or socket:// URI (auto-detect if not specified) + partitions: Specific partition names to restore, or all if omitted + verify: Read back and verify hashes after flashing (default: true) + dry_run: Show what would be flashed without writing (default: false) + """ + return await self._smart_restore_impl( + context, backup_dir, port, partitions, verify, dry_run, + ) + + # ------------------------------------------------------------------ + # Prompt registration + # ------------------------------------------------------------------ + + def _register_prompts(self) -> None: + """Register MCP prompts""" + + @self.app.prompt("esp_backup_plan") + async def backup_plan(port: str | None = None) -> str: + """Generate an informed backup plan for the connected ESP device. + + Reads the partition table and chip info, then calculates estimated + backup time per partition based on conservative throughput estimates. + + Args: + port: Serial port or socket:// URI (optional, required if + multiple devices are connected) + """ + return await self._backup_plan_impl(port) + + # ------------------------------------------------------------------ + # Smart backup implementation + # ------------------------------------------------------------------ + + async def _smart_backup_impl( + self, + context: Context, + output_dir: str, + port: str | None, + partition_filter: list[str] | None, + skip_empty: bool, + compress: bool, + chunk_mb: int, + dry_run: bool, + ) -> dict[str, Any]: + """Partition-aware backup implementation.""" + + if not port: + return { + "success": False, + "error": "Port is required for smart backup", + } + + if not (1 <= chunk_mb <= 16): + return { + "success": False, + "error": f"chunk_mb must be between 1 and 16, got {chunk_mb}", + } + + overall_start = time.time() + + # Step 1: Get chip info via flash-id + chip_info = await self._get_chip_info(port) + if not chip_info.get("success"): + return { + "success": False, + "error": f"Could not identify chip: {chip_info.get('error', 'unknown')}", + } + + flash_size_bytes = chip_info["flash_size_bytes"] + mac = chip_info.get("mac", "unknown") + chip_type = chip_info.get("type", "unknown") + + # Step 2: Read partition table + pt_result = await self._read_partition_table(port) + if not pt_result.get("success"): + return { + "success": False, + "error": f"Could not read partition table: {pt_result.get('error', 'unknown')}", + } + + device_partitions = pt_result["partitions"] + + # Filter partitions if requested + if partition_filter: + filter_set = set(partition_filter) + selected = [p for p in device_partitions if p["name"] in filter_set] + unknown = filter_set - {p["name"] for p in selected} + if unknown: + return { + "success": False, + "error": f"Unknown partition(s): {', '.join(sorted(unknown))}", + "available": [p["name"] for p in device_partitions], + } + device_partitions = selected + + # Determine first partition offset for bootloader region + if device_partitions: + first_offset = min(p["offset"] for p in device_partitions) + else: + first_offset = 0x8000 # default: up to partition table + + # Build the list of regions to back up: + # 1) bootloader (0x0 .. first_offset) + # 2) partition table (0x8000 .. 0x8C00) + # 3) each partition + backup_regions: list[dict[str, Any]] = [] + + # Only include bootloader if we are not filtering to specific partitions + if not partition_filter: + bootloader_size = min(first_offset, 0x8000) + if bootloader_size > 0: + backup_regions.append({ + "name": "bootloader", + "type": "boot", + "subtype": "boot", + "offset": 0, + "size": bootloader_size, + "file": "bootloader.bin", + }) + + backup_regions.append({ + "name": "partition-table", + "type": "data", + "subtype": "partition-table", + "offset": 0x8000, + "size": 0xC00, + "file": "partition-table.bin", + }) + + for p in device_partitions: + safe_name = re.sub(r"[^a-zA-Z0-9_-]", "_", p["name"]) + ext = ".bin.gz" if compress else ".bin" + backup_regions.append({ + "name": p["name"], + "type": p["type"], + "subtype": p["subtype"], + "offset": p["offset"], + "size": p["size"], + "file": f"{safe_name}{ext}", + }) + + total_bytes = sum(r["size"] for r in backup_regions) + + # Dry run: estimate time and return plan + if dry_run: + estimated_seconds = total_bytes / _DEFAULT_THROUGHPUT_BPS + return { + "success": True, + "dry_run": True, + "chip": { + "type": chip_type, + "mac": mac, + "flash_size_bytes": flash_size_bytes, + }, + "regions": [ + { + "name": r["name"], + "offset": f"0x{r['offset']:x}", + "size_bytes": r["size"], + "size_human": _format_size(r["size"]), + "estimated_seconds": round(r["size"] / _DEFAULT_THROUGHPUT_BPS, 1), + "chunks": math.ceil(r["size"] / (chunk_mb * 1024 * 1024)), + } + for r in backup_regions + ], + "total_bytes": total_bytes, + "total_human": _format_size(total_bytes), + "estimated_seconds": round(estimated_seconds, 1), + "estimated_minutes": round(estimated_seconds / 60, 1), + } + + # Create backup directory: {output_dir}/{mac_normalized}/{timestamp} + mac_dir = _normalize_mac(mac) if mac != "unknown" else "unknown" + ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H%M%S") + backup_path = Path(output_dir) / mac_dir / ts + backup_path.mkdir(parents=True, exist_ok=True) + + # Step 3-5: Read each region chunk by chunk + manifest_partitions: list[dict[str, Any]] = [] + bytes_read_total = 0 + current_chunk_mb = chunk_mb + first_chunk_measured = False + + for region in backup_regions: + region_offset = region["offset"] + region_size = region["size"] + region_file = backup_path / region["file"] + + num_chunks = math.ceil(region_size / (current_chunk_mb * 1024 * 1024)) + + region_data = bytearray() + region_empty = True + chunk_offset = region_offset + + for chunk_idx in range(num_chunks): + remaining = region_size - (chunk_idx * current_chunk_mb * 1024 * 1024) + this_chunk_size = min(current_chunk_mb * 1024 * 1024, remaining) + chunk_timeout = self._compute_chunk_timeout(this_chunk_size) + + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as tmp: + tmp_chunk = Path(tmp.name) + + try: + chunk_start = time.time() + result = await self._read_flash_chunk( + port, chunk_offset, this_chunk_size, tmp_chunk, chunk_timeout, + ) + + if not result["success"]: + return { + "success": False, + "error": ( + f"Failed reading {region['name']} at " + f"0x{chunk_offset:x}: {result['error']}" + ), + "partial_backup": str(backup_path), + } + + chunk_elapsed = time.time() - chunk_start + chunk_data = tmp_chunk.read_bytes() + + # Adaptive chunk sizing after first chunk + if not first_chunk_measured and chunk_elapsed > 0: + first_chunk_measured = True + current_chunk_mb = self._adapt_chunk_size( + chunk_elapsed, len(chunk_data), current_chunk_mb, + ) + + if not _is_chunk_empty(chunk_data): + region_empty = False + region_data.extend(chunk_data) + + bytes_read_total += len(chunk_data) + chunk_offset += this_chunk_size + + # Report progress (best-effort) + try: + await context.report_progress(bytes_read_total, total_bytes) + except Exception: + logger.debug("Progress reporting not supported by client") + + finally: + try: + tmp_chunk.unlink() + except OSError: + pass + + # Write region file (skip if entirely empty and skip_empty is on) + actual_bytes = len(region_data) + is_empty = region_empty and skip_empty + sha = "" + + if is_empty: + # Don't write a file, just record it in the manifest + sha = _sha256_bytes(region_data) + else: + try: + if compress: + compressed = gzip.compress(bytes(region_data), compresslevel=6) + region_file.write_bytes(compressed) + sha = _sha256_bytes(bytes(region_data)) + else: + region_file.write_bytes(bytes(region_data)) + sha = _sha256_file(region_file) + except OSError as e: + return { + "success": False, + "error": f"Failed writing {region_file}: {e}", + "partial_backup": str(backup_path), + } + + manifest_partitions.append({ + "name": region["name"], + "type": region["type"], + "subtype": region.get("subtype", ""), + "offset": f"0x{region['offset']:x}", + "size_bytes": region["size"], + "file": region["file"] if not is_empty else None, + "sha256": sha, + "empty": is_empty, + "compressed": compress and not is_empty, + "actual_bytes": actual_bytes, + }) + + elapsed = round(time.time() - overall_start, 1) + + # Write manifest + manifest = { + "version": 1, + "timestamp": datetime.now(timezone.utc).isoformat(), + "chip": { + "type": chip_type, + "mac": mac, + "flash_size_bytes": flash_size_bytes, + }, + "partitions": manifest_partitions, + "total_bytes_read": bytes_read_total, + "total_flash_bytes": flash_size_bytes, + "elapsed_seconds": elapsed, + } + + manifest_path = backup_path / "manifest.json" + try: + # Atomic write: write to temp file in same directory, then rename + tmp_manifest = manifest_path.with_suffix(".json.tmp") + tmp_manifest.write_text(json.dumps(manifest, indent=2) + "\n") + tmp_manifest.rename(manifest_path) + except OSError as e: + return { + "success": False, + "error": f"Failed writing manifest: {e}", + "partial_backup": str(backup_path), + } + + # Build summary + backed_up = [p for p in manifest_partitions if not p["empty"]] + skipped = [p for p in manifest_partitions if p["empty"]] + + return { + "success": True, + "backup_dir": str(backup_path), + "manifest": str(manifest_path), + "chip": { + "type": chip_type, + "mac": mac, + "flash_size_bytes": flash_size_bytes, + }, + "partitions_backed_up": len(backed_up), + "partitions_skipped_empty": len(skipped), + "total_bytes_read": bytes_read_total, + "total_human": _format_size(bytes_read_total), + "elapsed_seconds": elapsed, + "files": [p["file"] for p in manifest_partitions if p["file"]], + } + + # ------------------------------------------------------------------ + # Smart restore implementation + # ------------------------------------------------------------------ + + async def _smart_restore_impl( + self, + context: Context, + backup_dir: str, + port: str | None, + partition_filter: list[str] | None, + verify: bool, + dry_run: bool, + ) -> dict[str, Any]: + """Partition-aware restore implementation.""" + + if not port: + return { + "success": False, + "error": "Port is required for smart restore", + } + + backup_path = Path(backup_dir) + manifest_path = backup_path / "manifest.json" + + if not manifest_path.exists(): + return { + "success": False, + "error": f"manifest.json not found in {backup_dir}", + } + + # Step 1: Read and validate manifest + try: + manifest = json.loads(manifest_path.read_text()) + except (json.JSONDecodeError, OSError) as e: + return {"success": False, "error": f"Failed to read manifest: {e}"} + + if manifest.get("version") != 1: + return { + "success": False, + "error": f"Unsupported manifest version: {manifest.get('version')}", + } + + all_partitions = manifest.get("partitions", []) + + # Validate required fields in each partition entry + required_fields = {"name", "offset", "size_bytes"} + for idx, p in enumerate(all_partitions): + missing = required_fields - set(p.keys()) + if missing: + return { + "success": False, + "error": f"Manifest partition[{idx}] missing fields: {', '.join(sorted(missing))}", + } + + # Filter partitions + if partition_filter: + filter_set = set(partition_filter) + selected = [p for p in all_partitions if p["name"] in filter_set] + unknown = filter_set - {p["name"] for p in selected} + if unknown: + return { + "success": False, + "error": f"Unknown partition(s) in manifest: {', '.join(sorted(unknown))}", + "available": [p["name"] for p in all_partitions], + } + all_partitions = selected + + # Skip empty partitions (nothing to flash) + flashable = [p for p in all_partitions if not p.get("empty") and p.get("file")] + + # Step 2: Verify all referenced files exist and check SHA256 + # Cache decompressed data so we don't decompress twice (validation + flash) + _decompressed_cache: dict[str, bytes] = {} + validation_errors: list[str] = [] + for part in flashable: + try: + file_path = _safe_resolve(backup_path, part["file"]) + except ValueError as e: + validation_errors.append(str(e)) + continue + if not file_path.exists(): + validation_errors.append(f"Missing file: {part['file']}") + continue + + # Verify hash + if part.get("compressed"): + # Decompress to check hash; cache for reuse during flash + try: + decompressed = gzip.decompress(file_path.read_bytes()) + _decompressed_cache[part["file"]] = decompressed + actual_hash = _sha256_bytes(decompressed) + except Exception as e: + validation_errors.append(f"Decompression failed for {part['file']}: {e}") + continue + else: + actual_hash = _sha256_file(file_path) + + if actual_hash != part.get("sha256"): + validation_errors.append( + f"SHA256 mismatch for {part['file']}: " + f"expected {part.get('sha256', 'N/A')[:16]}..., " + f"got {actual_hash[:16]}..." + ) + + if validation_errors: + return { + "success": False, + "error": "Backup validation failed", + "validation_errors": validation_errors, + } + + # Step 3: Dry run + if dry_run: + return { + "success": True, + "dry_run": True, + "chip_info": manifest.get("chip", {}), + "partitions_to_flash": [ + { + "name": p["name"], + "offset": p["offset"], + "size_bytes": p["size_bytes"], + "size_human": _format_size(p["size_bytes"]), + "file": p["file"], + "compressed": p.get("compressed", False), + } + for p in flashable + ], + "partitions_empty_skip": [ + p["name"] for p in all_partitions if p.get("empty") + ], + "total_bytes": sum(p["size_bytes"] for p in flashable), + } + + overall_start = time.time() + + # Step 4-5: Decompress if needed, then build flash args + # esptool write-flash takes pairs of (offset, file) + flash_pairs: list[str] = [] + temp_files: list[Path] = [] + total_flash_bytes = 0 + + try: + for part in flashable: + file_path = _safe_resolve(backup_path, part["file"]) + offset = part["offset"] + + if part.get("compressed"): + # Use cached decompressed data from validation, or decompress now + decompressed = _decompressed_cache.pop(part["file"], None) + if decompressed is None: + decompressed = gzip.decompress(file_path.read_bytes()) + tmp = tempfile.NamedTemporaryFile(suffix=".bin", delete=False) + tmp_path = Path(tmp.name) + tmp.close() + tmp_path.write_bytes(decompressed) + temp_files.append(tmp_path) + flash_pairs.extend([offset, str(tmp_path)]) + total_flash_bytes += len(decompressed) + else: + flash_pairs.extend([offset, str(file_path)]) + total_flash_bytes += file_path.stat().st_size + + # Flash everything in one esptool invocation + args = ["write-flash", "--compress"] + args.extend(flash_pairs) + + # Generous timeout: scale by data size + flash_timeout = max(180.0, total_flash_bytes / (10 * 1024)) + + try: + await context.report_progress(0, total_flash_bytes) + except Exception: + logger.debug("Progress reporting not supported by client") + + result = await self._run_esptool(port, args, timeout=flash_timeout) + + if not result["success"]: + return { + "success": False, + "error": f"Flash failed: {result['error']}", + "partitions_attempted": [p["name"] for p in flashable], + } + + try: + await context.report_progress(total_flash_bytes, total_flash_bytes) + except Exception: + logger.debug("Progress reporting not supported by client") + + finally: + # Clean up temp files + for tmp_path in temp_files: + try: + tmp_path.unlink() + except OSError: + pass + + elapsed_flash = round(time.time() - overall_start, 1) + + # Step 6: Verify by reading back and comparing hashes + verify_results: list[dict[str, Any]] = [] + if verify: + for part in flashable: + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as tmp: + verify_path = Path(tmp.name) + + try: + vresult = await self._read_flash_chunk( + port, + int(part["offset"], 16), + part["size_bytes"], + verify_path, + timeout=self._compute_chunk_timeout(part["size_bytes"]), + ) + if vresult["success"]: + readback_hash = _sha256_file(verify_path) + matched = readback_hash == part["sha256"] + verify_results.append({ + "name": part["name"], + "verified": matched, + "expected_sha256": part["sha256"][:16] + "...", + "actual_sha256": readback_hash[:16] + "...", + }) + else: + verify_results.append({ + "name": part["name"], + "verified": False, + "error": vresult["error"], + }) + finally: + try: + verify_path.unlink() + except OSError: + pass + + try: + await context.report_progress(len(verify_results), len(flashable)) + except Exception: + logger.debug("Progress reporting not supported by client") + + total_elapsed = round(time.time() - overall_start, 1) + + all_verified = all(v.get("verified") for v in verify_results) if verify_results else None + + return { + "success": True, + "backup_dir": str(backup_path), + "chip_info": manifest.get("chip", {}), + "partitions_flashed": len(flashable), + "partitions_skipped_empty": len(all_partitions) - len(flashable), + "total_bytes_flashed": total_flash_bytes, + "flash_elapsed_seconds": elapsed_flash, + "verified": all_verified, + "verify_results": verify_results if verify else None, + "total_elapsed_seconds": total_elapsed, + } + + # ------------------------------------------------------------------ + # Backup plan prompt implementation + # ------------------------------------------------------------------ + + async def _backup_plan_impl(self, port: str | None) -> str: + """Generate a structured backup plan message.""" + + if not port: + return ( + "I need a serial port to generate a backup plan. " + "Please provide the port parameter (e.g. /dev/ttyUSB0 or COM3). " + "You can use esp_scan_ports or esp_detect_chip to find connected devices." + ) + + # Get chip info + chip_info = await self._get_chip_info(port) + if not chip_info.get("success"): + return ( + f"Could not read chip info on port {port}: {chip_info.get('error', 'unknown')}. " + "Make sure the device is connected and not in use by another program." + ) + + # Get partition table + pt_result = await self._read_partition_table(port) + if not pt_result.get("success"): + return ( + f"Could not read partition table on port {port}: {pt_result.get('error', 'unknown')}. " + "The device flash may be blank or the partition table corrupted." + ) + + chip_type = chip_info.get("type", "Unknown") + mac = chip_info.get("mac", "unknown") + flash_bytes = chip_info.get("flash_size_bytes", 0) + flash_str = chip_info.get("flash_size_str", _format_size(flash_bytes)) + partitions = pt_result["partitions"] + + # Calculate estimates + lines: list[str] = [] + lines.append(f"# Backup Plan for {chip_type}") + lines.append("") + lines.append(f"- **Chip:** {chip_type}") + lines.append(f"- **MAC:** {mac}") + lines.append(f"- **Flash size:** {flash_str} ({flash_bytes:,} bytes)") + lines.append(f"- **Partitions found:** {len(partitions)}") + lines.append("") + lines.append("## Partition Layout") + lines.append("") + lines.append("| Name | Type | Subtype | Offset | Size | Est. Time |") + lines.append("|------|------|---------|--------|------|-----------|") + + total_est = 0.0 + for p in partitions: + est_sec = p["size"] / _DEFAULT_THROUGHPUT_BPS + total_est += est_sec + if est_sec < 60: + time_str = f"{est_sec:.0f}s" + else: + time_str = f"{est_sec / 60:.1f}min" + + lines.append( + f"| {p['name']} | {p['type']} | {p['subtype']} " + f"| 0x{p['offset']:x} | {_format_size(p['size'])} | ~{time_str} |" + ) + + # Add bootloader and partition table overhead + boot_est = 0x8000 / _DEFAULT_THROUGHPUT_BPS + pt_est = 0xC00 / _DEFAULT_THROUGHPUT_BPS + total_est += boot_est + pt_est + + lines.append("") + lines.append("## Time Estimates") + lines.append("") + lines.append(f"- Bootloader region (0x0-0x8000): ~{boot_est:.0f}s") + lines.append(f"- Partition table: ~{pt_est:.0f}s") + lines.append( + f"- **Total estimated time:** ~{total_est / 60:.1f} minutes " + f"(at {_DEFAULT_THROUGHPUT_BPS // 1024} KB/s)" + ) + lines.append("") + + if flash_bytes >= 8 * 1024 * 1024: + lines.append("## Recommendations") + lines.append("") + lines.append( + f"This device has {flash_str} of flash. A full backup at default " + "throughput will take a while. Consider:" + ) + lines.append("") + lines.append("- Use `skip_empty: true` (default) to skip erased regions") + lines.append("- Back up only specific partitions if you don't need everything") + lines.append( + "- The `chunk_mb` parameter controls read chunk size; " + "larger chunks are faster but risk timeouts on slow connections" + ) + lines.append("") + + lines.append("## Example Command") + lines.append("") + lines.append("To back up all partitions:") + lines.append("```") + lines.append(f'esp_smart_backup(port="{port}", output_dir="./backups")') + lines.append("```") + lines.append("") + lines.append("To back up specific partitions:") + names = [p["name"] for p in partitions[:3]] + lines.append("```") + lines.append( + f'esp_smart_backup(port="{port}", output_dir="./backups", ' + f'partitions={names})' + ) + lines.append("```") + + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Health check + # ------------------------------------------------------------------ + + async def health_check(self) -> dict[str, Any]: + """Component health check""" + return {"status": "healthy", "note": "Smart backup manager ready"} diff --git a/src/mcesptool/server.py b/src/mcesptool/server.py index 8efd5d8..7fd8c9f 100644 --- a/src/mcesptool/server.py +++ b/src/mcesptool/server.py @@ -26,6 +26,7 @@ from .components import ( ProductionTools, QemuManager, SecurityManager, + SmartBackupManager, ) from .config import ESPToolServerConfig, get_config, set_config @@ -72,6 +73,7 @@ class ESPToolServer: self.components["chip_control"] = ChipControl(self.app, self.config) self.components["flash_manager"] = FlashManager(self.app, self.config) self.components["partition_manager"] = PartitionManager(self.app, self.config) + self.components["smart_backup"] = SmartBackupManager(self.app, self.config) # Advanced features self.components["security_manager"] = SecurityManager(self.app, self.config) @@ -179,6 +181,10 @@ class ESPToolServer: "esp_performance_profile", "esp_diagnostic_report", ], + "smart_backup": [ + "esp_smart_backup", + "esp_smart_restore", + ], "product_catalog": [ "esp_product_search", "esp_product_info",