main - feat: Imlement basic MCP server

This commit is contained in:
Yoel Bassin 2025-04-27 22:16:58 +03:00
parent f5a0629da7
commit 425b308556
6 changed files with 189 additions and 41 deletions

24
main.py
View File

@ -1,23 +1,16 @@
from __future__ import annotations from __future__ import annotations
import sys from fastmcp import FastMCP
from gnuradio_mcp.middlewares.platform import PlatformMiddleware from gnuradio_mcp.middlewares.platform import PlatformMiddleware
from gnuradio_mcp.providers.mcp import McpPlatformProvider
# Load GNU Radio
try: try:
from gnuradio import gr from gnuradio import gr
from gnuradio.grc.core.platform import Platform
except ImportError: except ImportError:
# Throw a new exception with more information
print(
"Cannot find GNU Radio! (Have you sourced the environment file?)",
file=sys.stderr,
)
# Throw the new exception
raise Exception("Cannot find GNU Radio!") from None raise Exception("Cannot find GNU Radio!") from None
from gnuradio.grc.core.platform import Platform
platform = Platform( platform = Platform(
version=gr.version(), version=gr.version(),
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()), version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
@ -25,9 +18,10 @@ platform = Platform(
) )
platform.build_library() platform.build_library()
app: FastMCP = FastMCP(
"GNU Radio MCP", description="Provide a MCP interface to GNU Radio"
)
platform_middleware = PlatformMiddleware(platform) McpPlatformProvider.from_platform_middleware(app, PlatformMiddleware(platform))
flowgraph_mw = platform_middleware.make_flowgraph()
flowgraph_mw.add_block("blocks_add_xx") app.run(transport="sse")
for error in flowgraph_mw.get_all_errors():
print(error)

View File

@ -5,40 +5,20 @@ from typing import TYPE_CHECKING, Optional
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio.grc.core.ports.port import Port
from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel from gnuradio_mcp.models import (
from gnuradio_mcp.utils import get_unique_id BlockModel,
ConnectionModel,
PortModel,
)
from gnuradio_mcp.utils import get_port_from_port_model, get_unique_id
if TYPE_CHECKING: if TYPE_CHECKING:
from gnuradio_mcp.middlewares.platform import PlatformMiddleware from gnuradio_mcp.middlewares.platform import PlatformMiddleware
def get_port_from_port_model_in_port_list(
port_list: list[Port], port_model: PortModel
) -> Block:
for port in port_list:
if port.key == port_model.key:
return port
raise ValueError(f"Port not found: {port_model.key}")
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port:
block_from_port_model = flowgraph.get_block(port_model.parent)
if port_model.direction == SOURCE:
return get_port_from_port_model_in_port_list(
block_from_port_model.sources, port_model
)
elif port_model.direction == SINK:
return get_port_from_port_model_in_port_list(
block_from_port_model.sinks, port_model
)
else:
raise ValueError(f"Invalid port direction: {port_model.direction}")
def set_block_name(block: Block, name: str): def set_block_name(block: Block, name: str):
block.params["id"].set_value(name) block.params["id"].set_value(name)

View File

View File

@ -0,0 +1,103 @@
from typing import Any, Dict, List
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
from gnuradio_mcp.models import (
SINK,
SOURCE,
BlockModel,
BlockTypeModel,
ConnectionModel,
ErrorModel,
ParamModel,
PortModel,
)
from gnuradio_mcp.utils import get_port_by_key
class PlatformProvider:
def __init__(self, platform_mw: PlatformMiddleware, flowgraph_path: str = ""):
self._platform_mw = platform_mw
self._flowgraph_mw = platform_mw.make_flowgraph(flowgraph_path)
##############################################
# Flowgraph Management
##############################################
def get_blocks(self) -> list[BlockModel]:
return self._flowgraph_mw.blocks
def make_block(self, block_name: str) -> str:
block_mw = self._flowgraph_mw.add_block(block_name)
return block_mw.name
def remove_block(self, block_name: str) -> bool:
self._flowgraph_mw.remove_block(block_name)
return True
##############################################
# Block Management
##############################################
def get_block_params(self, block_name: str) -> List[ParamModel]:
return self._flowgraph_mw.get_block(block_name).params
def set_block_params(self, block_name: str, params: Dict[str, Any]) -> bool:
self._flowgraph_mw.get_block(block_name).set_params(params)
return True
def get_block_sources(self, block_name: str) -> list[PortModel]:
return self._flowgraph_mw.get_block(block_name).sources
def get_block_sinks(self, block_name: str) -> list[PortModel]:
return self._flowgraph_mw.get_block(block_name).sinks
##############################################
# Connection Management
##############################################
def get_connections(self) -> list[ConnectionModel]:
return self._flowgraph_mw.get_connections()
def connect_blocks(
self,
source_block_name: str,
sink_block_name: str,
source_port_name: str,
sink_port_name: str,
) -> bool:
source_port = get_port_by_key(
self._flowgraph_mw, source_block_name, source_port_name, SOURCE
)
sink_port = get_port_by_key(
self._flowgraph_mw, sink_block_name, sink_port_name, SINK
)
self._flowgraph_mw.connect_blocks(source_port, sink_port)
return True
def disconnect_blocks(self, source_port: PortModel, sink_port: PortModel) -> bool:
self._flowgraph_mw.disconnect_blocks(source_port, sink_port)
return True
##############################################
# Flowgraph Validation
##############################################
def validate_block(self, block_name: str) -> bool:
return self._flowgraph_mw.get_block(block_name).validate()
def validate_flowgraph(self) -> bool:
return self._flowgraph_mw.validate()
def get_all_errors(self) -> list[ErrorModel]:
return self._flowgraph_mw.get_all_errors()
##############################################
# Platform Management
##############################################
def get_all_available_blocks(self) -> list[BlockTypeModel]:
return self._platform_mw.blocks
def save_flowgraph(self, filepath: str) -> bool:
self._platform_mw.save_flowgraph(filepath, self._flowgraph_mw)
return True

View File

@ -0,0 +1,42 @@
from fastmcp import FastMCP
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
from gnuradio_mcp.providers.base import PlatformProvider
class McpPlatformProvider:
def __init__(self, mcp_instance: FastMCP, platform_provider: PlatformProvider):
self._mcp_instance = mcp_instance
self._platform_provider = platform_provider
self.__init_tools()
def __init_tools(self):
self._mcp_instance.add_tool(self._platform_provider.get_blocks)
self._mcp_instance.add_tool(self._platform_provider.make_block)
self._mcp_instance.add_tool(self._platform_provider.remove_block)
self._mcp_instance.add_tool(self._platform_provider.get_block_params)
self._mcp_instance.add_tool(self._platform_provider.set_block_params)
self._mcp_instance.add_tool(self._platform_provider.get_block_sources)
self._mcp_instance.add_tool(self._platform_provider.get_block_sinks)
self._mcp_instance.add_tool(self._platform_provider.get_connections)
self._mcp_instance.add_tool(self._platform_provider.connect_blocks)
self._mcp_instance.add_tool(self._platform_provider.disconnect_blocks)
self._mcp_instance.add_tool(self._platform_provider.validate_block)
self._mcp_instance.add_tool(self._platform_provider.validate_flowgraph)
self._mcp_instance.add_tool(self._platform_provider.get_all_errors)
self._mcp_instance.add_tool(self._platform_provider.save_flowgraph)
self._mcp_instance.add_tool(self._platform_provider.get_all_available_blocks)
@property
def app(self) -> FastMCP:
return self._mcp_instance
@classmethod
def from_platform_middleware(
cls,
mcp_instance: FastMCP,
platform_middleware: PlatformMiddleware,
flowgraph_path: str = "",
):
platform_provider = PlatformProvider(platform_middleware, flowgraph_path)
return cls(mcp_instance, platform_provider)

View File

@ -8,8 +8,11 @@ from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel from pydantic import BaseModel
from gnuradio_mcp.models import ( from gnuradio_mcp.models import (
SINK,
SOURCE,
BlockModel, BlockModel,
ConnectionModel, ConnectionModel,
DirectionType,
ErrorModel, ErrorModel,
ParamModel, ParamModel,
PortModel, PortModel,
@ -48,3 +51,29 @@ def format_error_message(elem, msg) -> ErrorModel:
key=model, # type: ignore key=model, # type: ignore
message=msg, message=msg,
) )
def get_port_by_key_in_port_list(port_list: list[Port], key: str) -> Block:
for port in port_list:
if port.key == key:
return port
raise ValueError(f"Port not found: {key}")
def get_port_by_key(
flowgraph, block_name: str, port_name: str, direction: DirectionType
) -> Port:
block = flowgraph.get_block(block_name)
if direction == SOURCE:
return get_port_by_key_in_port_list(block.sources, port_name)
elif direction == SINK:
return get_port_by_key_in_port_list(block.sinks, port_name)
else:
raise ValueError(f"Invalid port direction: {direction}")
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port:
block_from_port_model = flowgraph.get_block(port_model.parent)
return get_port_by_key(
flowgraph, block_from_port_model.name, port_model.key, port_model.direction
)