main - feat: Implement gnuradio element validation and errors

This commit is contained in:
Yoel Bassin 2025-04-27 12:33:29 +03:00
parent 0be8e77596
commit 42484d6c7d
11 changed files with 172 additions and 45 deletions

View File

@ -22,11 +22,12 @@ platform = Platform(
version=gr.version(),
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
prefs=gr.prefs(),
# install_prefix=gr.prefix()
)
platform.build_library()
platform_middleware = PlatformMiddleware(platform)
print(platform_middleware.blocks)
flowgraph_mw = platform_middleware.make_flowgraph()
flowgraph_mw.add_block("blocks_add_xx")
for error in flowgraph_mw.get_all_errors():
print(error)

View File

@ -0,0 +1,23 @@
from gnuradio.grc.core.base import Element
from gnuradio_mcp.models import ErrorModel
from gnuradio_mcp.utils import format_error_message
class ElementMiddleware:
def __init__(self, element: Element):
self._element = element
def _rewrite(self):
self._element.rewrite()
def validate(self):
self._rewrite()
self._element.validate()
def get_all_errors(self) -> list[ErrorModel]:
self.validate()
return [
format_error_message(elem, msg)
for elem, msg in self._element.iter_error_messages()
]

View File

@ -4,17 +4,23 @@ from typing import Any, Dict
from gnuradio.grc.core.blocks.block import Block
from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel
class BlockMiddleware:
class BlockMiddleware(ElementMiddleware):
def __init__(self, block: Block):
self._block = block
super().__init__(block)
self._block = self._element
@property
def name(self) -> str:
return self._block.name
@name.setter
def name(self, name: str):
self._block.params["id"].set_value(name)
def set_param(self, param_name: str, param_value: Any):
self._block.params[param_name].set_value(param_value)
@ -50,6 +56,3 @@ class BlockMiddleware:
except ValueError:
pass
return ports
def _rewrite(self):
self._block.rewrite()

View File

@ -1,19 +1,34 @@
from __future__ import annotations
from typing import Dict, Optional
from typing import TYPE_CHECKING, Optional, Set
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel
from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel
from gnuradio_mcp.utils import get_unique_id
if TYPE_CHECKING:
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
class FlowGraphMiddleware:
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
block_from_port_model = flowgraph.get_block(port_model.parent)
if port_model.direction == SOURCE:
return block_from_port_model.sources[port_model.key]
elif port_model.direction == SINK:
return block_from_port_model.sinks[port_model.key]
else:
raise ValueError(f"Invalid port direction: {port_model.direction}")
class FlowGraphMiddleware(ElementMiddleware):
def __init__(self, flowgraph: FlowGraph):
self._flowgraph = flowgraph
self._blocks: Dict[str, BlockMiddleware] = {}
super().__init__(flowgraph)
self._flowgraph = self._element
self._blocks: Set[BlockMiddleware] = set()
@property
def blocks(self) -> list[BlockModel]:
@ -25,39 +40,34 @@ class FlowGraphMiddleware:
def add_block(
self, block_type: str, block_name: Optional[str] = None
) -> BlockMiddleware:
block_name = block_name or get_unique_id(self._flowgraph.blocks)
block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type)
block = self._flowgraph.new_block(block_type)
block.params["id"].set_value(block_name)
self._blocks[block_name] = BlockMiddleware(block)
return self._blocks[block_name]
assert block is not None, f"Failed to create block: {block_type}"
block_middleware = BlockMiddleware(block)
block_middleware.name = block_name
self._blocks.add(block_middleware)
return block_middleware
def remove_block(self, block_name: str) -> None:
block = self._flowgraph.get_block(block_name)
self._flowgraph.remove_element(block)
del self._blocks[block_name]
block_middleware = self.get_block(block_name)
self._flowgraph.remove_element(block_middleware._block)
self._blocks.remove(block_middleware)
def get_block(self, block_name: str) -> BlockMiddleware:
# TODO: Check if calling two times you get different results
return self._blocks[block_name]
return next(block for block in self._blocks if block.name == block_name)
def connect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel
) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block:
return self._flowgraph.get_block(port_model.parent)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
src_port = get_port_from_port_model(self._flowgraph, src_port_model)
dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
self._flowgraph.connect(src_port, dst_port)
def disconnect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel
) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block:
return self._flowgraph.get_block(port_model.parent)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
src_port = get_port_from_port_model(self._flowgraph, src_port_model)
dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
self._flowgraph.disconnect(src_port, dst_port)
def get_connections(self) -> list[ConnectionModel]:
@ -65,3 +75,12 @@ class FlowGraphMiddleware:
ConnectionModel.from_connection(connection)
for connection in self._flowgraph.connections
]
@classmethod
def from_file(
cls, platform: "PlatformMiddleware", filepath: str = ""
) -> FlowGraphMiddleware:
initial_state = platform._platform.parse_flow_graph(filepath)
flowgraph = FlowGraph(platform._platform)
flowgraph.import_data(initial_state)
return cls(flowgraph)

View File

@ -2,15 +2,15 @@ from __future__ import annotations
from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import BlockModel
class PlatformMiddleware:
class PlatformMiddleware(ElementMiddleware):
def __init__(self, platform: Platform):
self._platform = platform
flowgraph = self._platform.make_flow_graph("")
self._flowgraph_mw = FlowGraphMiddleware(flowgraph)
super().__init__(platform)
self._platform = self._element
@property
def blocks(self) -> list[BlockModel]:
@ -18,6 +18,5 @@ class PlatformMiddleware:
BlockModel.from_block(block) for block in self._platform.blocks.values()
]
@property
def flowgraph(self) -> FlowGraphMiddleware:
return self._flowgraph_mw
def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware:
return FlowGraphMiddleware.from_file(self, filepath)

View File

@ -77,3 +77,9 @@ class ConnectionModel(BaseModel):
source=PortModel.from_port(connection.source_port),
sink=PortModel.from_port(connection.sink_port),
)
class ErrorModel(BaseModel):
type: str
key: BaseModel
message: str

View File

@ -1,5 +1,20 @@
import re
from itertools import count
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.Connection import Connection
from gnuradio.grc.core.params.param import Param
from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel
from gnuradio_mcp.models import (
BlockModel,
ConnectionModel,
ErrorModel,
ParamModel,
PortModel,
)
def get_unique_id(flowgraph_blocks, base_id=""):
block_ids = set(b.name for b in flowgraph_blocks)
@ -8,3 +23,28 @@ def get_unique_id(flowgraph_blocks, base_id=""):
if block_id not in block_ids:
break
return block_id
def format_error_message(elem, msg) -> ErrorModel:
msg = re.sub("[^A-Za-z0-9]+", " ", msg).strip()
model: BaseModel
match (elem):
case Connection():
model = ConnectionModel.from_connection(elem)
case Param():
model = ParamModel.from_param(elem)
case Port():
model = PortModel.from_port(elem)
case Block():
model = BlockModel.from_block(elem)
case _:
raise ValueError(f"Unsupported element type: {type(elem)}")
return ErrorModel(
type=type(model).__name__,
key=model,
message=msg,
)

View File

@ -4,6 +4,8 @@ import pytest
from gnuradio import gr
from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
@pytest.fixture(scope="module")
def platform() -> Platform:
@ -20,6 +22,11 @@ def platform() -> Platform:
return platform
@pytest.fixture(scope="module")
def platform_middleware(platform: Platform) -> PlatformMiddleware:
return PlatformMiddleware(platform)
@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test
def block_key(platform, request):
block_keys = list(platform.blocks.keys())

View File

@ -6,7 +6,7 @@ from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import SINK, SOURCE, ParamModel
from gnuradio_mcp.models import SINK, SOURCE, ErrorModel, ParamModel
@pytest.fixture
@ -44,6 +44,22 @@ def test_block_middleware_set_params(block_middleware: BlockMiddleware):
assert block_middleware._block.params["id"].get_value() == "my_custom_block_name"
@pytest.mark.parametrize(
"block_key, initial_errors_number", [("blocks_add_xx", 3), ("blocks_copy", 2)]
)
def test_block_errors(
flowgraph_middleware: FlowGraphMiddleware,
block_key: str,
initial_errors_number: int,
):
block_middleware = flowgraph_middleware.add_block(block_key)
for error in block_middleware.get_all_errors():
assert isinstance(error, ErrorModel)
assert len(block_middleware.get_all_errors()) == initial_errors_number
# Call again to check that the errors are not duplicated
assert len(block_middleware.get_all_errors()) == initial_errors_number
def check_param_models(block: Block, params: list[ParamModel]):
assert params
assert len(params) == len(block.params)

View File

@ -3,17 +3,16 @@ from __future__ import annotations
from typing import Generator
import pytest
from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import BlockModel, ConnectionModel
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
from gnuradio_mcp.models import BlockModel, ConnectionModel, ErrorModel
@pytest.fixture
def flowgraph_middleware(platform: Platform):
flowgraph = platform.make_flow_graph("")
return FlowGraphMiddleware(flowgraph)
def flowgraph_middleware(platform_middleware: PlatformMiddleware):
return platform_middleware.make_flowgraph()
@pytest.fixture
@ -67,6 +66,11 @@ def test_block_unique_names_for_same_type(
assert first_block.name != second_block.name
def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
block = flowgraph_middleware.add_block(block_key)
assert block_key in block.name
@pytest.mark.parametrize(
"block_key, sinks_number, sources_number",
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
@ -109,6 +113,12 @@ def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_ke
assert len(flowgraph_middleware.get_connections()) == 0
def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware):
for error in flowgraph_middleware.get_all_errors():
assert isinstance(error, ErrorModel)
assert len(flowgraph_middleware.get_all_errors()) == 0
def util_iter_possible_connections(
source_block: BlockMiddleware,
dest_block: BlockMiddleware,

View File

@ -1,2 +1,5 @@
[pep8]
ignore = E265,E501,W504
[isort]
profile = black