main - feat: Implement gnuradio element validation and errors

This commit is contained in:
Yoel Bassin 2025-04-27 12:33:29 +03:00
parent 0be8e77596
commit 42484d6c7d
11 changed files with 172 additions and 45 deletions

View File

@ -22,11 +22,12 @@ platform = Platform(
version=gr.version(), version=gr.version(),
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()), version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
prefs=gr.prefs(), prefs=gr.prefs(),
# install_prefix=gr.prefix()
) )
platform.build_library() platform.build_library()
platform_middleware = PlatformMiddleware(platform) platform_middleware = PlatformMiddleware(platform)
flowgraph_mw = platform_middleware.make_flowgraph()
print(platform_middleware.blocks) flowgraph_mw.add_block("blocks_add_xx")
for error in flowgraph_mw.get_all_errors():
print(error)

View File

@ -0,0 +1,23 @@
from gnuradio.grc.core.base import Element
from gnuradio_mcp.models import ErrorModel
from gnuradio_mcp.utils import format_error_message
class ElementMiddleware:
def __init__(self, element: Element):
self._element = element
def _rewrite(self):
self._element.rewrite()
def validate(self):
self._rewrite()
self._element.validate()
def get_all_errors(self) -> list[ErrorModel]:
self.validate()
return [
format_error_message(elem, msg)
for elem, msg in self._element.iter_error_messages()
]

View File

@ -4,17 +4,23 @@ from typing import Any, Dict
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel
class BlockMiddleware: class BlockMiddleware(ElementMiddleware):
def __init__(self, block: Block): def __init__(self, block: Block):
self._block = block super().__init__(block)
self._block = self._element
@property @property
def name(self) -> str: def name(self) -> str:
return self._block.name return self._block.name
@name.setter
def name(self, name: str):
self._block.params["id"].set_value(name)
def set_param(self, param_name: str, param_value: Any): def set_param(self, param_name: str, param_value: Any):
self._block.params[param_name].set_value(param_value) self._block.params[param_name].set_value(param_value)
@ -50,6 +56,3 @@ class BlockMiddleware:
except ValueError: except ValueError:
pass pass
return ports return ports
def _rewrite(self):
self._block.rewrite()

View File

@ -1,19 +1,34 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, Optional from typing import TYPE_CHECKING, Optional, Set
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel
from gnuradio_mcp.utils import get_unique_id from gnuradio_mcp.utils import get_unique_id
if TYPE_CHECKING:
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
class FlowGraphMiddleware:
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
block_from_port_model = flowgraph.get_block(port_model.parent)
if port_model.direction == SOURCE:
return block_from_port_model.sources[port_model.key]
elif port_model.direction == SINK:
return block_from_port_model.sinks[port_model.key]
else:
raise ValueError(f"Invalid port direction: {port_model.direction}")
class FlowGraphMiddleware(ElementMiddleware):
def __init__(self, flowgraph: FlowGraph): def __init__(self, flowgraph: FlowGraph):
self._flowgraph = flowgraph super().__init__(flowgraph)
self._blocks: Dict[str, BlockMiddleware] = {} self._flowgraph = self._element
self._blocks: Set[BlockMiddleware] = set()
@property @property
def blocks(self) -> list[BlockModel]: def blocks(self) -> list[BlockModel]:
@ -25,39 +40,34 @@ class FlowGraphMiddleware:
def add_block( def add_block(
self, block_type: str, block_name: Optional[str] = None self, block_type: str, block_name: Optional[str] = None
) -> BlockMiddleware: ) -> BlockMiddleware:
block_name = block_name or get_unique_id(self._flowgraph.blocks) block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type)
block = self._flowgraph.new_block(block_type) block = self._flowgraph.new_block(block_type)
block.params["id"].set_value(block_name) assert block is not None, f"Failed to create block: {block_type}"
self._blocks[block_name] = BlockMiddleware(block) block_middleware = BlockMiddleware(block)
return self._blocks[block_name] block_middleware.name = block_name
self._blocks.add(block_middleware)
return block_middleware
def remove_block(self, block_name: str) -> None: def remove_block(self, block_name: str) -> None:
block = self._flowgraph.get_block(block_name) block_middleware = self.get_block(block_name)
self._flowgraph.remove_element(block) self._flowgraph.remove_element(block_middleware._block)
del self._blocks[block_name] self._blocks.remove(block_middleware)
def get_block(self, block_name: str) -> BlockMiddleware: def get_block(self, block_name: str) -> BlockMiddleware:
# TODO: Check if calling two times you get different results return next(block for block in self._blocks if block.name == block_name)
return self._blocks[block_name]
def connect_blocks( def connect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel self, src_port_model: PortModel, dst_port_model: PortModel
) -> None: ) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block: src_port = get_port_from_port_model(self._flowgraph, src_port_model)
return self._flowgraph.get_block(port_model.parent) dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
self._flowgraph.connect(src_port, dst_port) self._flowgraph.connect(src_port, dst_port)
def disconnect_blocks( def disconnect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel self, src_port_model: PortModel, dst_port_model: PortModel
) -> None: ) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block: src_port = get_port_from_port_model(self._flowgraph, src_port_model)
return self._flowgraph.get_block(port_model.parent) dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
self._flowgraph.disconnect(src_port, dst_port) self._flowgraph.disconnect(src_port, dst_port)
def get_connections(self) -> list[ConnectionModel]: def get_connections(self) -> list[ConnectionModel]:
@ -65,3 +75,12 @@ class FlowGraphMiddleware:
ConnectionModel.from_connection(connection) ConnectionModel.from_connection(connection)
for connection in self._flowgraph.connections for connection in self._flowgraph.connections
] ]
@classmethod
def from_file(
cls, platform: "PlatformMiddleware", filepath: str = ""
) -> FlowGraphMiddleware:
initial_state = platform._platform.parse_flow_graph(filepath)
flowgraph = FlowGraph(platform._platform)
flowgraph.import_data(initial_state)
return cls(flowgraph)

View File

@ -2,15 +2,15 @@ from __future__ import annotations
from gnuradio.grc.core.platform import Platform from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import BlockModel from gnuradio_mcp.models import BlockModel
class PlatformMiddleware: class PlatformMiddleware(ElementMiddleware):
def __init__(self, platform: Platform): def __init__(self, platform: Platform):
self._platform = platform super().__init__(platform)
flowgraph = self._platform.make_flow_graph("") self._platform = self._element
self._flowgraph_mw = FlowGraphMiddleware(flowgraph)
@property @property
def blocks(self) -> list[BlockModel]: def blocks(self) -> list[BlockModel]:
@ -18,6 +18,5 @@ class PlatformMiddleware:
BlockModel.from_block(block) for block in self._platform.blocks.values() BlockModel.from_block(block) for block in self._platform.blocks.values()
] ]
@property def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware:
def flowgraph(self) -> FlowGraphMiddleware: return FlowGraphMiddleware.from_file(self, filepath)
return self._flowgraph_mw

View File

@ -77,3 +77,9 @@ class ConnectionModel(BaseModel):
source=PortModel.from_port(connection.source_port), source=PortModel.from_port(connection.source_port),
sink=PortModel.from_port(connection.sink_port), sink=PortModel.from_port(connection.sink_port),
) )
class ErrorModel(BaseModel):
type: str
key: BaseModel
message: str

View File

@ -1,5 +1,20 @@
import re
from itertools import count from itertools import count
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.Connection import Connection
from gnuradio.grc.core.params.param import Param
from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel
from gnuradio_mcp.models import (
BlockModel,
ConnectionModel,
ErrorModel,
ParamModel,
PortModel,
)
def get_unique_id(flowgraph_blocks, base_id=""): def get_unique_id(flowgraph_blocks, base_id=""):
block_ids = set(b.name for b in flowgraph_blocks) block_ids = set(b.name for b in flowgraph_blocks)
@ -8,3 +23,28 @@ def get_unique_id(flowgraph_blocks, base_id=""):
if block_id not in block_ids: if block_id not in block_ids:
break break
return block_id return block_id
def format_error_message(elem, msg) -> ErrorModel:
msg = re.sub("[^A-Za-z0-9]+", " ", msg).strip()
model: BaseModel
match (elem):
case Connection():
model = ConnectionModel.from_connection(elem)
case Param():
model = ParamModel.from_param(elem)
case Port():
model = PortModel.from_port(elem)
case Block():
model = BlockModel.from_block(elem)
case _:
raise ValueError(f"Unsupported element type: {type(elem)}")
return ErrorModel(
type=type(model).__name__,
key=model,
message=msg,
)

View File

@ -4,6 +4,8 @@ import pytest
from gnuradio import gr from gnuradio import gr
from gnuradio.grc.core.platform import Platform from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def platform() -> Platform: def platform() -> Platform:
@ -20,6 +22,11 @@ def platform() -> Platform:
return platform return platform
@pytest.fixture(scope="module")
def platform_middleware(platform: Platform) -> PlatformMiddleware:
return PlatformMiddleware(platform)
@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test @pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test
def block_key(platform, request): def block_key(platform, request):
block_keys = list(platform.blocks.keys()) block_keys = list(platform.blocks.keys())

View File

@ -6,7 +6,7 @@ from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import SINK, SOURCE, ParamModel from gnuradio_mcp.models import SINK, SOURCE, ErrorModel, ParamModel
@pytest.fixture @pytest.fixture
@ -44,6 +44,22 @@ def test_block_middleware_set_params(block_middleware: BlockMiddleware):
assert block_middleware._block.params["id"].get_value() == "my_custom_block_name" assert block_middleware._block.params["id"].get_value() == "my_custom_block_name"
@pytest.mark.parametrize(
"block_key, initial_errors_number", [("blocks_add_xx", 3), ("blocks_copy", 2)]
)
def test_block_errors(
flowgraph_middleware: FlowGraphMiddleware,
block_key: str,
initial_errors_number: int,
):
block_middleware = flowgraph_middleware.add_block(block_key)
for error in block_middleware.get_all_errors():
assert isinstance(error, ErrorModel)
assert len(block_middleware.get_all_errors()) == initial_errors_number
# Call again to check that the errors are not duplicated
assert len(block_middleware.get_all_errors()) == initial_errors_number
def check_param_models(block: Block, params: list[ParamModel]): def check_param_models(block: Block, params: list[ParamModel]):
assert params assert params
assert len(params) == len(block.params) assert len(params) == len(block.params)

View File

@ -3,17 +3,16 @@ from __future__ import annotations
from typing import Generator from typing import Generator
import pytest import pytest
from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import BlockModel, ConnectionModel from gnuradio_mcp.middlewares.platform import PlatformMiddleware
from gnuradio_mcp.models import BlockModel, ConnectionModel, ErrorModel
@pytest.fixture @pytest.fixture
def flowgraph_middleware(platform: Platform): def flowgraph_middleware(platform_middleware: PlatformMiddleware):
flowgraph = platform.make_flow_graph("") return platform_middleware.make_flowgraph()
return FlowGraphMiddleware(flowgraph)
@pytest.fixture @pytest.fixture
@ -67,6 +66,11 @@ def test_block_unique_names_for_same_type(
assert first_block.name != second_block.name assert first_block.name != second_block.name
def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
block = flowgraph_middleware.add_block(block_key)
assert block_key in block.name
@pytest.mark.parametrize( @pytest.mark.parametrize(
"block_key, sinks_number, sources_number", "block_key, sinks_number, sources_number",
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)], [("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
@ -109,6 +113,12 @@ def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_ke
assert len(flowgraph_middleware.get_connections()) == 0 assert len(flowgraph_middleware.get_connections()) == 0
def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware):
for error in flowgraph_middleware.get_all_errors():
assert isinstance(error, ErrorModel)
assert len(flowgraph_middleware.get_all_errors()) == 0
def util_iter_possible_connections( def util_iter_possible_connections(
source_block: BlockMiddleware, source_block: BlockMiddleware,
dest_block: BlockMiddleware, dest_block: BlockMiddleware,

View File

@ -1,2 +1,5 @@
[pep8] [pep8]
ignore = E265,E501,W504 ignore = E265,E501,W504
[isort]
profile = black