diff --git a/src/mcp_esptool_server/components/flash_manager.py b/src/mcp_esptool_server/components/flash_manager.py index 0f88b7a..6f022de 100644 --- a/src/mcp_esptool_server/components/flash_manager.py +++ b/src/mcp_esptool_server/components/flash_manager.py @@ -1,11 +1,16 @@ """ Flash Manager Component -Provides comprehensive ESP flash memory operations including reading, writing, -erasing, verification, and backup with production-grade safety features. +Provides ESP flash memory operations: write, read, erase, and backup. +All operations shell out to esptool as an async subprocess, matching +the pattern established in chip_control.py. """ +import asyncio import logging +import re +import time +from pathlib import Path from typing import Any from fastmcp import Context, FastMCP @@ -18,21 +23,84 @@ logger = logging.getLogger(__name__) class FlashManager: """ESP flash memory management and operations""" - def __init__(self, app: FastMCP, config: ESPToolServerConfig): + def __init__(self, app: FastMCP, config: ESPToolServerConfig) -> None: self.app = app self.config = config self._register_tools() + async def _run_esptool( + self, + port: str, + args: list[str], + timeout: float = 120.0, + ) -> dict[str, Any]: + """Run esptool with arbitrary args as an async subprocess. + + Args: + port: Serial port or socket:// URI + args: esptool arguments after --port (e.g. ["write-flash", "0x0", "fw.bin"]) + timeout: Timeout in seconds (flash operations can be slow) + + Returns: + dict with "success", "output", and optionally "error" + """ + cmd = [ + self.config.esptool_path, + "--port", port, + *args, + ] + proc = None + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) + output = (stdout or b"").decode() + (stderr or b"").decode() + + if proc.returncode != 0: + return {"success": False, "error": output.strip()[:500]} + + return {"success": True, "output": output} + + except asyncio.TimeoutError: + if proc and proc.returncode is None: + proc.kill() + await proc.wait() + return {"success": False, "error": f"Timeout after {timeout}s"} + except FileNotFoundError: + return { + "success": False, + "error": f"esptool not found at {self.config.esptool_path}", + } + except Exception as e: + if proc and proc.returncode is None: + proc.kill() + await proc.wait() + return {"success": False, "error": str(e)} + def _register_tools(self) -> None: """Register flash management tools""" @self.app.tool("esp_flash_firmware") async def flash_firmware( - context: Context, firmware_path: str, port: str | None = None, verify: bool = True + context: Context, + firmware_path: str, + port: str | None = None, + verify: bool = True, ) -> dict[str, Any]: - """Flash firmware to ESP device""" - # Implementation placeholder - return {"success": True, "note": "Implementation coming soon"} + """Flash firmware to ESP device. + + Writes a binary firmware file to the device's flash memory using esptool. + Supports any port including socket:// URIs for QEMU virtual devices. + + Args: + firmware_path: Path to the firmware binary (.bin) to flash + port: Serial port or socket:// URI (auto-detect if not specified) + verify: Verify flash contents after writing (default: true) + """ + return await self._flash_firmware_impl(context, firmware_path, port, verify) @self.app.tool("esp_flash_read") async def flash_read( @@ -42,8 +110,18 @@ class FlashManager: start_address: str = "0x0", size: str | None = None, ) -> dict[str, Any]: - """Read flash memory contents""" - return {"success": True, "note": "Implementation coming soon"} + """Read flash memory contents to a file. + + Reads raw bytes from flash and saves to the specified output path. + If size is not specified, reads the entire flash. + + Args: + output_path: File path to save the flash contents + port: Serial port or socket:// URI (auto-detect if not specified) + start_address: Flash offset to start reading from (hex string, default: "0x0") + size: Number of bytes to read (hex or decimal string, reads all if not specified) + """ + return await self._flash_read_impl(context, output_path, port, start_address, size) @self.app.tool("esp_flash_erase") async def flash_erase( @@ -52,8 +130,17 @@ class FlashManager: start_address: str = "0x0", size: str | None = None, ) -> dict[str, Any]: - """Erase flash memory regions""" - return {"success": True, "note": "Implementation coming soon"} + """Erase flash memory regions. + + Erases the entire flash if no start_address and size are given. + Otherwise erases the specified region. Erased bytes become 0xFF. + + Args: + port: Serial port or socket:// URI (auto-detect if not specified) + start_address: Flash offset to start erasing (hex string, default: "0x0") + size: Number of bytes to erase (hex or decimal string, erases all if not specified) + """ + return await self._flash_erase_impl(context, port, start_address, size) @self.app.tool("esp_flash_backup") async def flash_backup( @@ -62,8 +149,186 @@ class FlashManager: port: str | None = None, include_bootloader: bool = True, ) -> dict[str, Any]: - """Create complete flash backup""" - return {"success": True, "note": "Implementation coming soon"} + """Create complete flash backup to a file. + + Reads the entire flash contents and saves to the specified path. + The resulting file can be restored with esp_flash_firmware. + + Args: + backup_path: File path to save the flash backup + port: Serial port or socket:// URI (auto-detect if not specified) + include_bootloader: Start from address 0x0 to include bootloader (default: true) + """ + return await self._flash_backup_impl(context, backup_path, port, include_bootloader) + + async def _flash_firmware_impl( + self, + context: Context, + firmware_path: str, + port: str | None, + verify: bool, + ) -> dict[str, Any]: + """Write firmware to flash via esptool write-flash.""" + + fw_path = Path(firmware_path) + if not fw_path.exists(): + return {"success": False, "error": f"Firmware file not found: {firmware_path}"} + + if not port: + return {"success": False, "error": "Port is required (no auto-detect for flash operations)"} + + start_time = time.time() + + args = ["--no-stub", "write-flash", "0x0", str(fw_path)] + if not verify: + args.insert(0, "--no-verify") + + result = await self._run_esptool(port, args, timeout=180.0) + + if not result["success"]: + return { + "success": False, + "error": result["error"], + "port": port, + "firmware_path": firmware_path, + } + + output = result["output"] + elapsed = round(time.time() - start_time, 1) + + # Parse bytes written from output + bytes_written = 0 + write_matches = re.findall(r"Wrote (\d+) bytes", output) + for match in write_matches: + bytes_written += int(match) + + verified = "Hash of data verified" in output or "Verified" in output + + return { + "success": True, + "port": port, + "firmware_path": firmware_path, + "firmware_size": fw_path.stat().st_size, + "bytes_written": bytes_written, + "verified": verified if verify else None, + "elapsed_seconds": elapsed, + } + + async def _flash_read_impl( + self, + context: Context, + output_path: str, + port: str | None, + start_address: str, + size: str | None, + ) -> dict[str, Any]: + """Read flash contents via esptool read-flash.""" + + if not port: + return {"success": False, "error": "Port is required (no auto-detect for flash operations)"} + + # Determine read size — if not specified, read entire flash (detect first) + if not size: + detect = await self._run_esptool(port, ["flash-id"], timeout=15.0) + if not detect["success"]: + return {"success": False, "error": f"Could not detect flash size: {detect['error']}"} + + # Parse flash size from output + flash_size_match = re.search(r"Detected flash size:\s*(\d+)([KMG]B)", detect["output"]) + if flash_size_match: + num = int(flash_size_match.group(1)) + unit = flash_size_match.group(2) + multiplier = {"KB": 1024, "MB": 1024 * 1024, "GB": 1024 * 1024 * 1024} + size = str(num * multiplier.get(unit, 1)) + else: + return {"success": False, "error": "Could not determine flash size. Specify size manually."} + + # Ensure output directory exists + out = Path(output_path) + out.parent.mkdir(parents=True, exist_ok=True) + + start_time = time.time() + result = await self._run_esptool( + port, + ["--no-stub", "read-flash", start_address, size, str(out)], + timeout=300.0, + ) + + if not result["success"]: + return {"success": False, "error": result["error"], "port": port} + + elapsed = round(time.time() - start_time, 1) + + return { + "success": True, + "port": port, + "output_path": str(out), + "start_address": start_address, + "bytes_read": out.stat().st_size if out.exists() else 0, + "elapsed_seconds": elapsed, + } + + async def _flash_erase_impl( + self, + context: Context, + port: str | None, + start_address: str, + size: str | None, + ) -> dict[str, Any]: + """Erase flash via esptool erase-flash or erase-region.""" + + if not port: + return {"success": False, "error": "Port is required (no auto-detect for flash operations)"} + + start_time = time.time() + + if size: + # Erase specific region + result = await self._run_esptool( + port, + ["--no-stub", "erase-region", start_address, size], + timeout=60.0, + ) + else: + # Erase entire flash + result = await self._run_esptool( + port, + ["--no-stub", "erase-flash"], + timeout=60.0, + ) + + if not result["success"]: + return {"success": False, "error": result["error"], "port": port} + + elapsed = round(time.time() - start_time, 1) + + return { + "success": True, + "port": port, + "erase_type": "region" if size else "full", + "start_address": start_address if size else "0x0", + "size": size, + "elapsed_seconds": elapsed, + } + + async def _flash_backup_impl( + self, + context: Context, + backup_path: str, + port: str | None, + include_bootloader: bool, + ) -> dict[str, Any]: + """Read entire flash to create a backup file.""" + + start_address = "0x0" if include_bootloader else "0x1000" + + return await self._flash_read_impl( + context, + output_path=backup_path, + port=port, + start_address=start_address, + size=None, # auto-detect full flash + ) async def health_check(self) -> dict[str, Any]: """Component health check"""