diff --git a/src/gnuradio_mcp/middlewares/block.py b/src/gnuradio_mcp/middlewares/block.py index 2ef37ad..a566095 100644 --- a/src/gnuradio_mcp/middlewares/block.py +++ b/src/gnuradio_mcp/middlewares/block.py @@ -38,11 +38,9 @@ class BlockMiddleware(ElementMiddleware): self._rewrite() ports = [] for port in self._block.sinks: - try: - port_model = PortModel.from_port(port, SINK) + port_model = PortModel.from_port(port, SINK) + if not port_model.hidden: ports.append(port_model) - except ValueError: - pass return ports @property @@ -50,9 +48,7 @@ class BlockMiddleware(ElementMiddleware): self._rewrite() ports = [] for port in self._block.sources: - try: - port_model = PortModel.from_port(port, SOURCE) + port_model = PortModel.from_port(port, SOURCE) + if not port_model.hidden: ports.append(port_model) - except ValueError: - pass return ports diff --git a/src/gnuradio_mcp/middlewares/flowgraph.py b/src/gnuradio_mcp/middlewares/flowgraph.py index c1038c4..a827937 100644 --- a/src/gnuradio_mcp/middlewares/flowgraph.py +++ b/src/gnuradio_mcp/middlewares/flowgraph.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Set +from functools import lru_cache +from typing import TYPE_CHECKING, Optional from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.FlowGraph import FlowGraph @@ -24,7 +25,7 @@ def get_port_from_port_model_in_port_list( raise ValueError(f"Port not found: {port_model.key}") -def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block: +def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port: block_from_port_model = flowgraph.get_block(port_model.parent) if port_model.direction == SOURCE: return get_port_from_port_model_in_port_list( @@ -38,11 +39,14 @@ def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block: raise ValueError(f"Invalid port direction: {port_model.direction}") +def set_block_name(block: Block, name: str): + block.params["id"].set_value(name) + + class FlowGraphMiddleware(ElementMiddleware): def __init__(self, flowgraph: FlowGraph): super().__init__(flowgraph) self._flowgraph = self._element - self._blocks: Set[BlockMiddleware] = set() @property def blocks(self) -> list[BlockModel]: @@ -50,22 +54,22 @@ class FlowGraphMiddleware(ElementMiddleware): def add_block( self, block_type: str, block_name: Optional[str] = None - ) -> BlockMiddleware: + ) -> BlockModel: block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type) block = self._flowgraph.new_block(block_type) assert block is not None, f"Failed to create block: {block_type}" - block_middleware = BlockMiddleware(block) - block_middleware.name = block_name - self._blocks.add(block_middleware) - return block_middleware + set_block_name(block, block_name) + return BlockModel.from_block(block) def remove_block(self, block_name: str) -> None: block_middleware = self.get_block(block_name) self._flowgraph.remove_element(block_middleware._block) - self._blocks.remove(block_middleware) + @lru_cache(maxsize=None) def get_block(self, block_name: str) -> BlockMiddleware: - return next(block for block in self._blocks if block.name == block_name) + return BlockMiddleware( + next(block for block in self._flowgraph.blocks if block.name == block_name) + ) def connect_blocks( self, src_port_model: PortModel, dst_port_model: PortModel diff --git a/src/gnuradio_mcp/models.py b/src/gnuradio_mcp/models.py index 98bd97c..cb6e613 100644 --- a/src/gnuradio_mcp/models.py +++ b/src/gnuradio_mcp/models.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import Any, Literal, Type, get_args +from typing import Any, Literal, Protocol, Type, get_args from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.Connection import Connection from gnuradio.grc.core.params.param import Param from gnuradio.grc.core.ports.port import Port -from pydantic import BaseModel +from pydantic import BaseModel, field_validator class BlockTypeModel(BaseModel): @@ -18,6 +18,11 @@ class BlockTypeModel(BaseModel): return cls(label=block.label, key=block.key) +class KeyedModel(Protocol): + def to_key(self) -> str: + ... + + class BlockModel(BaseModel): label: str name: str @@ -26,6 +31,9 @@ class BlockModel(BaseModel): def from_block(cls, block: Block) -> BlockModel: return cls(label=block.label, name=block.name) + def to_key(self) -> str: + return f"{self.label}:{self.name}" + class ParamModel(BaseModel): parent: str @@ -44,6 +52,9 @@ class ParamModel(BaseModel): value=param.get_value(), ) + def to_key(self) -> str: + return f"{self.parent}:{self.key}" + DirectionType = Literal["sink", "source"] SINK, SOURCE = get_args(DirectionType) @@ -56,6 +67,7 @@ class PortModel(BaseModel): dtype: str direction: DirectionType optional: bool = False + hidden: bool = False @classmethod def from_port( @@ -64,8 +76,6 @@ class PortModel(BaseModel): direction: DirectionType | None = None, ) -> PortModel: direction = direction or port._dir - if not port.key.isnumeric(): - raise ValueError("Currently not supporting named ports") return cls( parent=port.parent.name, key=port.key, @@ -73,8 +83,12 @@ class PortModel(BaseModel): dtype=port.dtype, direction=direction, optional=port.optional, + hidden=port.hidden, ) + def to_key(self) -> str: + return f"{self.parent}:{self.direction}[{self.key}]" + class ConnectionModel(BaseModel): source: PortModel @@ -87,8 +101,16 @@ class ConnectionModel(BaseModel): sink=PortModel.from_port(connection.sink_port), ) + def to_key(self) -> str: + return f"{self.source.to_key()}-{self.sink.to_key()}" + class ErrorModel(BaseModel): type: str - key: BaseModel + key: str message: str + + @field_validator("key", mode="before") + @classmethod + def transform_key(cls, v: KeyedModel) -> str: + return v.to_key() diff --git a/src/gnuradio_mcp/utils.py b/src/gnuradio_mcp/utils.py index 800d37d..1c8656e 100644 --- a/src/gnuradio_mcp/utils.py +++ b/src/gnuradio_mcp/utils.py @@ -45,6 +45,6 @@ def format_error_message(elem, msg) -> ErrorModel: raise ValueError(f"Unsupported element type: {type(elem)}") return ErrorModel( type=type(model).__name__, - key=model, + key=model, # type: ignore message=msg, ) diff --git a/tests/unit/test_block.py b/tests/unit/test_block.py index 18b0689..6f2931c 100644 --- a/tests/unit/test_block.py +++ b/tests/unit/test_block.py @@ -19,7 +19,8 @@ def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware: def block_middleware( flowgraph_middleware: FlowGraphMiddleware, block_key: str ) -> BlockMiddleware: - return flowgraph_middleware.add_block(block_key) + block_model = flowgraph_middleware.add_block(block_key) + return flowgraph_middleware.get_block(block_model.name) def test_block_middleware_params(block_middleware: BlockMiddleware): @@ -52,7 +53,8 @@ def test_block_errors( block_key: str, initial_errors_number: int, ): - block_middleware = flowgraph_middleware.add_block(block_key) + block_model = flowgraph_middleware.add_block(block_key) + block_middleware = flowgraph_middleware.get_block(block_model.name) for error in block_middleware.get_all_errors(): assert isinstance(error, ErrorModel) assert len(block_middleware.get_all_errors()) == initial_errors_number diff --git a/tests/unit/test_flowgraph.py b/tests/unit/test_flowgraph.py index 643b452..159fb39 100644 --- a/tests/unit/test_flowgraph.py +++ b/tests/unit/test_flowgraph.py @@ -66,8 +66,17 @@ def test_block_unique_names_for_same_type( def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str): - block = flowgraph_middleware.add_block(block_key) - assert block_key in block.name + block_model = flowgraph_middleware.add_block(block_key) + assert block_key in block_model.name + + +def test_remove_existing_block(flowgraph_middleware: FlowGraphMiddleware): + DEFAULT_VARIABLE_BLOCK_NAME = "samp_rate" + flowgraph_middleware.remove_block(DEFAULT_VARIABLE_BLOCK_NAME) + assert not any( + block.name == DEFAULT_VARIABLE_BLOCK_NAME + for block in flowgraph_middleware.blocks + ) @pytest.mark.parametrize( @@ -80,11 +89,11 @@ def test_block_connections( sources_number: int, sinks_number: int, ): - source_block = flowgraph_middleware.add_block(block_key) - dest_block = flowgraph_middleware.add_block(block_key) + source_block, dest_block = create_and_connect_blocks( + flowgraph_middleware, block_key, block_key + ) for connection in util_iter_possible_connections(source_block, dest_block): - flowgraph_middleware.connect_blocks(connection.source, connection.sink) connections = flowgraph_middleware.get_connections() assert any( c.source.key == connection.source.key and c.sink.key == connection.sink.key @@ -95,11 +104,9 @@ def test_block_connections( def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str): - source_block = flowgraph_middleware.add_block(block_key) - dest_block = flowgraph_middleware.add_block(block_key) - - for connection in util_iter_possible_connections(source_block, dest_block): - flowgraph_middleware.connect_blocks(connection.source, connection.sink) + source_block, dest_block = create_and_connect_blocks( + flowgraph_middleware, block_key, block_key + ) for connection in util_iter_possible_connections(source_block, dest_block): flowgraph_middleware.disconnect_blocks(connection.source, connection.sink) @@ -118,6 +125,23 @@ def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware): assert len(flowgraph_middleware.get_all_errors()) == 0 +def create_and_connect_blocks( + flowgraph_middleware: FlowGraphMiddleware, + source_block_key: str, + dest_block_key: str, +): + source_block_model = flowgraph_middleware.add_block(source_block_key) + dest_block_model = flowgraph_middleware.add_block(dest_block_key) + + source_block = flowgraph_middleware.get_block(source_block_model.name) + dest_block = flowgraph_middleware.get_block(dest_block_model.name) + + for connection in util_iter_possible_connections(source_block, dest_block): + flowgraph_middleware.connect_blocks(connection.source, connection.sink) + + return source_block, dest_block + + def util_iter_possible_connections( source_block: BlockMiddleware, dest_block: BlockMiddleware,