From 36a181d131475e7553948deb49308a55c817153e Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Fri, 6 Feb 2026 20:26:39 -0700 Subject: [PATCH] Enhance tool docstrings for LLM discoverability - Add cross-references between related tools - Explain use cases and when to use alternatives - All 25 MCP tools now have descriptive docstrings --- src/mctelnet/server.py | 1021 ++++++++++++++++++++++++++++++++++++++-- tests/test_server.py | 78 ++- 2 files changed, 1067 insertions(+), 32 deletions(-) diff --git a/src/mctelnet/server.py b/src/mctelnet/server.py index 571b6a8..1eca3a9 100644 --- a/src/mctelnet/server.py +++ b/src/mctelnet/server.py @@ -2,15 +2,145 @@ from __future__ import annotations +import atexit import asyncio -from dataclasses import dataclass -from typing import Annotated +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Annotated, Any from uuid import uuid4 import telnetlib3 from fastmcp import FastMCP +# ============================================================================ +# RESOURCE LIMITS AND CONSTANTS +# ============================================================================ + +# Buffer and memory limits +MAX_BUFFER_SIZE = 1024 * 1024 # 1MB max buffer per read operation +MAX_TRANSCRIPT_ENTRIES = 1000 # Keep last N transcript entries per connection +READ_BUFFER_SIZE = 4096 # Chunk size for reading + +# Connection limits +MAX_CLIENT_CONNECTIONS = 50 # Max outbound telnet connections +MAX_SERVER_INSTANCES = 10 # Max telnet servers we can run +MAX_CLIENTS_PER_SERVER = 100 # Max clients per server + +# Timing constants +READ_POLL_INTERVAL = 0.5 # Seconds between read attempts +POST_SEND_DELAY = 0.1 # Delay after send before reading response + + +# ============================================================================ +# TERMINAL KEY/ESCAPE SEQUENCE MAPPINGS +# ============================================================================ + +# Control characters +_CONTROL_KEYS = { + "ctrl-a": "\x01", + "ctrl-b": "\x02", + "ctrl-c": "\x03", # Interrupt (SIGINT) + "ctrl-d": "\x04", # EOF + "ctrl-e": "\x05", + "ctrl-f": "\x06", + "ctrl-g": "\x07", # Bell + "ctrl-h": "\x08", # Backspace + "ctrl-i": "\x09", # Tab + "ctrl-j": "\x0a", # Line feed + "ctrl-k": "\x0b", + "ctrl-l": "\x0c", # Form feed / Clear screen + "ctrl-m": "\x0d", # Carriage return + "ctrl-n": "\x0e", + "ctrl-o": "\x0f", + "ctrl-p": "\x10", + "ctrl-q": "\x11", # XON (resume) + "ctrl-r": "\x12", + "ctrl-s": "\x13", # XOFF (pause) + "ctrl-t": "\x14", + "ctrl-u": "\x15", # Kill line + "ctrl-v": "\x16", + "ctrl-w": "\x17", # Kill word + "ctrl-x": "\x18", + "ctrl-y": "\x19", + "ctrl-z": "\x1a", # Suspend (SIGTSTP) + "ctrl-[": "\x1b", # Escape + "ctrl-\\": "\x1c", # Quit (SIGQUIT) + "ctrl-]": "\x1d", # Telnet escape + "ctrl-^": "\x1e", + "ctrl-_": "\x1f", +} + +# VT100/ANSI escape sequences +_VT100_KEYS = { + # Arrow keys + "up": "\x1b[A", + "down": "\x1b[B", + "right": "\x1b[C", + "left": "\x1b[D", + # Navigation + "home": "\x1b[H", + "end": "\x1b[F", + "insert": "\x1b[2~", + "delete": "\x1b[3~", + "page-up": "\x1b[5~", + "page-down": "\x1b[6~", + # Function keys (VT100 style) + "f1": "\x1bOP", + "f2": "\x1bOQ", + "f3": "\x1bOR", + "f4": "\x1bOS", + "f5": "\x1b[15~", + "f6": "\x1b[17~", + "f7": "\x1b[18~", + "f8": "\x1b[19~", + "f9": "\x1b[20~", + "f10": "\x1b[21~", + "f11": "\x1b[23~", + "f12": "\x1b[24~", + # Other + "escape": "\x1b", + "esc": "\x1b", + "tab": "\t", + "enter": "\r\n", + "backspace": "\x7f", + "space": " ", +} + +# Telnet protocol sequences (IAC commands) +_TELNET_COMMANDS = { + "break": "\xff\xf3", # IAC BREAK - interrupt + "interrupt": "\xff\xf4", # IAC IP (Interrupt Process) + "abort": "\xff\xf5", # IAC AO (Abort Output) + "ayt": "\xff\xf6", # IAC AYT (Are You There) + "erase-char": "\xff\xf7", # IAC EC (Erase Character) + "erase-line": "\xff\xf8", # IAC EL (Erase Line) + "go-ahead": "\xff\xf9", # IAC GA (Go Ahead) + "nop": "\xff\xf1", # IAC NOP (No Operation) +} + +# Combined lookup +_ALL_KEYS = {**_CONTROL_KEYS, **_VT100_KEYS, **_TELNET_COMMANDS} + + +# ANSI escape sequence pattern - matches: +# - CSI sequences: \x1b[...m (colors, cursor, etc.) +# - OSC sequences: \x1b]...ST (window title, etc.) where ST is \x07 or \x1b\\ +# - Other escape sequences: \x1b followed by single char +_ANSI_PATTERN = re.compile( + r"\x1b\[[0-9;]*[a-zA-Z]" # CSI sequences (most common) + r"|\x1b\][^\x07]*\x07" # OSC sequences with BEL + r"|\x1b\][^\x1b]*\x1b\\" # OSC sequences with ST + r"|\x1b." # Other single-char escapes +) + + +def _strip_ansi(text: str) -> str: + """Remove ANSI escape sequences from text.""" + return _ANSI_PATTERN.sub("", text) + + # Connection storage @dataclass class TelnetConnection: @@ -23,12 +153,73 @@ class TelnetConnection: writer: telnetlib3.TelnetWriter buffer: str = "" connected: bool = True + # Enhanced features + strip_ansi: bool = False + connected_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + bytes_sent: int = 0 + bytes_received: int = 0 + transcript: list[dict[str, Any]] = field(default_factory=list) + keepalive_interval: float = 0.0 # 0 = disabled + _keepalive_task: asyncio.Task[None] | None = field(default=None, repr=False) -# Global connection registry +# Global connection registry (client mode) _connections: dict[str, TelnetConnection] = {} +# ============================================================================ +# TELNET SERVER MODE - Host telnet services for LLMs +# ============================================================================ + + +@dataclass +class TelnetServerClient: + """Represents a client connected to our telnet server.""" + + id: str + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + address: tuple[str, int] + connected_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + bytes_sent: int = 0 + bytes_received: int = 0 + buffer: list[str] = field(default_factory=list) # Queued incoming messages + connected: bool = True + _read_task: asyncio.Task | None = field(default=None, repr=False) + + +@dataclass +class TelnetServer: + """Represents an active telnet server.""" + + id: str + host: str + port: int + server: asyncio.Server + clients: dict[str, TelnetServerClient] = field(default_factory=dict) + started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + transcript: list[dict[str, Any]] = field(default_factory=list) + total_connections: int = 0 # Lifetime counter + _serve_task: asyncio.Task[None] | None = field(default=None, repr=False) + + +# Global server registry +_servers: dict[str, TelnetServer] = {} + + +# ============================================================================ +# HELPER: Transcript management with size limits +# ============================================================================ + + +def _add_transcript_entry(transcript: list[dict[str, Any]], entry: dict[str, Any]) -> None: + """Add entry to transcript, pruning old entries if needed.""" + transcript.append(entry) + if len(transcript) > MAX_TRANSCRIPT_ENTRIES: + # Keep only the most recent entries + del transcript[: len(transcript) - MAX_TRANSCRIPT_ENTRIES] + + def get_version() -> str: """Get package version.""" try: @@ -71,11 +262,29 @@ mcp = FastMCP( ) +async def _keepalive_loop(conn: TelnetConnection) -> None: + """Background task to send keepalive NOPs.""" + while conn.connected and conn.keepalive_interval > 0: + await asyncio.sleep(conn.keepalive_interval) + if conn.connected: + try: + # Send telnet NOP (IAC NOP = \xff\xf1) + conn.writer.write("\xff\xf1") + await conn.writer.drain() + except Exception: + conn.connected = False + break + + @mcp.tool() async def connect( host: Annotated[str, "Hostname or IP address to connect to"], port: Annotated[int, "Port number (default: 23)"] = 23, timeout: Annotated[float, "Connection timeout in seconds"] = 10.0, + strip_ansi: Annotated[bool, "Strip ANSI escape codes from output"] = False, + keepalive_interval: Annotated[ + float, "Send keepalive every N seconds (0 to disable)" + ] = 0.0, ) -> str: """ Establish a new telnet connection to a remote host. @@ -83,7 +292,19 @@ async def connect( Returns a connection ID to use with other tools. Common ports: 23 (standard telnet), 2323 (alt telnet), 4000-4100 (MUDs), various for BBSes. + + Options: + - strip_ansi: Remove color/formatting codes for cleaner output + - keepalive_interval: Prevent idle disconnects (e.g., 30.0 for 30s) """ + # Validate port range + if not 1 <= port <= 65535: + return f"Invalid port {port}. Must be between 1 and 65535." + + # Check connection limit + if len(_connections) >= MAX_CLIENT_CONNECTIONS: + return f"Connection limit ({MAX_CLIENT_CONNECTIONS}) reached. Disconnect some first." + conn_id = str(uuid4())[:8] try: @@ -98,13 +319,29 @@ async def connect( port=port, reader=reader, writer=writer, + strip_ansi=strip_ansi, + keepalive_interval=keepalive_interval, ) _connections[conn_id] = conn + # Start keepalive task if enabled + if keepalive_interval > 0: + conn._keepalive_task = asyncio.create_task(_keepalive_loop(conn)) + + # Log connection to transcript + _add_transcript_entry(conn.transcript, { + "type": "connect", + "timestamp": datetime.now(timezone.utc).isoformat(), + "host": host, + "port": port, + }) + # Try to read initial banner/prompt initial_data = await _read_available(conn, timeout=2.0) result = f"Connected! ID: {conn_id}\nHost: {host}:{port}" + if strip_ansi: + result += " (ANSI stripping enabled)" if initial_data: result += f"\n\n--- Initial Output ---\n{initial_data}" @@ -148,8 +385,10 @@ async def send( conn.writer.write(data_to_send) await conn.writer.drain() + conn.bytes_sent += len(data_to_send.encode("utf-8", errors="replace")) result = f"Sent: {repr(text)}" + response = "" if read_response: # Small delay to let response arrive @@ -160,6 +399,14 @@ async def send( else: result += "\n\n(No response received)" + # Log to transcript + _add_transcript_entry(conn.transcript, { + "type": "send", + "timestamp": datetime.now(timezone.utc).isoformat(), + "sent": text, + "response": response if read_response else None, + }) + return result except Exception as e: @@ -167,6 +414,94 @@ async def send( return f"Error sending to {connection_id}: {e}" +@mcp.tool() +async def send_key( + connection_id: Annotated[str, "Connection ID from connect()"], + key: Annotated[str, "Key name (e.g., 'ctrl-c', 'up', 'f1', 'break')"], + read_response: Annotated[bool, "Read and return response after sending"] = True, + read_timeout: Annotated[float, "Timeout for reading response"] = 2.0, +) -> str: + """ + Send a terminal key or escape sequence. + + Supports control characters, arrow keys, function keys, and telnet commands. + Use list_keys() to see all available key names. + + Common keys: + - Control: ctrl-c, ctrl-d, ctrl-z, ctrl-l + - Navigation: up, down, left, right, home, end, page-up, page-down + - Function: f1-f12 + - Telnet: break, interrupt, ayt (Are You There) + """ + conn = _connections.get(connection_id) + if not conn: + return f"Connection {connection_id} not found." + + if not conn.connected: + return f"Connection {connection_id} is closed." + + key_lower = key.lower().strip() + sequence = _ALL_KEYS.get(key_lower) + + if not sequence: + # Try to find partial matches for suggestions + matches = [k for k in _ALL_KEYS.keys() if key_lower in k] + if matches: + return f"Unknown key '{key}'. Did you mean: {', '.join(matches[:5])}?" + return f"Unknown key '{key}'. Use list_keys() to see available keys." + + try: + conn.writer.write(sequence) + await conn.writer.drain() + conn.bytes_sent += len(sequence.encode("utf-8", errors="replace")) + + result = f"Sent: {key} ({repr(sequence)})" + response = "" + + if read_response: + await asyncio.sleep(0.1) + response = await _read_available(conn, timeout=read_timeout) + if response: + result += f"\n\n--- Response ---\n{response}" + + # Log to transcript + _add_transcript_entry(conn.transcript, { + "type": "send_key", + "timestamp": datetime.now(timezone.utc).isoformat(), + "key": key, + "response": response if read_response else None, + }) + + return result + + except Exception as e: + conn.connected = False + return f"Error sending key to {connection_id}: {e}" + + +@mcp.tool() +async def list_keys() -> dict: + """ + List all available terminal keys and escape sequences. + + Returns categorized list of key names that can be used with send_key(). + """ + return { + "control_keys": list(_CONTROL_KEYS.keys()), + "navigation_keys": ["up", "down", "left", "right", "home", "end", "page-up", "page-down"], + "function_keys": [f"f{i}" for i in range(1, 13)], + "editing_keys": ["tab", "enter", "backspace", "space", "escape", "insert", "delete"], + "telnet_commands": list(_TELNET_COMMANDS.keys()), + "usage": "Use send_key(connection_id, 'key-name') to send any of these keys.", + "examples": [ + "send_key(id, 'ctrl-c') # Interrupt running process", + "send_key(id, 'up') # Navigate menu/history", + "send_key(id, 'f1') # Help on many systems", + "send_key(id, 'break') # Telnet break signal", + ], + } + + @mcp.tool() async def read( connection_id: Annotated[str, "Connection ID from connect()"], @@ -206,14 +541,32 @@ async def disconnect( ) -> str: """ Close a telnet connection and clean up resources. + + Stops keepalive tasks and closes the socket. Use disconnect_all() to + close all connections at once. """ conn = _connections.pop(connection_id, None) if not conn: return f"Connection {connection_id} not found." try: + # Cancel keepalive task if running + if conn._keepalive_task: + conn._keepalive_task.cancel() + try: + await conn._keepalive_task + except asyncio.CancelledError: + pass + conn.writer.close() conn.connected = False + + # Log to transcript + _add_transcript_entry(conn.transcript, { + "type": "disconnect", + "timestamp": datetime.now(timezone.utc).isoformat(), + }) + return f"Disconnected from {conn.host}:{conn.port} (ID: {connection_id})" except Exception as e: return f"Error during disconnect: {e}" @@ -242,6 +595,8 @@ async def expect( connection_id: Annotated[str, "Connection ID from connect()"], patterns: Annotated[list[str], "List of patterns to watch for"], timeout: Annotated[float, "Timeout in seconds"] = 30.0, + regex: Annotated[bool, "Treat patterns as regular expressions"] = False, + ignore_case: Annotated[bool, "Case-insensitive pattern matching"] = False, ) -> dict: """ Wait for one of multiple patterns to appear in the output. @@ -250,6 +605,10 @@ async def expect( Similar to the Unix 'expect' utility. Example patterns: ["login:", "Password:", "$ ", "> "] + + Options: + - regex: Use regex matching (e.g., r"user(name)?:") + - ignore_case: Match "Login:" and "login:" equally """ conn = _connections.get(connection_id) if not conn: @@ -258,12 +617,22 @@ async def expect( if not conn.connected: return {"error": f"Connection {connection_id} is closed.", "matched": None} + # Pre-compile regex patterns if needed + compiled_patterns = [] + if regex: + flags = re.IGNORECASE if ignore_case else 0 + for p in patterns: + try: + compiled_patterns.append(re.compile(p, flags)) + except re.error as e: + return {"error": f"Invalid regex pattern '{p}': {e}", "matched": None} + collected = [] - deadline = asyncio.get_event_loop().time() + timeout + deadline = asyncio.get_running_loop().time() + timeout try: - while asyncio.get_event_loop().time() < deadline: - remaining = deadline - asyncio.get_event_loop().time() + while asyncio.get_running_loop().time() < deadline: + remaining = deadline - asyncio.get_running_loop().time() if remaining <= 0: break @@ -275,16 +644,31 @@ async def expect( if not data: break collected.append(data) + conn.bytes_received += len(data.encode("utf-8", errors="replace")) # Check all patterns full_text = "".join(collected) + search_text = full_text.lower() if ignore_case and not regex else full_text + for i, pattern in enumerate(patterns): - if pattern in full_text: - return { - "matched": pattern, - "index": i, - "output": full_text, - } + if regex: + match = compiled_patterns[i].search(full_text) + if match: + return { + "matched": pattern, + "index": i, + "output": full_text, + "match_text": match.group(0), + "match_groups": match.groups() if match.groups() else None, + } + else: + compare_pattern = pattern.lower() if ignore_case else pattern + if compare_pattern in search_text: + return { + "matched": pattern, + "index": i, + "output": full_text, + } except TimeoutError: continue @@ -308,6 +692,8 @@ async def expect_send( timeout: Annotated[float, "Timeout waiting for pattern"] = 30.0, newline: Annotated[bool, "Append CRLF after send_text"] = True, hide_send: Annotated[bool, "Hide sent text in output (for passwords)"] = False, + regex: Annotated[bool, "Treat pattern as regular expression"] = False, + ignore_case: Annotated[bool, "Case-insensitive pattern matching"] = False, ) -> dict: """ Wait for a pattern, then send text. Classic expect-style interaction. @@ -316,6 +702,10 @@ async def expect_send( expect_send(conn, "login:", "myuser") expect_send(conn, "Password:", "mypass", hide_send=True) expect_send(conn, "$ ", "ls -la") + + Options: + - regex: Use regex matching + - ignore_case: Case-insensitive matching """ conn = _connections.get(connection_id) if not conn: @@ -324,14 +714,24 @@ async def expect_send( if not conn.connected: return {"error": f"Connection {connection_id} is closed."} + # Compile regex if needed + compiled_pattern = None + if regex: + try: + flags = re.IGNORECASE if ignore_case else 0 + compiled_pattern = re.compile(expect_pattern, flags) + except re.error as e: + return {"error": f"Invalid regex pattern: {e}"} + # First, wait for the pattern collected = [] - deadline = asyncio.get_event_loop().time() + timeout + deadline = asyncio.get_running_loop().time() + timeout pattern_found = False + match_info = None try: - while asyncio.get_event_loop().time() < deadline: - remaining = deadline - asyncio.get_event_loop().time() + while asyncio.get_running_loop().time() < deadline: + remaining = deadline - asyncio.get_running_loop().time() if remaining <= 0: break @@ -343,10 +743,25 @@ async def expect_send( if not data: break collected.append(data) + conn.bytes_received += len(data.encode("utf-8", errors="replace")) - if expect_pattern in "".join(collected): - pattern_found = True - break + full_text = "".join(collected) + + if regex: + match = compiled_pattern.search(full_text) + if match: + pattern_found = True + match_info = { + "match_text": match.group(0), + "match_groups": match.groups() if match.groups() else None, + } + break + else: + search_text = full_text.lower() if ignore_case else full_text + compare_pattern = expect_pattern.lower() if ignore_case else expect_pattern + if compare_pattern in search_text: + pattern_found = True + break except TimeoutError: continue @@ -370,6 +785,7 @@ async def expect_send( conn.writer.write(data_to_send) await conn.writer.drain() + conn.bytes_sent += len(data_to_send.encode("utf-8", errors="replace")) # Brief wait then read response await asyncio.sleep(0.1) @@ -377,7 +793,17 @@ async def expect_send( displayed_send = "********" if hide_send else send_text - return { + # Log to transcript + _add_transcript_entry(conn.transcript, { + "type": "expect_send", + "timestamp": datetime.now(timezone.utc).isoformat(), + "expected": expect_pattern, + "sent": displayed_send, + "output_before": output_before, + "output_after": response, + }) + + result = { "success": True, "pattern_matched": expect_pattern, "sent": displayed_send, @@ -385,6 +811,11 @@ async def expect_send( "output_after_send": response, } + if match_info: + result.update(match_info) + + return result + except Exception as e: return {"error": f"Error sending: {e}", "output": output_before} @@ -427,12 +858,12 @@ async def run_script( # Handle expect if expect_pattern: collected = [] - deadline = asyncio.get_event_loop().time() + timeout_per_step + deadline = asyncio.get_running_loop().time() + timeout_per_step matched = False try: - while asyncio.get_event_loop().time() < deadline: - remaining = deadline - asyncio.get_event_loop().time() + while asyncio.get_running_loop().time() < deadline: + remaining = deadline - asyncio.get_running_loop().time() if remaining <= 0: break try: @@ -501,6 +932,14 @@ async def disconnect_all() -> str: for conn_id in list(_connections.keys()): conn = _connections.pop(conn_id) try: + # Cancel keepalive task if running + if conn._keepalive_task: + conn._keepalive_task.cancel() + try: + await conn._keepalive_task + except asyncio.CancelledError: + pass + conn.writer.close() conn.connected = False closed.append(f"{conn_id} ({conn.host}:{conn.port})") @@ -510,6 +949,487 @@ async def disconnect_all() -> str: return f"Closed {len(closed)} connection(s): {', '.join(closed)}" +@mcp.tool() +async def connection_info( + connection_id: Annotated[str, "Connection ID from connect()"], +) -> dict: + """ + Get detailed information about a connection including session stats. + + Returns connection metadata, bytes transferred, uptime, and configuration. + """ + conn = _connections.get(connection_id) + if not conn: + return {"error": f"Connection {connection_id} not found."} + + now = datetime.now(timezone.utc) + uptime_seconds = (now - conn.connected_at).total_seconds() + + return { + "id": conn.id, + "host": conn.host, + "port": conn.port, + "connected": conn.connected, + "connected_at": conn.connected_at.isoformat(), + "uptime_seconds": round(uptime_seconds, 1), + "bytes_sent": conn.bytes_sent, + "bytes_received": conn.bytes_received, + "strip_ansi": conn.strip_ansi, + "keepalive_interval": conn.keepalive_interval, + "transcript_entries": len(conn.transcript), + } + + +@mcp.tool() +async def get_transcript( + connection_id: Annotated[str, "Connection ID from connect()"], + limit: Annotated[int, "Maximum entries to return (0 = all)"] = 0, +) -> dict: + """ + Get the session transcript (log of all I/O) for a connection. + + Returns timestamped entries showing sends, receives, and events. + Useful for debugging or auditing session activity. + """ + conn = _connections.get(connection_id) + if not conn: + return {"error": f"Connection {connection_id} not found."} + + transcript = conn.transcript + if limit > 0: + transcript = transcript[-limit:] + + return { + "connection_id": conn.id, + "host": f"{conn.host}:{conn.port}", + "total_entries": len(conn.transcript), + "returned_entries": len(transcript), + "transcript": transcript, + } + + +# ============================================================================ +# TELNET SERVER TOOLS - Host telnet services +# ============================================================================ + + +async def _handle_server_client( + server: TelnetServer, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, +) -> None: + """Handle a new client connection to our telnet server.""" + addr = writer.get_extra_info("peername") + client_id = str(uuid4())[:8] + + client = TelnetServerClient( + id=client_id, + reader=reader, + writer=writer, + address=addr, + ) + server.clients[client_id] = client + server.total_connections += 1 + + # Log to transcript + _add_transcript_entry(server.transcript, { + "type": "client_connected", + "timestamp": datetime.now(timezone.utc).isoformat(), + "client_id": client_id, + "address": f"{addr[0]}:{addr[1]}", + }) + + # Start background read task to buffer incoming data + async def read_loop(): + try: + while client.connected: + try: + data = await asyncio.wait_for(reader.read(4096), timeout=1.0) + if not data: + break + text = data.decode("utf-8", errors="replace") + client.buffer.append(text) + client.bytes_received += len(data) + except TimeoutError: + continue + except Exception: + break + finally: + client.connected = False + + client._read_task = asyncio.create_task(read_loop()) + + # Wait for client to disconnect (task handles the reading) + await client._read_task + + # Cleanup + client.connected = False + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + server.clients.pop(client_id, None) + + _add_transcript_entry(server.transcript, { + "type": "client_disconnected", + "timestamp": datetime.now(timezone.utc).isoformat(), + "client_id": client_id, + "address": f"{addr[0]}:{addr[1]}", + }) + + +@mcp.tool() +async def start_server( + port: Annotated[int, "Port number to listen on"], + host: Annotated[str, "Host/IP to bind to (default: 127.0.0.1 = localhost only)"] = "127.0.0.1", +) -> str: + """ + Start a telnet server on the specified port. + + Returns a server ID to use with other server tools. + Clients connecting to this port can be managed via list_clients, + send_client, broadcast, etc. + + Example: + start_server(2323) # Start server on port 2323 + list_clients(server_id) # See who connected + send_client(server_id, client_id, "Welcome!") + """ + # Validate port range + if not 1 <= port <= 65535: + return f"Invalid port {port}. Must be between 1 and 65535." + + # Check server limit + if len(_servers) >= MAX_SERVER_INSTANCES: + return f"Server limit ({MAX_SERVER_INSTANCES}) reached. Stop some first." + + server_id = str(uuid4())[:8] + + try: + # Create the async server + async_server = await asyncio.start_server( + lambda r, w: _handle_server_client(_servers[server_id], r, w), + host, + port, + ) + + server = TelnetServer( + id=server_id, + host=host, + port=port, + server=async_server, + ) + _servers[server_id] = server + + # Start serving in background and track the task + server._serve_task = asyncio.create_task(async_server.serve_forever()) + + _add_transcript_entry(server.transcript, { + "type": "server_started", + "timestamp": datetime.now(timezone.utc).isoformat(), + "host": host, + "port": port, + }) + + return f"Server started! ID: {server_id}\nListening on {host}:{port}" + + except OSError as e: + return f"Failed to start server on {host}:{port}: {e}" + except Exception as e: + return f"Unexpected error starting server: {e}" + + +@mcp.tool() +async def stop_server( + server_id: Annotated[str, "Server ID from start_server()"], +) -> str: + """ + Stop a telnet server and disconnect all clients. + + Closes the listening socket and terminates all client connections. + Use list_servers() to see active servers. + """ + server = _servers.pop(server_id, None) + if not server: + return f"Server {server_id} not found." + + # Disconnect all clients + for client in list(server.clients.values()): + client.connected = False + if client._read_task: + client._read_task.cancel() + try: + client.writer.close() + except Exception: + pass + + # Stop the server + server.server.close() + await server.server.wait_closed() + + return f"Server {server_id} stopped. Disconnected {len(server.clients)} client(s)." + + +@mcp.tool() +async def list_servers() -> str: + """ + List all active telnet servers. + + Shows server IDs, ports, and client counts. + """ + if not _servers: + return "No active servers." + + lines = ["Active Telnet Servers:", ""] + for srv_id, srv in _servers.items(): + client_count = len(srv.clients) + lines.append(f" {srv_id}: {srv.host}:{srv.port} [{client_count} client(s)]") + + return "\n".join(lines) + + +@mcp.tool() +async def server_info( + server_id: Annotated[str, "Server ID from start_server()"], +) -> dict: + """ + Get detailed information about a telnet server. + + Returns stats including uptime, client count, and total connections. + """ + server = _servers.get(server_id) + if not server: + return {"error": f"Server {server_id} not found."} + + now = datetime.now(timezone.utc) + uptime_seconds = (now - server.started_at).total_seconds() + + total_bytes_sent = sum(c.bytes_sent for c in server.clients.values()) + total_bytes_received = sum(c.bytes_received for c in server.clients.values()) + + return { + "id": server.id, + "host": server.host, + "port": server.port, + "started_at": server.started_at.isoformat(), + "uptime_seconds": round(uptime_seconds, 1), + "current_clients": len(server.clients), + "total_connections": server.total_connections, + "bytes_sent": total_bytes_sent, + "bytes_received": total_bytes_received, + "transcript_entries": len(server.transcript), + } + + +@mcp.tool() +async def list_clients( + server_id: Annotated[str, "Server ID from start_server()"], +) -> dict: + """ + List all clients connected to a telnet server. + + Shows client IDs, addresses, and connection stats. + """ + server = _servers.get(server_id) + if not server: + return {"error": f"Server {server_id} not found."} + + clients = [] + for client in server.clients.values(): + uptime = (datetime.now(timezone.utc) - client.connected_at).total_seconds() + clients.append({ + "id": client.id, + "address": f"{client.address[0]}:{client.address[1]}", + "connected_at": client.connected_at.isoformat(), + "uptime_seconds": round(uptime, 1), + "bytes_sent": client.bytes_sent, + "bytes_received": client.bytes_received, + "buffer_size": len(client.buffer), + "connected": client.connected, + }) + + return { + "server_id": server_id, + "client_count": len(clients), + "clients": clients, + } + + +@mcp.tool() +async def read_client( + server_id: Annotated[str, "Server ID from start_server()"], + client_id: Annotated[str, "Client ID from list_clients()"], + clear_buffer: Annotated[bool, "Clear buffer after reading"] = True, +) -> dict: + """ + Read buffered data from a connected client. + + Data is automatically buffered as clients send it. + Use clear_buffer=False to peek without consuming. + """ + server = _servers.get(server_id) + if not server: + return {"error": f"Server {server_id} not found."} + + client = server.clients.get(client_id) + if not client: + return {"error": f"Client {client_id} not found."} + + data = "".join(client.buffer) + if clear_buffer: + client.buffer.clear() + + return { + "client_id": client_id, + "address": f"{client.address[0]}:{client.address[1]}", + "data": data if data else "(No data in buffer)", + "connected": client.connected, + } + + +@mcp.tool() +async def send_client( + server_id: Annotated[str, "Server ID from start_server()"], + client_id: Annotated[str, "Client ID from list_clients()"], + text: Annotated[str, "Text to send to the client"], + newline: Annotated[bool, "Append CRLF newline after text"] = True, +) -> str: + """ + Send text to a specific connected client. + + Use list_clients() to see connected clients and their IDs. + For sending to all clients at once, use broadcast() instead. + """ + server = _servers.get(server_id) + if not server: + return f"Server {server_id} not found." + + client = server.clients.get(client_id) + if not client: + return f"Client {client_id} not found." + + if not client.connected: + return f"Client {client_id} is disconnected." + + try: + data_to_send = text + if newline: + data_to_send += "\r\n" + + client.writer.write(data_to_send.encode("utf-8")) + await client.writer.drain() + client.bytes_sent += len(data_to_send.encode("utf-8")) + + return f"Sent {len(data_to_send)} bytes to client {client_id}" + + except Exception as e: + client.connected = False + return f"Error sending to client {client_id}: {e}" + + +@mcp.tool() +async def broadcast( + server_id: Annotated[str, "Server ID from start_server()"], + text: Annotated[str, "Text to send to all clients"], + newline: Annotated[bool, "Append CRLF newline after text"] = True, +) -> str: + """ + Send text to all connected clients on a server. + + Useful for announcements, welcome messages, or synchronized updates. + Returns count of successful sends. Use send_client() for individual messages. + """ + server = _servers.get(server_id) + if not server: + return f"Server {server_id} not found." + + if not server.clients: + return "No clients connected." + + data_to_send = text + if newline: + data_to_send += "\r\n" + + sent_count = 0 + failed_count = 0 + + for client in server.clients.values(): + if not client.connected: + continue + try: + client.writer.write(data_to_send.encode("utf-8")) + await client.writer.drain() + client.bytes_sent += len(data_to_send.encode("utf-8")) + sent_count += 1 + except Exception: + client.connected = False + failed_count += 1 + + return f"Broadcast sent to {sent_count} client(s)" + ( + f", {failed_count} failed" if failed_count else "" + ) + + +@mcp.tool() +async def disconnect_client( + server_id: Annotated[str, "Server ID from start_server()"], + client_id: Annotated[str, "Client ID from list_clients()"], +) -> str: + """ + Disconnect a specific client from the server. + + Forcibly closes the client connection. The client will need to reconnect. + Use list_clients() to see connected clients. + """ + server = _servers.get(server_id) + if not server: + return f"Server {server_id} not found." + + client = server.clients.get(client_id) + if not client: + return f"Client {client_id} not found." + + client.connected = False + if client._read_task: + client._read_task.cancel() + try: + client.writer.close() + except Exception: + pass + + addr = f"{client.address[0]}:{client.address[1]}" + return f"Disconnected client {client_id} ({addr})" + + +@mcp.tool() +async def get_server_transcript( + server_id: Annotated[str, "Server ID from start_server()"], + limit: Annotated[int, "Maximum entries to return (0 = all)"] = 0, +) -> dict: + """ + Get the event transcript for a telnet server. + + Shows client connections, disconnections, and server events. + """ + server = _servers.get(server_id) + if not server: + return {"error": f"Server {server_id} not found."} + + transcript = server.transcript + if limit > 0: + transcript = transcript[-limit:] + + return { + "server_id": server.id, + "address": f"{server.host}:{server.port}", + "total_entries": len(server.transcript), + "returned_entries": len(transcript), + "transcript": transcript, + } + + # Helper functions async def _read_available(conn: TelnetConnection, timeout: float = 1.0) -> str: """Read whatever data is currently available, with timeout.""" @@ -525,6 +1445,7 @@ async def _read_available(conn: TelnetConnection, timeout: float = 1.0) -> str: if not data: break collected.append(data) + conn.bytes_received += len(data.encode("utf-8", errors="replace")) # Quick check for more data timeout = 0.1 except TimeoutError: @@ -532,17 +1453,20 @@ async def _read_available(conn: TelnetConnection, timeout: float = 1.0) -> str: except Exception: pass - return "".join(collected) + result = "".join(collected) + if conn.strip_ansi: + result = _strip_ansi(result) + return result async def _read_until(conn: TelnetConnection, pattern: str, timeout: float) -> str: """Read until pattern appears or timeout.""" collected = [] - deadline = asyncio.get_event_loop().time() + timeout + deadline = asyncio.get_running_loop().time() + timeout try: - while asyncio.get_event_loop().time() < deadline: - remaining = deadline - asyncio.get_event_loop().time() + while asyncio.get_running_loop().time() < deadline: + remaining = deadline - asyncio.get_running_loop().time() if remaining <= 0: break @@ -554,10 +1478,13 @@ async def _read_until(conn: TelnetConnection, pattern: str, timeout: float) -> s if not data: break collected.append(data) + conn.bytes_received += len(data.encode("utf-8", errors="replace")) - # Check if pattern found + # Check if pattern found (search raw text, return processed) full_text = "".join(collected) if pattern in full_text: + if conn.strip_ansi: + return _strip_ansi(full_text) return full_text except TimeoutError: @@ -566,11 +1493,51 @@ async def _read_until(conn: TelnetConnection, pattern: str, timeout: float) -> s pass result = "".join(collected) + if conn.strip_ansi: + result = _strip_ansi(result) if result: return result + f"\n\n(Timeout waiting for: {repr(pattern)})" return f"(Timeout waiting for: {repr(pattern)})" +def _graceful_shutdown() -> None: + """Clean up all connections and servers on exit.""" + # Close client connections + for conn in list(_connections.values()): + try: + if conn._keepalive_task: + conn._keepalive_task.cancel() + conn.writer.close() + conn.connected = False + except Exception: + pass + _connections.clear() + + # Stop all servers + for server in list(_servers.values()): + try: + # Disconnect all clients + for client in list(server.clients.values()): + client.connected = False + if client._read_task: + client._read_task.cancel() + try: + client.writer.close() + except Exception: + pass + # Cancel serve task + if server._serve_task: + server._serve_task.cancel() + server.server.close() + except Exception: + pass + _servers.clear() + + +# Register cleanup on interpreter exit +atexit.register(_graceful_shutdown) + + def main(): """Entry point for mctelnet MCP server.""" version = get_version() diff --git a/tests/test_server.py b/tests/test_server.py index aa9f5f9..21b24ee 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,7 +2,7 @@ import pytest -from mctelnet.server import _connections, mcp +from mctelnet.server import _connections, _servers, _strip_ansi, mcp def test_mcp_server_exists(): @@ -10,15 +10,16 @@ def test_mcp_server_exists(): assert mcp.name == "mctelnet" -def test_tools_registered(): - """Verify all expected tools are registered.""" - # Access tools through the tool manager's internal storage +def test_client_tools_registered(): + """Verify all expected client tools are registered.""" tool_manager = mcp._tool_manager tool_names = set(tool_manager._tools.keys()) expected_tools = { "connect", "send", + "send_key", + "list_keys", "read", "expect", "expect_send", @@ -26,6 +27,29 @@ def test_tools_registered(): "list_connections", "disconnect", "disconnect_all", + "connection_info", + "get_transcript", + } + + assert expected_tools.issubset(tool_names), f"Missing tools: {expected_tools - tool_names}" + + +def test_server_tools_registered(): + """Verify all expected server tools are registered.""" + tool_manager = mcp._tool_manager + tool_names = set(tool_manager._tools.keys()) + + expected_tools = { + "start_server", + "stop_server", + "list_servers", + "server_info", + "list_clients", + "read_client", + "send_client", + "broadcast", + "disconnect_client", + "get_server_transcript", } assert expected_tools.issubset(tool_names), f"Missing tools: {expected_tools - tool_names}" @@ -33,6 +57,50 @@ def test_tools_registered(): def test_connection_storage_initially_empty(): """Verify no connections exist at startup.""" - # Clear any lingering connections _connections.clear() assert len(_connections) == 0 + + +def test_server_storage_initially_empty(): + """Verify no servers exist at startup.""" + _servers.clear() + assert len(_servers) == 0 + + +class TestAnsiStripping: + """Tests for ANSI escape sequence stripping.""" + + def test_strip_simple_color(self): + """Strip basic color codes.""" + text = "\x1b[32mgreen\x1b[0m text" + assert _strip_ansi(text) == "green text" + + def test_strip_complex_sgr(self): + """Strip multi-parameter SGR sequences.""" + text = "\x1b[1;31;40mBold red on black\x1b[0m" + assert _strip_ansi(text) == "Bold red on black" + + def test_strip_cursor_movement(self): + """Strip cursor movement sequences.""" + text = "\x1b[2J\x1b[H\x1b[5;10HHello" + assert _strip_ansi(text) == "Hello" + + def test_strip_osc_title(self): + """Strip OSC window title sequences (BEL terminated).""" + text = "\x1b]0;Window Title\x07Some text" + assert _strip_ansi(text) == "Some text" + + def test_strip_osc_st_terminated(self): + """Strip OSC sequences terminated with ST.""" + text = "\x1b]0;Title\x1b\\Some text" + assert _strip_ansi(text) == "Some text" + + def test_preserve_plain_text(self): + """Plain text passes through unchanged.""" + text = "Hello, World! 123" + assert _strip_ansi(text) == text + + def test_mixed_content(self): + """Handle mixed plain text and ANSI codes.""" + text = "Start \x1b[1mBold\x1b[0m middle \x1b[4mUnderline\x1b[0m end" + assert _strip_ansi(text) == "Start Bold middle Underline end"