main - feat: Imlement basic MCP server
This commit is contained in:
parent
f5a0629da7
commit
425b308556
26
main.py
26
main.py
@ -1,22 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||
from gnuradio_mcp.providers.mcp import McpPlatformProvider
|
||||
|
||||
# Load GNU Radio
|
||||
try:
|
||||
from gnuradio import gr
|
||||
except ImportError:
|
||||
# Throw a new exception with more information
|
||||
print(
|
||||
"Cannot find GNU Radio! (Have you sourced the environment file?)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Throw the new exception
|
||||
raise Exception("Cannot find GNU Radio!") from None
|
||||
|
||||
from gnuradio.grc.core.platform import Platform
|
||||
except ImportError:
|
||||
raise Exception("Cannot find GNU Radio!") from None
|
||||
|
||||
platform = Platform(
|
||||
version=gr.version(),
|
||||
@ -25,9 +18,10 @@ platform = Platform(
|
||||
)
|
||||
platform.build_library()
|
||||
|
||||
app: FastMCP = FastMCP(
|
||||
"GNU Radio MCP", description="Provide a MCP interface to GNU Radio"
|
||||
)
|
||||
|
||||
platform_middleware = PlatformMiddleware(platform)
|
||||
flowgraph_mw = platform_middleware.make_flowgraph()
|
||||
flowgraph_mw.add_block("blocks_add_xx")
|
||||
for error in flowgraph_mw.get_all_errors():
|
||||
print(error)
|
||||
McpPlatformProvider.from_platform_middleware(app, PlatformMiddleware(platform))
|
||||
|
||||
app.run(transport="sse")
|
||||
|
||||
@ -5,40 +5,20 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from gnuradio.grc.core.blocks.block import Block
|
||||
from gnuradio.grc.core.FlowGraph import FlowGraph
|
||||
from gnuradio.grc.core.ports.port import Port
|
||||
|
||||
from gnuradio_mcp.middlewares.base import ElementMiddleware
|
||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||
from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel
|
||||
from gnuradio_mcp.utils import get_unique_id
|
||||
from gnuradio_mcp.models import (
|
||||
BlockModel,
|
||||
ConnectionModel,
|
||||
PortModel,
|
||||
)
|
||||
from gnuradio_mcp.utils import get_port_from_port_model, get_unique_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||
|
||||
|
||||
def get_port_from_port_model_in_port_list(
|
||||
port_list: list[Port], port_model: PortModel
|
||||
) -> Block:
|
||||
for port in port_list:
|
||||
if port.key == port_model.key:
|
||||
return port
|
||||
raise ValueError(f"Port not found: {port_model.key}")
|
||||
|
||||
|
||||
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port:
|
||||
block_from_port_model = flowgraph.get_block(port_model.parent)
|
||||
if port_model.direction == SOURCE:
|
||||
return get_port_from_port_model_in_port_list(
|
||||
block_from_port_model.sources, port_model
|
||||
)
|
||||
elif port_model.direction == SINK:
|
||||
return get_port_from_port_model_in_port_list(
|
||||
block_from_port_model.sinks, port_model
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid port direction: {port_model.direction}")
|
||||
|
||||
|
||||
def set_block_name(block: Block, name: str):
|
||||
block.params["id"].set_value(name)
|
||||
|
||||
|
||||
0
src/gnuradio_mcp/providers/__init__.py
Normal file
0
src/gnuradio_mcp/providers/__init__.py
Normal file
103
src/gnuradio_mcp/providers/base.py
Normal file
103
src/gnuradio_mcp/providers/base.py
Normal file
@ -0,0 +1,103 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||
from gnuradio_mcp.models import (
|
||||
SINK,
|
||||
SOURCE,
|
||||
BlockModel,
|
||||
BlockTypeModel,
|
||||
ConnectionModel,
|
||||
ErrorModel,
|
||||
ParamModel,
|
||||
PortModel,
|
||||
)
|
||||
from gnuradio_mcp.utils import get_port_by_key
|
||||
|
||||
|
||||
class PlatformProvider:
|
||||
def __init__(self, platform_mw: PlatformMiddleware, flowgraph_path: str = ""):
|
||||
self._platform_mw = platform_mw
|
||||
self._flowgraph_mw = platform_mw.make_flowgraph(flowgraph_path)
|
||||
|
||||
##############################################
|
||||
# Flowgraph Management
|
||||
##############################################
|
||||
|
||||
def get_blocks(self) -> list[BlockModel]:
|
||||
return self._flowgraph_mw.blocks
|
||||
|
||||
def make_block(self, block_name: str) -> str:
|
||||
block_mw = self._flowgraph_mw.add_block(block_name)
|
||||
return block_mw.name
|
||||
|
||||
def remove_block(self, block_name: str) -> bool:
|
||||
self._flowgraph_mw.remove_block(block_name)
|
||||
return True
|
||||
|
||||
##############################################
|
||||
# Block Management
|
||||
##############################################
|
||||
|
||||
def get_block_params(self, block_name: str) -> List[ParamModel]:
|
||||
return self._flowgraph_mw.get_block(block_name).params
|
||||
|
||||
def set_block_params(self, block_name: str, params: Dict[str, Any]) -> bool:
|
||||
self._flowgraph_mw.get_block(block_name).set_params(params)
|
||||
return True
|
||||
|
||||
def get_block_sources(self, block_name: str) -> list[PortModel]:
|
||||
return self._flowgraph_mw.get_block(block_name).sources
|
||||
|
||||
def get_block_sinks(self, block_name: str) -> list[PortModel]:
|
||||
return self._flowgraph_mw.get_block(block_name).sinks
|
||||
|
||||
##############################################
|
||||
# Connection Management
|
||||
##############################################
|
||||
|
||||
def get_connections(self) -> list[ConnectionModel]:
|
||||
return self._flowgraph_mw.get_connections()
|
||||
|
||||
def connect_blocks(
|
||||
self,
|
||||
source_block_name: str,
|
||||
sink_block_name: str,
|
||||
source_port_name: str,
|
||||
sink_port_name: str,
|
||||
) -> bool:
|
||||
source_port = get_port_by_key(
|
||||
self._flowgraph_mw, source_block_name, source_port_name, SOURCE
|
||||
)
|
||||
sink_port = get_port_by_key(
|
||||
self._flowgraph_mw, sink_block_name, sink_port_name, SINK
|
||||
)
|
||||
self._flowgraph_mw.connect_blocks(source_port, sink_port)
|
||||
return True
|
||||
|
||||
def disconnect_blocks(self, source_port: PortModel, sink_port: PortModel) -> bool:
|
||||
self._flowgraph_mw.disconnect_blocks(source_port, sink_port)
|
||||
return True
|
||||
|
||||
##############################################
|
||||
# Flowgraph Validation
|
||||
##############################################
|
||||
|
||||
def validate_block(self, block_name: str) -> bool:
|
||||
return self._flowgraph_mw.get_block(block_name).validate()
|
||||
|
||||
def validate_flowgraph(self) -> bool:
|
||||
return self._flowgraph_mw.validate()
|
||||
|
||||
def get_all_errors(self) -> list[ErrorModel]:
|
||||
return self._flowgraph_mw.get_all_errors()
|
||||
|
||||
##############################################
|
||||
# Platform Management
|
||||
##############################################
|
||||
|
||||
def get_all_available_blocks(self) -> list[BlockTypeModel]:
|
||||
return self._platform_mw.blocks
|
||||
|
||||
def save_flowgraph(self, filepath: str) -> bool:
|
||||
self._platform_mw.save_flowgraph(filepath, self._flowgraph_mw)
|
||||
return True
|
||||
42
src/gnuradio_mcp/providers/mcp.py
Normal file
42
src/gnuradio_mcp/providers/mcp.py
Normal file
@ -0,0 +1,42 @@
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||
from gnuradio_mcp.providers.base import PlatformProvider
|
||||
|
||||
|
||||
class McpPlatformProvider:
|
||||
def __init__(self, mcp_instance: FastMCP, platform_provider: PlatformProvider):
|
||||
self._mcp_instance = mcp_instance
|
||||
self._platform_provider = platform_provider
|
||||
self.__init_tools()
|
||||
|
||||
def __init_tools(self):
|
||||
self._mcp_instance.add_tool(self._platform_provider.get_blocks)
|
||||
self._mcp_instance.add_tool(self._platform_provider.make_block)
|
||||
self._mcp_instance.add_tool(self._platform_provider.remove_block)
|
||||
self._mcp_instance.add_tool(self._platform_provider.get_block_params)
|
||||
self._mcp_instance.add_tool(self._platform_provider.set_block_params)
|
||||
self._mcp_instance.add_tool(self._platform_provider.get_block_sources)
|
||||
self._mcp_instance.add_tool(self._platform_provider.get_block_sinks)
|
||||
self._mcp_instance.add_tool(self._platform_provider.get_connections)
|
||||
self._mcp_instance.add_tool(self._platform_provider.connect_blocks)
|
||||
self._mcp_instance.add_tool(self._platform_provider.disconnect_blocks)
|
||||
self._mcp_instance.add_tool(self._platform_provider.validate_block)
|
||||
self._mcp_instance.add_tool(self._platform_provider.validate_flowgraph)
|
||||
self._mcp_instance.add_tool(self._platform_provider.get_all_errors)
|
||||
self._mcp_instance.add_tool(self._platform_provider.save_flowgraph)
|
||||
self._mcp_instance.add_tool(self._platform_provider.get_all_available_blocks)
|
||||
|
||||
@property
|
||||
def app(self) -> FastMCP:
|
||||
return self._mcp_instance
|
||||
|
||||
@classmethod
|
||||
def from_platform_middleware(
|
||||
cls,
|
||||
mcp_instance: FastMCP,
|
||||
platform_middleware: PlatformMiddleware,
|
||||
flowgraph_path: str = "",
|
||||
):
|
||||
platform_provider = PlatformProvider(platform_middleware, flowgraph_path)
|
||||
return cls(mcp_instance, platform_provider)
|
||||
@ -8,8 +8,11 @@ from gnuradio.grc.core.ports.port import Port
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gnuradio_mcp.models import (
|
||||
SINK,
|
||||
SOURCE,
|
||||
BlockModel,
|
||||
ConnectionModel,
|
||||
DirectionType,
|
||||
ErrorModel,
|
||||
ParamModel,
|
||||
PortModel,
|
||||
@ -48,3 +51,29 @@ def format_error_message(elem, msg) -> ErrorModel:
|
||||
key=model, # type: ignore
|
||||
message=msg,
|
||||
)
|
||||
|
||||
|
||||
def get_port_by_key_in_port_list(port_list: list[Port], key: str) -> Block:
|
||||
for port in port_list:
|
||||
if port.key == key:
|
||||
return port
|
||||
raise ValueError(f"Port not found: {key}")
|
||||
|
||||
|
||||
def get_port_by_key(
|
||||
flowgraph, block_name: str, port_name: str, direction: DirectionType
|
||||
) -> Port:
|
||||
block = flowgraph.get_block(block_name)
|
||||
if direction == SOURCE:
|
||||
return get_port_by_key_in_port_list(block.sources, port_name)
|
||||
elif direction == SINK:
|
||||
return get_port_by_key_in_port_list(block.sinks, port_name)
|
||||
else:
|
||||
raise ValueError(f"Invalid port direction: {direction}")
|
||||
|
||||
|
||||
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port:
|
||||
block_from_port_model = flowgraph.get_block(port_model.parent)
|
||||
return get_port_by_key(
|
||||
flowgraph, block_from_port_model.name, port_model.key, port_model.direction
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user