diff --git a/src/mcserial/server.py b/src/mcserial/server.py index 5438b6f..057853b 100644 --- a/src/mcserial/server.py +++ b/src/mcserial/server.py @@ -28,6 +28,123 @@ class SerialConnection: # Active connections registry _connections: dict[str, SerialConnection] = {} + +def _detect_baud_rate_internal( + port: str, + probe: str | None = None, + timeout_per_rate: float = 0.3, + baudrates: list[int] | None = None, +) -> dict: + """Internal baud rate detection (not exposed as MCP tool). + + Used by open_serial_port for auto-detection. + """ + import math + import time + from collections import Counter + + # Common baud rates ordered by popularity + default_rates = [ + 115200, 9600, 57600, 38400, 19200, + 230400, 460800, 921600, + 4800, 2400, 1200, + ] + rates_to_try = baudrates or default_rates + results = [] + + def score_data(data: bytes) -> dict: + """Score data readability with sync pattern analysis.""" + if not data: + return {"score": 0, "bytes_received": 0} + + # Printable ASCII percentage + printable = sum(1 for b in data if 32 <= b <= 126 or b in (9, 10, 13)) + printable_pct = (printable / len(data)) * 100 + + # Line endings indicator + has_newlines = b'\n' in data or b'\r' in data + + # Null byte penalty + null_pct = (data.count(0x00) / len(data)) * 100 + + # 0x55 sync pattern detection + sync_count = sum(1 for b in data if b in (0x55, 0xAA)) + sync_pct = (sync_count / len(data)) * 100 + + # Bit transitions (0x55 has 7 per byte) + transitions = sum( + sum(1 for i in range(7) if ((b >> i) & 1) != ((b >> (i + 1)) & 1)) + for b in data + ) + avg_transitions = transitions / len(data) + + # Text clustering + counter = Counter(data) + text_cluster = sum(counter.get(b, 0) for b in range(0x20, 0x7F)) / len(data) * 100 + + # Entropy calculation + entropy = 0 + for count in counter.values(): + if count > 0: + p = count / len(data) + entropy -= p * math.log2(p) + norm_entropy = (entropy / 8) * 100 + + # Composite score + score = printable_pct + score += sync_pct * 0.5 + score += min(100, (avg_transitions / 7) * 100) * 0.2 if avg_transitions > 3 else 0 + score += text_cluster * 0.3 + if has_newlines: + score += 15 + if null_pct > 30: + score -= 40 + if norm_entropy > 70: + score -= 20 + + # UTF-8 bonus + try: + data.decode('utf-8') + score += 10 + except UnicodeDecodeError: + pass + + return { + "score": round(max(0, min(100, score)), 1), + "bytes_received": len(data), + } + + for rate in rates_to_try: + try: + conn = serial.Serial(port=port, baudrate=rate, timeout=timeout_per_rate) + + if probe: + conn.write(probe.encode()) + conn.flush() + time.sleep(0.05) + + time.sleep(timeout_per_rate) + available = conn.in_waiting + data = conn.read(available) if available else b"" + conn.close() + + score_result = score_data(data) + if score_result["bytes_received"] > 0: + results.append({"baudrate": rate, **score_result}) + + except serial.SerialException: + continue + + results.sort(key=lambda x: x["score"], reverse=True) + best = results[0] if results else None + + return { + "detected_baudrate": best["baudrate"] if best and best["score"] > 50 else None, + "confidence": best["score"] if best else 0, + "results": results[:5], + } + + mcp = FastMCP( name="mcserial", instructions="""Serial port MCP server. Use tools to open/close/write to serial ports. @@ -74,7 +191,7 @@ def list_serial_ports(usb_only: bool = True) -> list[dict]: @mcp.tool() def open_serial_port( port: str, - baudrate: int = DEFAULT_BAUDRATE, + baudrate: int | None = None, bytesize: Literal[5, 6, 7, 8] = 8, parity: Literal["N", "E", "O", "M", "S"] = "N", stopbits: Literal[1, 1.5, 2] = 1, @@ -85,12 +202,18 @@ def open_serial_port( rtscts: bool = False, dsrdtr: bool = False, exclusive: bool = False, + autobaud_probe: str | None = None, + autobaud_timeout: float = 0.3, ) -> dict: - """Open a serial port connection. + """Open a serial port connection with optional auto-baud detection. + + If baudrate is not specified (None), automatically detects the baud rate + by analyzing incoming data patterns. Works best when the device is actively + sending data or responds to a probe string. Args: port: Device path (e.g., '/dev/ttyUSB0', 'COM3') - baudrate: Baud rate (default from env or 9600) + baudrate: Baud rate. If None, auto-detect (recommended for unknown devices) bytesize: Data bits (5, 6, 7, or 8) parity: Parity checking (N=None, E=Even, O=Odd, M=Mark, S=Space) stopbits: Stop bits (1, 1.5, or 2) @@ -101,9 +224,11 @@ def open_serial_port( rtscts: Enable hardware RTS/CTS flow control dsrdtr: Enable hardware DSR/DTR flow control exclusive: Request exclusive access (lock port from other processes) + autobaud_probe: String to send during auto-detection (e.g., "UUUUU" for sync) + autobaud_timeout: Timeout per rate during auto-detection (default 0.3s) Returns: - Connection status and details + Connection status and details (includes auto-detection info if used) """ if port in _connections: return {"error": f"Port {port} is already open", "success": False} @@ -111,6 +236,34 @@ def open_serial_port( if len(_connections) >= MAX_CONNECTIONS: return {"error": f"Maximum connections ({MAX_CONNECTIONS}) reached", "success": False} + # Auto-detect baud rate if not specified + detected_info = None + if baudrate is None: + # Import the detect function's logic inline to avoid circular issues + detection_result = _detect_baud_rate_internal( + port=port, + probe=autobaud_probe, + timeout_per_rate=autobaud_timeout, + ) + if detection_result.get("detected_baudrate"): + baudrate = detection_result["detected_baudrate"] + detected_info = { + "auto_detected": True, + "detection_confidence": detection_result.get("confidence", 0), + "detection_candidates": [ + {"baudrate": r["baudrate"], "score": r["score"]} + for r in detection_result.get("results", [])[:3] + ], + } + else: + # Fall back to default if detection fails + baudrate = DEFAULT_BAUDRATE + detected_info = { + "auto_detected": False, + "fallback_reason": "No data received or low confidence", + "using_default": DEFAULT_BAUDRATE, + } + try: conn = serial.Serial( port=port, @@ -127,7 +280,7 @@ def open_serial_port( exclusive=exclusive, ) _connections[port] = SerialConnection(port=port, connection=conn) - return { + result = { "success": True, "port": port, "baudrate": baudrate, @@ -140,6 +293,9 @@ def open_serial_port( "exclusive": exclusive, "resource_uri": f"serial://{port}/data", } + if detected_info: + result["autobaud"] = detected_info + return result except serial.SerialException as e: return {"error": str(e), "success": False}