From 5485413efd00e9e2938e142819ebf38a79f117d5 Mon Sep 17 00:00:00 2001 From: Yoel Bassin Date: Sun, 27 Apr 2025 00:02:19 +0300 Subject: [PATCH] main - feat: Implement block connection --- src/gnuradio_mcp/middlewares/block.py | 29 +++++- src/gnuradio_mcp/middlewares/flowgraph.py | 44 +++++++-- src/gnuradio_mcp/models.py | 19 +++- tests/test_block.py | 26 +++--- tests/test_flowgraph.py | 107 ++++++++++++++-------- 5 files changed, 163 insertions(+), 62 deletions(-) diff --git a/src/gnuradio_mcp/middlewares/block.py b/src/gnuradio_mcp/middlewares/block.py index cc6ed61..6a4a179 100644 --- a/src/gnuradio_mcp/middlewares/block.py +++ b/src/gnuradio_mcp/middlewares/block.py @@ -9,14 +9,39 @@ class BlockMiddleware: def __init__(self, block: Block): self._block = block + @property + def name(self) -> str: + return self._block.name + + # TODO: Check if rewrite is needed + @property def params(self) -> list[ParamModel]: return [ParamModel.from_param(param) for param in self._block.params.values()] @property def sinks(self) -> list[PortModel]: - return [PortModel.from_port(port, SINK) for port in self._block.sinks] + self._rewrite() + ports = [] + for port in self._block.sinks: + try: + port_model = PortModel.from_port(port, SINK) + ports.append(port_model) + except ValueError: + pass + return ports @property def sources(self) -> list[PortModel]: - return [PortModel.from_port(port, SOURCE) for port in self._block.sources] + self._rewrite() + ports = [] + for port in self._block.sources: + try: + port_model = PortModel.from_port(port, SOURCE) + ports.append(port_model) + except ValueError: + pass + return ports + + def _rewrite(self): + self._block.rewrite() diff --git a/src/gnuradio_mcp/middlewares/flowgraph.py b/src/gnuradio_mcp/middlewares/flowgraph.py index a0d74d3..48f3786 100644 --- a/src/gnuradio_mcp/middlewares/flowgraph.py +++ b/src/gnuradio_mcp/middlewares/flowgraph.py @@ -1,17 +1,19 @@ from __future__ import annotations -from typing import Optional +from typing import Dict, Optional +from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.FlowGraph import FlowGraph from gnuradio_mcp.middlewares.block import BlockMiddleware -from gnuradio_mcp.models import BlockModel +from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel from gnuradio_mcp.utils import get_unique_id class FlowGraphMiddleware: def __init__(self, flowgraph: FlowGraph): self._flowgraph = flowgraph + self._blocks: Dict[str, BlockMiddleware] = {} @property def blocks(self) -> list[BlockModel]: @@ -20,16 +22,46 @@ class FlowGraphMiddleware: for block in self._flowgraph.blocks ] - def add_block(self, block_type: str, block_name: Optional[str] = None) -> str: + def add_block( + self, block_type: str, block_name: Optional[str] = None + ) -> BlockMiddleware: block_name = block_name or get_unique_id(self._flowgraph.blocks) block = self._flowgraph.new_block(block_type) block.params["id"].set_value(block_name) - return block_name + self._blocks[block_name] = BlockMiddleware(block) + return self._blocks[block_name] def remove_block(self, block_name: str) -> None: block = self._flowgraph.get_block(block_name) self._flowgraph.remove_element(block) + del self._blocks[block_name] def get_block(self, block_name: str) -> BlockMiddleware: - block = self._flowgraph.get_block(block_name) - return BlockMiddleware(block) + # TODO: Check if calling two times you get different results + return self._blocks[block_name] + + def connect_blocks( + self, src_port_model: PortModel, dst_port_model: PortModel + ) -> None: + def get_block_by_port_model(port_model: PortModel) -> Block: + return self._flowgraph.get_block(port_model.parent) + + src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key] + dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key] + self._flowgraph.connect(src_port, dst_port) + + def disconnect_blocks( + self, src_port_model: PortModel, dst_port_model: PortModel + ) -> None: + def get_block_by_port_model(port_model: PortModel) -> Block: + return self._flowgraph.get_block(port_model.parent) + + src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key] + dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key] + self._flowgraph.disconnect(src_port, dst_port) + + def get_connections(self) -> list[ConnectionModel]: + return [ + ConnectionModel.from_connection(connection) + for connection in self._flowgraph.connections + ] diff --git a/src/gnuradio_mcp/models.py b/src/gnuradio_mcp/models.py index b32b81d..8205a0b 100644 --- a/src/gnuradio_mcp/models.py +++ b/src/gnuradio_mcp/models.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Any, Literal, get_args from gnuradio.grc.core.blocks.block import Block +from gnuradio.grc.core.Connection import Connection from gnuradio.grc.core.params.param import Param from gnuradio.grc.core.ports.port import Port from pydantic import BaseModel @@ -41,7 +42,7 @@ SINK, SOURCE = get_args(DirectionType) class PortModel(BaseModel): parent: str - key: str + key: int name: str dtype: str direction: DirectionType @@ -54,11 +55,25 @@ class PortModel(BaseModel): direction: DirectionType | None = None, ) -> PortModel: direction = direction or port._dir + if not port.key.isnumeric(): + raise ValueError("Currently not supporting named ports") return cls( parent=port.parent.name, - key=port.key, + key=int(port.key), name=port.name, dtype=port.dtype, direction=direction, optional=port.optional, ) + + +class ConnectionModel(BaseModel): + source: PortModel + sink: PortModel + + @classmethod + def from_connection(cls, connection: Connection) -> "ConnectionModel": + return cls( + source=PortModel.from_port(connection.source_port), + sink=PortModel.from_port(connection.sink_port), + ) diff --git a/tests/test_block.py b/tests/test_block.py index dfc1b33..090920b 100644 --- a/tests/test_block.py +++ b/tests/test_block.py @@ -10,30 +10,28 @@ from gnuradio_mcp.models import SINK, SOURCE, ParamModel @pytest.fixture -def flowgraph_middleware(platform: Platform): +def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware: flowgraph = platform.make_flow_graph("") return FlowGraphMiddleware(flowgraph) @pytest.fixture -def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str): - block_name = flowgraph_middleware.add_block(block_key) - return flowgraph_middleware._flowgraph.get_block(block_name) +def block_middleware( + flowgraph_middleware: FlowGraphMiddleware, block_key: str +) -> BlockMiddleware: + return flowgraph_middleware.add_block(block_key) -def test_block_middleware_params(block: Block): - middleware = BlockMiddleware(block) - check_param_models(block, middleware.params) +def test_block_middleware_params(block_middleware: BlockMiddleware): + check_param_models(block_middleware._block, block_middleware.params) -def test_block_middleware_sinks(block: Block): - middleware = BlockMiddleware(block) - check_port_models(middleware.sinks, block.sinks, SINK) +def test_block_middleware_sinks(block_middleware: BlockMiddleware): + check_port_models(block_middleware.sinks, block_middleware._block.sinks, SINK) -def test_block_middleware_sources(block: Block): - middleware = BlockMiddleware(block) - check_port_models(middleware.sources, block.sources, SOURCE) +def test_block_middleware_sources(block_middleware: BlockMiddleware): + check_port_models(block_middleware.sources, block_middleware._block.sources, SOURCE) def check_param_models(block: Block, params: list[ParamModel]): @@ -51,7 +49,7 @@ def check_port_models(port_models, ports, direction): assert isinstance(port_models, list) assert len(port_models) == len(ports) for model, port in zip(port_models, ports): - assert model.key == port.key + assert model.key == int(port.key) assert model.name == port.name assert model.dtype == port.dtype assert model.direction == direction diff --git a/tests/test_flowgraph.py b/tests/test_flowgraph.py index d784e4f..33a3461 100644 --- a/tests/test_flowgraph.py +++ b/tests/test_flowgraph.py @@ -1,10 +1,13 @@ from __future__ import annotations +from typing import Generator + import pytest from gnuradio.grc.core.platform import Platform +from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware -from gnuradio_mcp.models import BlockModel +from gnuradio_mcp.models import BlockModel, ConnectionModel @pytest.fixture @@ -23,26 +26,21 @@ def initial_blocks(flowgraph_middleware: FlowGraphMiddleware): def test_flowgraph_block_addition_and_removal( flowgraph_middleware: FlowGraphMiddleware, - platform: Platform, initial_blocks: list[BlockModel], block_key: str, ): - block_keys = platform.blocks.keys() - assert block_keys, "No blocks available in platform library." - block_name = f"test_block_{block_key}" - flowgraph_middleware.add_block(block_key, block_name) + explicit_name = "my_custom_block_name" + + flowgraph_middleware.add_block(block_key, explicit_name) + blocks = flowgraph_middleware.blocks assert all(b in blocks for b in initial_blocks) assert any(b.key == block_key for b in blocks) - flowgraph_middleware._flowgraph.remove_element( - flowgraph_middleware._flowgraph.get_block(block_name), - ) - current_blocks = [ - BlockModel(key=block.key, label=block.label) - for block in flowgraph_middleware._flowgraph.blocks - ] - assert current_blocks == initial_blocks + flowgraph_middleware.remove_block(explicit_name) + + blocks = flowgraph_middleware.blocks + assert all(b in initial_blocks for b in blocks) def test_flowgraph_initial_state( @@ -53,35 +51,68 @@ def test_flowgraph_initial_state( def test_block_naming(flowgraph_middleware: FlowGraphMiddleware, block_key: str): - # Explicit name explicit_name = "my_custom_block_name" - flowgraph_middleware.add_block(block_key, explicit_name) - block = flowgraph_middleware._flowgraph.get_block(explicit_name) - assert block.params["id"].get_value() == explicit_name - # Remove for clean state - flowgraph_middleware._flowgraph.remove_element(block) - # Implicit name (should use get_unique_id logic) - flowgraph_middleware.add_block(block_key) - # The last block added should be the last in the blocks list - last_block = flowgraph_middleware._flowgraph.blocks[-1] - # The id param should match the block's name - assert last_block.params["id"].get_value() == last_block.name - # Remove for clean state - flowgraph_middleware._flowgraph.remove_element(last_block) + block = flowgraph_middleware.add_block(block_key, explicit_name) + + assert block.name == explicit_name def test_block_unique_names_for_same_type( flowgraph_middleware: FlowGraphMiddleware, block_key: str ): - # Add two blocks of the same type without explicit names - flowgraph_middleware.add_block(block_key) - first_block = flowgraph_middleware._flowgraph.blocks[-1] - flowgraph_middleware.add_block(block_key) - second_block = flowgraph_middleware._flowgraph.blocks[-1] - # They should have different names + first_block = flowgraph_middleware.add_block(block_key) + second_block = flowgraph_middleware.add_block(block_key) + assert first_block.name != second_block.name - assert first_block.params["id"].get_value() != second_block.params["id"].get_value() - # Clean up - flowgraph_middleware._flowgraph.remove_element(first_block) - flowgraph_middleware._flowgraph.remove_element(second_block) + + +@pytest.mark.parametrize( + "block_key, sinks_number, sources_number", + [("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)], +) +def test_block_connections( + flowgraph_middleware: FlowGraphMiddleware, + block_key: str, + sources_number: int, + sinks_number: int, +): + source_block = flowgraph_middleware.add_block(block_key) + dest_block = flowgraph_middleware.add_block(block_key) + + for connection in util_iter_possible_connections(source_block, dest_block): + flowgraph_middleware.connect_blocks(connection.source, connection.sink) + connections = flowgraph_middleware.get_connections() + assert any( + c.source.key == connection.source.key and c.sink.key == connection.sink.key + for c in connections + ) + + assert len(flowgraph_middleware.get_connections()) == sources_number * sinks_number + + +def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str): + source_block = flowgraph_middleware.add_block(block_key) + dest_block = flowgraph_middleware.add_block(block_key) + + for connection in util_iter_possible_connections(source_block, dest_block): + flowgraph_middleware.connect_blocks(connection.source, connection.sink) + + for connection in util_iter_possible_connections(source_block, dest_block): + flowgraph_middleware.disconnect_blocks(connection.source, connection.sink) + connections = flowgraph_middleware.get_connections() + assert not any( + c.source.key == connection.source.key and c.sink.key == connection.sink.key + for c in connections + ) + + assert len(flowgraph_middleware.get_connections()) == 0 + + +def util_iter_possible_connections( + source_block: BlockMiddleware, + dest_block: BlockMiddleware, +) -> Generator[ConnectionModel]: + for sink in source_block.sinks: + for source in dest_block.sources: + yield ConnectionModel(source=source, sink=sink)