diff --git a/src/gnuradio_mcp/middlewares/block.py b/src/gnuradio_mcp/middlewares/block.py new file mode 100644 index 0000000..058bbec --- /dev/null +++ b/src/gnuradio_mcp/middlewares/block.py @@ -0,0 +1,21 @@ +from typing import List +from gnuradio.grc.core.blocks.block import Block + +from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel + + +class BlockMiddleware: + def __init__(self, block: Block): + self._block = block + + @property + def params(self) -> List[ParamModel]: + return [ParamModel.from_param(param) for param in self._block.params.values()] + + @property + def sinks(self) -> List[PortModel]: + return [PortModel.from_port(port, SINK) for port in self._block.sinks] + + @property + def sources(self) -> List[PortModel]: + return [PortModel.from_port(port, SOURCE) for port in self._block.sources] diff --git a/src/gnuradio_mcp/middlewares/flowgraph.py b/src/gnuradio_mcp/middlewares/flowgraph.py index 1694240..6776cd3 100644 --- a/src/gnuradio_mcp/middlewares/flowgraph.py +++ b/src/gnuradio_mcp/middlewares/flowgraph.py @@ -1,6 +1,6 @@ - from typing import List, Optional from gnuradio.grc.core.FlowGraph import FlowGraph +from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.models import BlockModel @@ -14,12 +14,15 @@ class FlowGraphMiddleware: BlockModel(key=block.key, label=block.label) for block in self._flowgraph.blocks ] - + def add_block(self, block_type: str, block_name: str) -> None: block = self._flowgraph.new_block(block_type) block.params["id"].set_value(block_name) - def remove_block(self, block_name: str) -> None: block = self._flowgraph.get_block(block_name) self._flowgraph.remove_element(block) + + def get_block(self, block_name: str) -> BlockMiddleware: + block = self._flowgraph.get_block(block_name) + return BlockMiddleware(block) diff --git a/src/gnuradio_mcp/models.py b/src/gnuradio_mcp/models.py index f0efefb..4649d69 100644 --- a/src/gnuradio_mcp/models.py +++ b/src/gnuradio_mcp/models.py @@ -1,6 +1,8 @@ +from typing import Any, Literal, Optional, get_args from pydantic import BaseModel from gnuradio.grc.core.blocks.block import Block - +from gnuradio.grc.core.ports.port import Port +from gnuradio.grc.core.params.param import Param class BlockModel(BaseModel): label: str @@ -8,4 +10,49 @@ class BlockModel(BaseModel): @classmethod def from_block(cls, block: Block) -> "BlockModel": - return cls(label=block.label, key=block.key) \ No newline at end of file + return cls(label=block.label, key=block.key) + + +class ParamModel(BaseModel): + parent: str + key: str + name: str + dtype: str + value: Any + + @classmethod + def from_param(cls, param: Param) -> "ParamModel": + return cls( + parent=param.parent.name, + key=param.key, + name=param.name, + dtype=param.dtype, + value=param.get_value(), + ) + + + +DirectionType = Literal["sink", "source"] +SINK, SOURCE = get_args(DirectionType) + +class PortModel(BaseModel): + parent: str + key: str + name: str + dtype: str + direction: DirectionType + optional: bool = False + + @classmethod + def from_port( + cls, port: Port, direction: Optional[DirectionType] = None + ) -> "PortModel": + direction = direction or port._dir + return cls( + parent=port.parent.name, + key=port.key, + name=port.name, + dtype=port.dtype, + direction=direction, + optional=port.optional, + ) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f377ce2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +import pytest +from gnuradio.grc.core.platform import Platform +from gnuradio import gr + +@pytest.fixture(scope="module") +def platform() -> Platform: + platform = Platform( + version=gr.version(), + version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()), + prefs=gr.prefs(), + ) + platform.build_library() + return platform + +@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test +def block_key(platform, request): + block_keys = list(platform.blocks.keys()) + assert block_keys, "No blocks available in platform library." + idx = request.param + if idx < len(block_keys): + return block_keys[idx] + return block_keys[0] \ No newline at end of file diff --git a/tests/test_block.py b/tests/test_block.py new file mode 100644 index 0000000..1b9a697 --- /dev/null +++ b/tests/test_block.py @@ -0,0 +1,58 @@ +from typing import List +import pytest +from gnuradio.grc.core.platform import Platform +from gnuradio.grc.core.blocks.block import Block +from gnuradio.grc.core.FlowGraph import FlowGraph +from gnuradio import gr +from gnuradio_mcp.middlewares.block import BlockMiddleware +from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware +from gnuradio_mcp.models import SINK, SOURCE, ParamModel + + +@pytest.fixture +def flowgraph_middleware(platform: Platform): + flowgraph = platform.make_flow_graph("") + return FlowGraphMiddleware(flowgraph) + + +@pytest.fixture +def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str): + flowgraph_middleware.add_block(block_key, block_key) + return flowgraph_middleware._flowgraph.get_block(block_key) + + +def test_block_middleware_params(block: Block): + middleware = BlockMiddleware(block) + check_param_models(block, middleware.params) + + +def test_block_middleware_sinks(block: Block): + middleware = BlockMiddleware(block) + check_port_models(middleware.sinks, block.sinks, SINK) + + +def test_block_middleware_sources(block: Block): + middleware = BlockMiddleware(block) + check_port_models(middleware.sources, block.sources, SOURCE) + + +def check_param_models(block: Block, params: List[ParamModel]): + assert params + assert len(params) == len(block.params) + for param in params: + original_param = block.params[param.key] + assert param.key == original_param.key + assert param.name == original_param.name + assert param.dtype == original_param.dtype + assert param.value == original_param.value + + +def check_port_models(port_models, ports, direction): + assert isinstance(port_models, list) + assert len(port_models) == len(ports) + for model, port in zip(port_models, ports): + assert model.key == port.key + assert model.name == port.name + assert model.dtype == port.dtype + assert model.direction == direction + assert model.optional == port.optional diff --git a/tests/test_flowgraph.py b/tests/test_flowgraph.py index 79b3235..6375503 100644 --- a/tests/test_flowgraph.py +++ b/tests/test_flowgraph.py @@ -1,86 +1,46 @@ +from typing import List import pytest from gnuradio.grc.core.platform import Platform from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio import gr - from gnuradio_mcp.models import BlockModel - -BLOCK_KEYS_TO_TEST = [1, 10] - - -@pytest.fixture(scope="module") -def platform() -> Platform: - platform = Platform( - version=gr.version(), - version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()), - prefs=gr.prefs(), - ) - platform.build_library() - return platform +from .utils import get_block_keys, add_block, remove_block, get_current_blocks @pytest.fixture -def flowgraph_middleware(platform): +def flowgraph_middleware(platform: Platform): flowgraph = platform.make_flow_graph("") return FlowGraphMiddleware(flowgraph) @pytest.fixture -def initial_blocks(flowgraph_middleware): +def initial_blocks(flowgraph_middleware: FlowGraphMiddleware): return [ BlockModel(key=block.key, label=block.label) for block in flowgraph_middleware._flowgraph.blocks ] -def get_block_keys(platform): - return list(platform.blocks.keys()) - - -def add_block(flowgraph_middleware, block_key, block_name): - flowgraph_middleware.add_block(block_key, block_name) - - -def remove_block(flowgraph_middleware, block_name): - flowgraph_middleware._flowgraph.remove_element( - flowgraph_middleware._flowgraph.get_block(block_name) - ) - - -def get_current_blocks(flowgraph_middleware): - return [ - BlockModel(key=block.key, label=block.label) - for block in flowgraph_middleware._flowgraph.blocks - ] - - -def test_blocks_match_initial_state(flowgraph_middleware, initial_blocks): - assert flowgraph_middleware.blocks == initial_blocks - - -@pytest.mark.parametrize("block_index", BLOCK_KEYS_TO_TEST) -def test_add_block_preserves_and_adds( - flowgraph_middleware, platform, initial_blocks, block_index +def test_flowgraph_block_addition_and_removal( + flowgraph_middleware: FlowGraphMiddleware, + platform: Platform, + initial_blocks: List[BlockModel], + block_key: str, ): block_keys = get_block_keys(platform) assert block_keys, "No blocks available in platform library." - block_key = block_keys[block_index] - block_name = f"test_block_{block_index}" + block_name = f"test_block_{block_key}" add_block(flowgraph_middleware, block_key, block_name) blocks = flowgraph_middleware.blocks assert all(b in blocks for b in initial_blocks) assert any(b.key == block_key for b in blocks) - -@pytest.mark.parametrize("block_index", BLOCK_KEYS_TO_TEST) -def test_remove_block_restores_initial( - flowgraph_middleware, platform, initial_blocks, block_index -): - block_keys = get_block_keys(platform) - block_key = block_keys[block_index] - block_name = f"block_to_remove_{block_index}" - add_block(flowgraph_middleware, block_key, block_name) - assert any(b.key == block_key for b in flowgraph_middleware.blocks) remove_block(flowgraph_middleware, block_name) current_blocks = get_current_blocks(flowgraph_middleware) assert current_blocks == initial_blocks + + +def test_flowgraph_initial_state( + flowgraph_middleware: FlowGraphMiddleware, initial_blocks: List[BlockModel] +): + assert flowgraph_middleware.blocks == initial_blocks diff --git a/tests/test_platform.py b/tests/test_platform.py index 455d5f4..d0d4ee8 100644 --- a/tests/test_platform.py +++ b/tests/test_platform.py @@ -4,25 +4,15 @@ from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.platform import Platform from gnuradio import gr -@pytest.fixture -def platform() -> Platform: - platform = Platform( - version=gr.version(), - version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()), - prefs=gr.prefs(), - ) - platform.build_library() - return platform - -def test_block_model_from_block(platform): +def test_block_model_from_block(platform: Platform): block = Block(platform) model = BlockModel.from_block(block) assert model.label == block.label assert model.key == block.key -def test_platform_middleware_blocks(platform): +def test_platform_middleware_blocks(platform: Platform): middleware = PlatformMiddleware(platform) block_models = middleware.blocks assert block_models # Checks that the list is not empty diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..fe03ede --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,18 @@ +from gnuradio_mcp.models import BlockModel + +def get_block_keys(platform): + return list(platform.blocks.keys()) + +def add_block(flowgraph_middleware, block_key, block_name): + flowgraph_middleware.add_block(block_key, block_name) + +def remove_block(flowgraph_middleware, block_name): + flowgraph_middleware._flowgraph.remove_element( + flowgraph_middleware._flowgraph.get_block(block_name) + ) + +def get_current_blocks(flowgraph_middleware): + return [ + BlockModel(key=block.key, label=block.label) + for block in flowgraph_middleware._flowgraph.blocks + ] \ No newline at end of file