From 42484d6c7db2867c542948c2743afb67b52fb9c7 Mon Sep 17 00:00:00 2001 From: Yoel Bassin Date: Sun, 27 Apr 2025 12:33:29 +0300 Subject: [PATCH] main - feat: Implement gnuradio element validation and errors --- main.py | 7 ++- src/gnuradio_mcp/middlewares/base.py | 23 ++++++++ src/gnuradio_mcp/middlewares/block.py | 13 +++-- src/gnuradio_mcp/middlewares/flowgraph.py | 67 +++++++++++++++-------- src/gnuradio_mcp/middlewares/platform.py | 13 ++--- src/gnuradio_mcp/models.py | 6 ++ src/gnuradio_mcp/utils.py | 40 ++++++++++++++ tests/conftest.py | 7 +++ tests/unit/test_block.py | 18 +++++- tests/unit/test_flowgraph.py | 20 +++++-- tox.ini | 3 + 11 files changed, 172 insertions(+), 45 deletions(-) create mode 100644 src/gnuradio_mcp/middlewares/base.py diff --git a/main.py b/main.py index 027fe43..c9bc623 100644 --- a/main.py +++ b/main.py @@ -22,11 +22,12 @@ platform = Platform( version=gr.version(), version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()), prefs=gr.prefs(), - # install_prefix=gr.prefix() ) platform.build_library() platform_middleware = PlatformMiddleware(platform) - -print(platform_middleware.blocks) +flowgraph_mw = platform_middleware.make_flowgraph() +flowgraph_mw.add_block("blocks_add_xx") +for error in flowgraph_mw.get_all_errors(): + print(error) diff --git a/src/gnuradio_mcp/middlewares/base.py b/src/gnuradio_mcp/middlewares/base.py new file mode 100644 index 0000000..b6340e8 --- /dev/null +++ b/src/gnuradio_mcp/middlewares/base.py @@ -0,0 +1,23 @@ +from gnuradio.grc.core.base import Element + +from gnuradio_mcp.models import ErrorModel +from gnuradio_mcp.utils import format_error_message + + +class ElementMiddleware: + def __init__(self, element: Element): + self._element = element + + def _rewrite(self): + self._element.rewrite() + + def validate(self): + self._rewrite() + self._element.validate() + + def get_all_errors(self) -> list[ErrorModel]: + self.validate() + return [ + format_error_message(elem, msg) + for elem, msg in self._element.iter_error_messages() + ] diff --git a/src/gnuradio_mcp/middlewares/block.py b/src/gnuradio_mcp/middlewares/block.py index 7effc6a..2ef37ad 100644 --- a/src/gnuradio_mcp/middlewares/block.py +++ b/src/gnuradio_mcp/middlewares/block.py @@ -4,17 +4,23 @@ from typing import Any, Dict from gnuradio.grc.core.blocks.block import Block +from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel -class BlockMiddleware: +class BlockMiddleware(ElementMiddleware): def __init__(self, block: Block): - self._block = block + super().__init__(block) + self._block = self._element @property def name(self) -> str: return self._block.name + @name.setter + def name(self, name: str): + self._block.params["id"].set_value(name) + def set_param(self, param_name: str, param_value: Any): self._block.params[param_name].set_value(param_value) @@ -50,6 +56,3 @@ class BlockMiddleware: except ValueError: pass return ports - - def _rewrite(self): - self._block.rewrite() diff --git a/src/gnuradio_mcp/middlewares/flowgraph.py b/src/gnuradio_mcp/middlewares/flowgraph.py index 48f3786..5b0ba49 100644 --- a/src/gnuradio_mcp/middlewares/flowgraph.py +++ b/src/gnuradio_mcp/middlewares/flowgraph.py @@ -1,19 +1,34 @@ from __future__ import annotations -from typing import Dict, Optional +from typing import TYPE_CHECKING, Optional, Set from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.FlowGraph import FlowGraph +from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware -from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel +from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel from gnuradio_mcp.utils import get_unique_id +if TYPE_CHECKING: + from gnuradio_mcp.middlewares.platform import PlatformMiddleware -class FlowGraphMiddleware: + +def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block: + block_from_port_model = flowgraph.get_block(port_model.parent) + if port_model.direction == SOURCE: + return block_from_port_model.sources[port_model.key] + elif port_model.direction == SINK: + return block_from_port_model.sinks[port_model.key] + else: + raise ValueError(f"Invalid port direction: {port_model.direction}") + + +class FlowGraphMiddleware(ElementMiddleware): def __init__(self, flowgraph: FlowGraph): - self._flowgraph = flowgraph - self._blocks: Dict[str, BlockMiddleware] = {} + super().__init__(flowgraph) + self._flowgraph = self._element + self._blocks: Set[BlockMiddleware] = set() @property def blocks(self) -> list[BlockModel]: @@ -25,39 +40,34 @@ class FlowGraphMiddleware: def add_block( self, block_type: str, block_name: Optional[str] = None ) -> BlockMiddleware: - block_name = block_name or get_unique_id(self._flowgraph.blocks) + block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type) block = self._flowgraph.new_block(block_type) - block.params["id"].set_value(block_name) - self._blocks[block_name] = BlockMiddleware(block) - return self._blocks[block_name] + assert block is not None, f"Failed to create block: {block_type}" + block_middleware = BlockMiddleware(block) + block_middleware.name = block_name + self._blocks.add(block_middleware) + return block_middleware def remove_block(self, block_name: str) -> None: - block = self._flowgraph.get_block(block_name) - self._flowgraph.remove_element(block) - del self._blocks[block_name] + block_middleware = self.get_block(block_name) + self._flowgraph.remove_element(block_middleware._block) + self._blocks.remove(block_middleware) def get_block(self, block_name: str) -> BlockMiddleware: - # TODO: Check if calling two times you get different results - return self._blocks[block_name] + return next(block for block in self._blocks if block.name == block_name) def connect_blocks( self, src_port_model: PortModel, dst_port_model: PortModel ) -> None: - def get_block_by_port_model(port_model: PortModel) -> Block: - return self._flowgraph.get_block(port_model.parent) - - src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key] - dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key] + src_port = get_port_from_port_model(self._flowgraph, src_port_model) + dst_port = get_port_from_port_model(self._flowgraph, dst_port_model) self._flowgraph.connect(src_port, dst_port) def disconnect_blocks( self, src_port_model: PortModel, dst_port_model: PortModel ) -> None: - def get_block_by_port_model(port_model: PortModel) -> Block: - return self._flowgraph.get_block(port_model.parent) - - src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key] - dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key] + src_port = get_port_from_port_model(self._flowgraph, src_port_model) + dst_port = get_port_from_port_model(self._flowgraph, dst_port_model) self._flowgraph.disconnect(src_port, dst_port) def get_connections(self) -> list[ConnectionModel]: @@ -65,3 +75,12 @@ class FlowGraphMiddleware: ConnectionModel.from_connection(connection) for connection in self._flowgraph.connections ] + + @classmethod + def from_file( + cls, platform: "PlatformMiddleware", filepath: str = "" + ) -> FlowGraphMiddleware: + initial_state = platform._platform.parse_flow_graph(filepath) + flowgraph = FlowGraph(platform._platform) + flowgraph.import_data(initial_state) + return cls(flowgraph) diff --git a/src/gnuradio_mcp/middlewares/platform.py b/src/gnuradio_mcp/middlewares/platform.py index 325b191..f18c9ed 100644 --- a/src/gnuradio_mcp/middlewares/platform.py +++ b/src/gnuradio_mcp/middlewares/platform.py @@ -2,15 +2,15 @@ from __future__ import annotations from gnuradio.grc.core.platform import Platform +from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio_mcp.models import BlockModel -class PlatformMiddleware: +class PlatformMiddleware(ElementMiddleware): def __init__(self, platform: Platform): - self._platform = platform - flowgraph = self._platform.make_flow_graph("") - self._flowgraph_mw = FlowGraphMiddleware(flowgraph) + super().__init__(platform) + self._platform = self._element @property def blocks(self) -> list[BlockModel]: @@ -18,6 +18,5 @@ class PlatformMiddleware: BlockModel.from_block(block) for block in self._platform.blocks.values() ] - @property - def flowgraph(self) -> FlowGraphMiddleware: - return self._flowgraph_mw + def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware: + return FlowGraphMiddleware.from_file(self, filepath) diff --git a/src/gnuradio_mcp/models.py b/src/gnuradio_mcp/models.py index 8205a0b..f9fb679 100644 --- a/src/gnuradio_mcp/models.py +++ b/src/gnuradio_mcp/models.py @@ -77,3 +77,9 @@ class ConnectionModel(BaseModel): source=PortModel.from_port(connection.source_port), sink=PortModel.from_port(connection.sink_port), ) + + +class ErrorModel(BaseModel): + type: str + key: BaseModel + message: str diff --git a/src/gnuradio_mcp/utils.py b/src/gnuradio_mcp/utils.py index 11428d2..800d37d 100644 --- a/src/gnuradio_mcp/utils.py +++ b/src/gnuradio_mcp/utils.py @@ -1,5 +1,20 @@ +import re from itertools import count +from gnuradio.grc.core.blocks.block import Block +from gnuradio.grc.core.Connection import Connection +from gnuradio.grc.core.params.param import Param +from gnuradio.grc.core.ports.port import Port +from pydantic import BaseModel + +from gnuradio_mcp.models import ( + BlockModel, + ConnectionModel, + ErrorModel, + ParamModel, + PortModel, +) + def get_unique_id(flowgraph_blocks, base_id=""): block_ids = set(b.name for b in flowgraph_blocks) @@ -8,3 +23,28 @@ def get_unique_id(flowgraph_blocks, base_id=""): if block_id not in block_ids: break return block_id + + +def format_error_message(elem, msg) -> ErrorModel: + msg = re.sub("[^A-Za-z0-9]+", " ", msg).strip() + model: BaseModel + match (elem): + case Connection(): + model = ConnectionModel.from_connection(elem) + + case Param(): + model = ParamModel.from_param(elem) + + case Port(): + model = PortModel.from_port(elem) + + case Block(): + model = BlockModel.from_block(elem) + + case _: + raise ValueError(f"Unsupported element type: {type(elem)}") + return ErrorModel( + type=type(model).__name__, + key=model, + message=msg, + ) diff --git a/tests/conftest.py b/tests/conftest.py index fa3989d..bdb4734 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import pytest from gnuradio import gr from gnuradio.grc.core.platform import Platform +from gnuradio_mcp.middlewares.platform import PlatformMiddleware + @pytest.fixture(scope="module") def platform() -> Platform: @@ -20,6 +22,11 @@ def platform() -> Platform: return platform +@pytest.fixture(scope="module") +def platform_middleware(platform: Platform) -> PlatformMiddleware: + return PlatformMiddleware(platform) + + @pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test def block_key(platform, request): block_keys = list(platform.blocks.keys()) diff --git a/tests/unit/test_block.py b/tests/unit/test_block.py index fe2653d..53e4a03 100644 --- a/tests/unit/test_block.py +++ b/tests/unit/test_block.py @@ -6,7 +6,7 @@ from gnuradio.grc.core.platform import Platform from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware -from gnuradio_mcp.models import SINK, SOURCE, ParamModel +from gnuradio_mcp.models import SINK, SOURCE, ErrorModel, ParamModel @pytest.fixture @@ -44,6 +44,22 @@ def test_block_middleware_set_params(block_middleware: BlockMiddleware): assert block_middleware._block.params["id"].get_value() == "my_custom_block_name" +@pytest.mark.parametrize( + "block_key, initial_errors_number", [("blocks_add_xx", 3), ("blocks_copy", 2)] +) +def test_block_errors( + flowgraph_middleware: FlowGraphMiddleware, + block_key: str, + initial_errors_number: int, +): + block_middleware = flowgraph_middleware.add_block(block_key) + for error in block_middleware.get_all_errors(): + assert isinstance(error, ErrorModel) + assert len(block_middleware.get_all_errors()) == initial_errors_number + # Call again to check that the errors are not duplicated + assert len(block_middleware.get_all_errors()) == initial_errors_number + + def check_param_models(block: Block, params: list[ParamModel]): assert params assert len(params) == len(block.params) diff --git a/tests/unit/test_flowgraph.py b/tests/unit/test_flowgraph.py index 33a3461..9406b91 100644 --- a/tests/unit/test_flowgraph.py +++ b/tests/unit/test_flowgraph.py @@ -3,17 +3,16 @@ from __future__ import annotations from typing import Generator import pytest -from gnuradio.grc.core.platform import Platform from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware -from gnuradio_mcp.models import BlockModel, ConnectionModel +from gnuradio_mcp.middlewares.platform import PlatformMiddleware +from gnuradio_mcp.models import BlockModel, ConnectionModel, ErrorModel @pytest.fixture -def flowgraph_middleware(platform: Platform): - flowgraph = platform.make_flow_graph("") - return FlowGraphMiddleware(flowgraph) +def flowgraph_middleware(platform_middleware: PlatformMiddleware): + return platform_middleware.make_flowgraph() @pytest.fixture @@ -67,6 +66,11 @@ def test_block_unique_names_for_same_type( assert first_block.name != second_block.name +def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str): + block = flowgraph_middleware.add_block(block_key) + assert block_key in block.name + + @pytest.mark.parametrize( "block_key, sinks_number, sources_number", [("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)], @@ -109,6 +113,12 @@ def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_ke assert len(flowgraph_middleware.get_connections()) == 0 +def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware): + for error in flowgraph_middleware.get_all_errors(): + assert isinstance(error, ErrorModel) + assert len(flowgraph_middleware.get_all_errors()) == 0 + + def util_iter_possible_connections( source_block: BlockMiddleware, dest_block: BlockMiddleware, diff --git a/tox.ini b/tox.ini index 0cd64bc..29c5ee1 100644 --- a/tox.ini +++ b/tox.ini @@ -1,2 +1,5 @@ [pep8] ignore = E265,E501,W504 + +[isort] +profile = black