diff --git a/main.py b/main.py index c9bc623..9a0946e 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,16 @@ from __future__ import annotations -import sys +from fastmcp import FastMCP from gnuradio_mcp.middlewares.platform import PlatformMiddleware +from gnuradio_mcp.providers.mcp import McpPlatformProvider -# Load GNU Radio try: from gnuradio import gr + from gnuradio.grc.core.platform import Platform except ImportError: - # Throw a new exception with more information - print( - "Cannot find GNU Radio! (Have you sourced the environment file?)", - file=sys.stderr, - ) - # Throw the new exception raise Exception("Cannot find GNU Radio!") from None -from gnuradio.grc.core.platform import Platform - platform = Platform( version=gr.version(), version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()), @@ -25,9 +18,10 @@ platform = Platform( ) platform.build_library() +app: FastMCP = FastMCP( + "GNU Radio MCP", description="Provide a MCP interface to GNU Radio" +) -platform_middleware = PlatformMiddleware(platform) -flowgraph_mw = platform_middleware.make_flowgraph() -flowgraph_mw.add_block("blocks_add_xx") -for error in flowgraph_mw.get_all_errors(): - print(error) +McpPlatformProvider.from_platform_middleware(app, PlatformMiddleware(platform)) + +app.run(transport="sse") diff --git a/src/gnuradio_mcp/middlewares/flowgraph.py b/src/gnuradio_mcp/middlewares/flowgraph.py index a827937..9a045bf 100644 --- a/src/gnuradio_mcp/middlewares/flowgraph.py +++ b/src/gnuradio_mcp/middlewares/flowgraph.py @@ -5,40 +5,20 @@ from typing import TYPE_CHECKING, Optional from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.FlowGraph import FlowGraph -from gnuradio.grc.core.ports.port import Port from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware -from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel -from gnuradio_mcp.utils import get_unique_id +from gnuradio_mcp.models import ( + BlockModel, + ConnectionModel, + PortModel, +) +from gnuradio_mcp.utils import get_port_from_port_model, get_unique_id if TYPE_CHECKING: from gnuradio_mcp.middlewares.platform import PlatformMiddleware -def get_port_from_port_model_in_port_list( - port_list: list[Port], port_model: PortModel -) -> Block: - for port in port_list: - if port.key == port_model.key: - return port - raise ValueError(f"Port not found: {port_model.key}") - - -def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port: - block_from_port_model = flowgraph.get_block(port_model.parent) - if port_model.direction == SOURCE: - return get_port_from_port_model_in_port_list( - block_from_port_model.sources, port_model - ) - elif port_model.direction == SINK: - return get_port_from_port_model_in_port_list( - block_from_port_model.sinks, port_model - ) - else: - raise ValueError(f"Invalid port direction: {port_model.direction}") - - def set_block_name(block: Block, name: str): block.params["id"].set_value(name) diff --git a/src/gnuradio_mcp/providers/__init__.py b/src/gnuradio_mcp/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gnuradio_mcp/providers/base.py b/src/gnuradio_mcp/providers/base.py new file mode 100644 index 0000000..c91c617 --- /dev/null +++ b/src/gnuradio_mcp/providers/base.py @@ -0,0 +1,103 @@ +from typing import Any, Dict, List + +from gnuradio_mcp.middlewares.platform import PlatformMiddleware +from gnuradio_mcp.models import ( + SINK, + SOURCE, + BlockModel, + BlockTypeModel, + ConnectionModel, + ErrorModel, + ParamModel, + PortModel, +) +from gnuradio_mcp.utils import get_port_by_key + + +class PlatformProvider: + def __init__(self, platform_mw: PlatformMiddleware, flowgraph_path: str = ""): + self._platform_mw = platform_mw + self._flowgraph_mw = platform_mw.make_flowgraph(flowgraph_path) + + ############################################## + # Flowgraph Management + ############################################## + + def get_blocks(self) -> list[BlockModel]: + return self._flowgraph_mw.blocks + + def make_block(self, block_name: str) -> str: + block_mw = self._flowgraph_mw.add_block(block_name) + return block_mw.name + + def remove_block(self, block_name: str) -> bool: + self._flowgraph_mw.remove_block(block_name) + return True + + ############################################## + # Block Management + ############################################## + + def get_block_params(self, block_name: str) -> List[ParamModel]: + return self._flowgraph_mw.get_block(block_name).params + + def set_block_params(self, block_name: str, params: Dict[str, Any]) -> bool: + self._flowgraph_mw.get_block(block_name).set_params(params) + return True + + def get_block_sources(self, block_name: str) -> list[PortModel]: + return self._flowgraph_mw.get_block(block_name).sources + + def get_block_sinks(self, block_name: str) -> list[PortModel]: + return self._flowgraph_mw.get_block(block_name).sinks + + ############################################## + # Connection Management + ############################################## + + def get_connections(self) -> list[ConnectionModel]: + return self._flowgraph_mw.get_connections() + + def connect_blocks( + self, + source_block_name: str, + sink_block_name: str, + source_port_name: str, + sink_port_name: str, + ) -> bool: + source_port = get_port_by_key( + self._flowgraph_mw, source_block_name, source_port_name, SOURCE + ) + sink_port = get_port_by_key( + self._flowgraph_mw, sink_block_name, sink_port_name, SINK + ) + self._flowgraph_mw.connect_blocks(source_port, sink_port) + return True + + def disconnect_blocks(self, source_port: PortModel, sink_port: PortModel) -> bool: + self._flowgraph_mw.disconnect_blocks(source_port, sink_port) + return True + + ############################################## + # Flowgraph Validation + ############################################## + + def validate_block(self, block_name: str) -> bool: + return self._flowgraph_mw.get_block(block_name).validate() + + def validate_flowgraph(self) -> bool: + return self._flowgraph_mw.validate() + + def get_all_errors(self) -> list[ErrorModel]: + return self._flowgraph_mw.get_all_errors() + + ############################################## + # Platform Management + ############################################## + + def get_all_available_blocks(self) -> list[BlockTypeModel]: + return self._platform_mw.blocks + + def save_flowgraph(self, filepath: str) -> bool: + self._platform_mw.save_flowgraph(filepath, self._flowgraph_mw) + return True diff --git a/src/gnuradio_mcp/providers/mcp.py b/src/gnuradio_mcp/providers/mcp.py new file mode 100644 index 0000000..dea55b3 --- /dev/null +++ b/src/gnuradio_mcp/providers/mcp.py @@ -0,0 +1,42 @@ +from fastmcp import FastMCP + +from gnuradio_mcp.middlewares.platform import PlatformMiddleware +from gnuradio_mcp.providers.base import PlatformProvider + + +class McpPlatformProvider: + def __init__(self, mcp_instance: FastMCP, platform_provider: PlatformProvider): + self._mcp_instance = mcp_instance + self._platform_provider = platform_provider + self.__init_tools() + + def __init_tools(self): + self._mcp_instance.add_tool(self._platform_provider.get_blocks) + self._mcp_instance.add_tool(self._platform_provider.make_block) + self._mcp_instance.add_tool(self._platform_provider.remove_block) + self._mcp_instance.add_tool(self._platform_provider.get_block_params) + self._mcp_instance.add_tool(self._platform_provider.set_block_params) + self._mcp_instance.add_tool(self._platform_provider.get_block_sources) + self._mcp_instance.add_tool(self._platform_provider.get_block_sinks) + self._mcp_instance.add_tool(self._platform_provider.get_connections) + self._mcp_instance.add_tool(self._platform_provider.connect_blocks) + self._mcp_instance.add_tool(self._platform_provider.disconnect_blocks) + self._mcp_instance.add_tool(self._platform_provider.validate_block) + self._mcp_instance.add_tool(self._platform_provider.validate_flowgraph) + self._mcp_instance.add_tool(self._platform_provider.get_all_errors) + self._mcp_instance.add_tool(self._platform_provider.save_flowgraph) + self._mcp_instance.add_tool(self._platform_provider.get_all_available_blocks) + + @property + def app(self) -> FastMCP: + return self._mcp_instance + + @classmethod + def from_platform_middleware( + cls, + mcp_instance: FastMCP, + platform_middleware: PlatformMiddleware, + flowgraph_path: str = "", + ): + platform_provider = PlatformProvider(platform_middleware, flowgraph_path) + return cls(mcp_instance, platform_provider) diff --git a/src/gnuradio_mcp/utils.py b/src/gnuradio_mcp/utils.py index 1c8656e..bc0c2a5 100644 --- a/src/gnuradio_mcp/utils.py +++ b/src/gnuradio_mcp/utils.py @@ -8,8 +8,11 @@ from gnuradio.grc.core.ports.port import Port from pydantic import BaseModel from gnuradio_mcp.models import ( + SINK, + SOURCE, BlockModel, ConnectionModel, + DirectionType, ErrorModel, ParamModel, PortModel, @@ -48,3 +51,29 @@ def format_error_message(elem, msg) -> ErrorModel: key=model, # type: ignore message=msg, ) + + +def get_port_by_key_in_port_list(port_list: list[Port], key: str) -> Block: + for port in port_list: + if port.key == key: + return port + raise ValueError(f"Port not found: {key}") + + +def get_port_by_key( + flowgraph, block_name: str, port_name: str, direction: DirectionType +) -> Port: + block = flowgraph.get_block(block_name) + if direction == SOURCE: + return get_port_by_key_in_port_list(block.sources, port_name) + elif direction == SINK: + return get_port_by_key_in_port_list(block.sinks, port_name) + else: + raise ValueError(f"Invalid port direction: {direction}") + + +def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port: + block_from_port_model = flowgraph.get_block(port_model.parent) + return get_port_by_key( + flowgraph, block_from_port_model.name, port_model.key, port_model.direction + )