diff --git a/src/gnuradio_mcp/middlewares/base.py b/src/gnuradio_mcp/middlewares/base.py index b6340e8..de0f943 100644 --- a/src/gnuradio_mcp/middlewares/base.py +++ b/src/gnuradio_mcp/middlewares/base.py @@ -11,9 +11,10 @@ class ElementMiddleware: def _rewrite(self): self._element.rewrite() - def validate(self): + def validate(self) -> bool: self._rewrite() self._element.validate() + return self._element.is_valid() def get_all_errors(self) -> list[ErrorModel]: self.validate() diff --git a/src/gnuradio_mcp/middlewares/flowgraph.py b/src/gnuradio_mcp/middlewares/flowgraph.py index 5b0ba49..9cdff8b 100644 --- a/src/gnuradio_mcp/middlewares/flowgraph.py +++ b/src/gnuradio_mcp/middlewares/flowgraph.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Set from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.FlowGraph import FlowGraph +from gnuradio.grc.core.ports.port import Port from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware @@ -14,12 +15,25 @@ if TYPE_CHECKING: from gnuradio_mcp.middlewares.platform import PlatformMiddleware +def get_port_from_port_model_in_port_list( + port_list: list[Port], port_model: PortModel +) -> Block: + for port in port_list: + if port.key == port_model.key: + return port + raise ValueError(f"Port not found: {port_model.key}") + + def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block: block_from_port_model = flowgraph.get_block(port_model.parent) if port_model.direction == SOURCE: - return block_from_port_model.sources[port_model.key] + return get_port_from_port_model_in_port_list( + block_from_port_model.sources, port_model + ) elif port_model.direction == SINK: - return block_from_port_model.sinks[port_model.key] + return get_port_from_port_model_in_port_list( + block_from_port_model.sinks, port_model + ) else: raise ValueError(f"Invalid port direction: {port_model.direction}") diff --git a/src/gnuradio_mcp/middlewares/platform.py b/src/gnuradio_mcp/middlewares/platform.py index f18c9ed..4e43158 100644 --- a/src/gnuradio_mcp/middlewares/platform.py +++ b/src/gnuradio_mcp/middlewares/platform.py @@ -4,7 +4,7 @@ from gnuradio.grc.core.platform import Platform from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware -from gnuradio_mcp.models import BlockModel +from gnuradio_mcp.models import BlockTypeModel class PlatformMiddleware(ElementMiddleware): @@ -13,10 +13,14 @@ class PlatformMiddleware(ElementMiddleware): self._platform = self._element @property - def blocks(self) -> list[BlockModel]: + def blocks(self) -> list[BlockTypeModel]: return [ - BlockModel.from_block(block) for block in self._platform.blocks.values() + BlockTypeModel.from_block_type(block) + for block in self._platform.blocks.values() ] def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware: return FlowGraphMiddleware.from_file(self, filepath) + + def save_flowgraph(self, filepath: str, flowgraph: FlowGraphMiddleware) -> None: + self._platform.save_flow_graph(filepath, flowgraph._flowgraph) diff --git a/src/gnuradio_mcp/models.py b/src/gnuradio_mcp/models.py index f9fb679..80f6ed1 100644 --- a/src/gnuradio_mcp/models.py +++ b/src/gnuradio_mcp/models.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal, get_args +from typing import Any, Literal, Type, get_args from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.Connection import Connection @@ -9,13 +9,22 @@ from gnuradio.grc.core.ports.port import Port from pydantic import BaseModel +class BlockTypeModel(BaseModel): + label: str + key: str + + @classmethod + def from_block_type(cls, block: Type[Block]) -> BlockTypeModel: + return cls(label=block.label, key=block.key) + + class BlockModel(BaseModel): label: str key: str @classmethod def from_block(cls, block: Block) -> BlockModel: - return cls(label=block.label, key=block.key) + return cls(label=block.label, key=block.name) class ParamModel(BaseModel): @@ -42,7 +51,7 @@ SINK, SOURCE = get_args(DirectionType) class PortModel(BaseModel): parent: str - key: int + key: str name: str dtype: str direction: DirectionType @@ -59,7 +68,7 @@ class PortModel(BaseModel): raise ValueError("Currently not supporting named ports") return cls( parent=port.parent.name, - key=int(port.key), + key=port.key, name=port.name, dtype=port.dtype, direction=direction, diff --git a/tests/unit/test_block.py b/tests/unit/test_block.py index 53e4a03..18b0689 100644 --- a/tests/unit/test_block.py +++ b/tests/unit/test_block.py @@ -75,7 +75,7 @@ def check_port_models(port_models, ports, direction): assert isinstance(port_models, list) assert len(port_models) == len(ports) for model, port in zip(port_models, ports): - assert model.key == int(port.key) + assert model.key == port.key assert model.name == port.name assert model.dtype == port.dtype assert model.direction == direction diff --git a/tests/unit/test_platform.py b/tests/unit/test_platform.py index 4ccf029..9d59e1c 100644 --- a/tests/unit/test_platform.py +++ b/tests/unit/test_platform.py @@ -3,18 +3,18 @@ from __future__ import annotations from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.platform import Platform -from gnuradio_mcp.middlewares.platform import BlockModel, PlatformMiddleware +from gnuradio_mcp.middlewares.platform import BlockTypeModel, PlatformMiddleware def test_block_model_from_block(platform: Platform): - block = Block(platform) - model = BlockModel.from_block(block) - assert model.label == block.label - assert model.key == block.key + block_type = Block + model = BlockTypeModel.from_block_type(block_type) + assert model.label == block_type.label + assert model.key == block_type.key def test_platform_middleware_blocks(platform: Platform): middleware = PlatformMiddleware(platform) block_models = middleware.blocks assert block_models # Checks that the list is not empty - assert all(isinstance(block_model, BlockModel) for block_model in block_models) + assert all(isinstance(block_model, BlockTypeModel) for block_model in block_models)