main - fiat: Implement BlockMiddleware and reorganize tests

This commit is contained in:
Yoel Bassin 2025-04-26 21:25:28 +03:00
parent c5b4be6950
commit 37cb51f056
9 changed files with 192 additions and 73 deletions

View File

@ -0,0 +1,21 @@
from typing import List
from gnuradio.grc.core.blocks.block import Block
from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel
class BlockMiddleware:
def __init__(self, block: Block):
self._block = block
@property
def params(self) -> List[ParamModel]:
return [ParamModel.from_param(param) for param in self._block.params.values()]
@property
def sinks(self) -> List[PortModel]:
return [PortModel.from_port(port, SINK) for port in self._block.sinks]
@property
def sources(self) -> List[PortModel]:
return [PortModel.from_port(port, SOURCE) for port in self._block.sources]

View File

@ -1,6 +1,6 @@
from typing import List, Optional from typing import List, Optional
from gnuradio.grc.core.FlowGraph import FlowGraph from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.models import BlockModel from gnuradio_mcp.models import BlockModel
@ -19,7 +19,10 @@ class FlowGraphMiddleware:
block = self._flowgraph.new_block(block_type) block = self._flowgraph.new_block(block_type)
block.params["id"].set_value(block_name) block.params["id"].set_value(block_name)
def remove_block(self, block_name: str) -> None: def remove_block(self, block_name: str) -> None:
block = self._flowgraph.get_block(block_name) block = self._flowgraph.get_block(block_name)
self._flowgraph.remove_element(block) self._flowgraph.remove_element(block)
def get_block(self, block_name: str) -> BlockMiddleware:
block = self._flowgraph.get_block(block_name)
return BlockMiddleware(block)

View File

@ -1,6 +1,8 @@
from typing import Any, Literal, Optional, get_args
from pydantic import BaseModel from pydantic import BaseModel
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.ports.port import Port
from gnuradio.grc.core.params.param import Param
class BlockModel(BaseModel): class BlockModel(BaseModel):
label: str label: str
@ -9,3 +11,48 @@ class BlockModel(BaseModel):
@classmethod @classmethod
def from_block(cls, block: Block) -> "BlockModel": def from_block(cls, block: Block) -> "BlockModel":
return cls(label=block.label, key=block.key) return cls(label=block.label, key=block.key)
class ParamModel(BaseModel):
parent: str
key: str
name: str
dtype: str
value: Any
@classmethod
def from_param(cls, param: Param) -> "ParamModel":
return cls(
parent=param.parent.name,
key=param.key,
name=param.name,
dtype=param.dtype,
value=param.get_value(),
)
DirectionType = Literal["sink", "source"]
SINK, SOURCE = get_args(DirectionType)
class PortModel(BaseModel):
parent: str
key: str
name: str
dtype: str
direction: DirectionType
optional: bool = False
@classmethod
def from_port(
cls, port: Port, direction: Optional[DirectionType] = None
) -> "PortModel":
direction = direction or port._dir
return cls(
parent=port.parent.name,
key=port.key,
name=port.name,
dtype=port.dtype,
direction=direction,
optional=port.optional,
)

0
tests/__init__.py Normal file
View File

22
tests/conftest.py Normal file
View File

@ -0,0 +1,22 @@
import pytest
from gnuradio.grc.core.platform import Platform
from gnuradio import gr
@pytest.fixture(scope="module")
def platform() -> Platform:
platform = Platform(
version=gr.version(),
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
prefs=gr.prefs(),
)
platform.build_library()
return platform
@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test
def block_key(platform, request):
block_keys = list(platform.blocks.keys())
assert block_keys, "No blocks available in platform library."
idx = request.param
if idx < len(block_keys):
return block_keys[idx]
return block_keys[0]

58
tests/test_block.py Normal file
View File

@ -0,0 +1,58 @@
from typing import List
import pytest
from gnuradio.grc.core.platform import Platform
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio import gr
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import SINK, SOURCE, ParamModel
@pytest.fixture
def flowgraph_middleware(platform: Platform):
flowgraph = platform.make_flow_graph("")
return FlowGraphMiddleware(flowgraph)
@pytest.fixture
def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
flowgraph_middleware.add_block(block_key, block_key)
return flowgraph_middleware._flowgraph.get_block(block_key)
def test_block_middleware_params(block: Block):
middleware = BlockMiddleware(block)
check_param_models(block, middleware.params)
def test_block_middleware_sinks(block: Block):
middleware = BlockMiddleware(block)
check_port_models(middleware.sinks, block.sinks, SINK)
def test_block_middleware_sources(block: Block):
middleware = BlockMiddleware(block)
check_port_models(middleware.sources, block.sources, SOURCE)
def check_param_models(block: Block, params: List[ParamModel]):
assert params
assert len(params) == len(block.params)
for param in params:
original_param = block.params[param.key]
assert param.key == original_param.key
assert param.name == original_param.name
assert param.dtype == original_param.dtype
assert param.value == original_param.value
def check_port_models(port_models, ports, direction):
assert isinstance(port_models, list)
assert len(port_models) == len(ports)
for model, port in zip(port_models, ports):
assert model.key == port.key
assert model.name == port.name
assert model.dtype == port.dtype
assert model.direction == direction
assert model.optional == port.optional

View File

@ -1,86 +1,46 @@
from typing import List
import pytest import pytest
from gnuradio.grc.core.platform import Platform from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio import gr from gnuradio import gr
from gnuradio_mcp.models import BlockModel from gnuradio_mcp.models import BlockModel
from .utils import get_block_keys, add_block, remove_block, get_current_blocks
BLOCK_KEYS_TO_TEST = [1, 10]
@pytest.fixture(scope="module")
def platform() -> Platform:
platform = Platform(
version=gr.version(),
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
prefs=gr.prefs(),
)
platform.build_library()
return platform
@pytest.fixture @pytest.fixture
def flowgraph_middleware(platform): def flowgraph_middleware(platform: Platform):
flowgraph = platform.make_flow_graph("") flowgraph = platform.make_flow_graph("")
return FlowGraphMiddleware(flowgraph) return FlowGraphMiddleware(flowgraph)
@pytest.fixture @pytest.fixture
def initial_blocks(flowgraph_middleware): def initial_blocks(flowgraph_middleware: FlowGraphMiddleware):
return [ return [
BlockModel(key=block.key, label=block.label) BlockModel(key=block.key, label=block.label)
for block in flowgraph_middleware._flowgraph.blocks for block in flowgraph_middleware._flowgraph.blocks
] ]
def get_block_keys(platform): def test_flowgraph_block_addition_and_removal(
return list(platform.blocks.keys()) flowgraph_middleware: FlowGraphMiddleware,
platform: Platform,
initial_blocks: List[BlockModel],
def add_block(flowgraph_middleware, block_key, block_name): block_key: str,
flowgraph_middleware.add_block(block_key, block_name)
def remove_block(flowgraph_middleware, block_name):
flowgraph_middleware._flowgraph.remove_element(
flowgraph_middleware._flowgraph.get_block(block_name)
)
def get_current_blocks(flowgraph_middleware):
return [
BlockModel(key=block.key, label=block.label)
for block in flowgraph_middleware._flowgraph.blocks
]
def test_blocks_match_initial_state(flowgraph_middleware, initial_blocks):
assert flowgraph_middleware.blocks == initial_blocks
@pytest.mark.parametrize("block_index", BLOCK_KEYS_TO_TEST)
def test_add_block_preserves_and_adds(
flowgraph_middleware, platform, initial_blocks, block_index
): ):
block_keys = get_block_keys(platform) block_keys = get_block_keys(platform)
assert block_keys, "No blocks available in platform library." assert block_keys, "No blocks available in platform library."
block_key = block_keys[block_index] block_name = f"test_block_{block_key}"
block_name = f"test_block_{block_index}"
add_block(flowgraph_middleware, block_key, block_name) add_block(flowgraph_middleware, block_key, block_name)
blocks = flowgraph_middleware.blocks blocks = flowgraph_middleware.blocks
assert all(b in blocks for b in initial_blocks) assert all(b in blocks for b in initial_blocks)
assert any(b.key == block_key for b in blocks) assert any(b.key == block_key for b in blocks)
@pytest.mark.parametrize("block_index", BLOCK_KEYS_TO_TEST)
def test_remove_block_restores_initial(
flowgraph_middleware, platform, initial_blocks, block_index
):
block_keys = get_block_keys(platform)
block_key = block_keys[block_index]
block_name = f"block_to_remove_{block_index}"
add_block(flowgraph_middleware, block_key, block_name)
assert any(b.key == block_key for b in flowgraph_middleware.blocks)
remove_block(flowgraph_middleware, block_name) remove_block(flowgraph_middleware, block_name)
current_blocks = get_current_blocks(flowgraph_middleware) current_blocks = get_current_blocks(flowgraph_middleware)
assert current_blocks == initial_blocks assert current_blocks == initial_blocks
def test_flowgraph_initial_state(
flowgraph_middleware: FlowGraphMiddleware, initial_blocks: List[BlockModel]
):
assert flowgraph_middleware.blocks == initial_blocks

View File

@ -4,25 +4,15 @@ from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.platform import Platform from gnuradio.grc.core.platform import Platform
from gnuradio import gr from gnuradio import gr
@pytest.fixture
def platform() -> Platform:
platform = Platform(
version=gr.version(),
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
prefs=gr.prefs(),
)
platform.build_library()
return platform
def test_block_model_from_block(platform: Platform):
def test_block_model_from_block(platform):
block = Block(platform) block = Block(platform)
model = BlockModel.from_block(block) model = BlockModel.from_block(block)
assert model.label == block.label assert model.label == block.label
assert model.key == block.key assert model.key == block.key
def test_platform_middleware_blocks(platform): def test_platform_middleware_blocks(platform: Platform):
middleware = PlatformMiddleware(platform) middleware = PlatformMiddleware(platform)
block_models = middleware.blocks block_models = middleware.blocks
assert block_models # Checks that the list is not empty assert block_models # Checks that the list is not empty

18
tests/utils.py Normal file
View File

@ -0,0 +1,18 @@
from gnuradio_mcp.models import BlockModel
def get_block_keys(platform):
return list(platform.blocks.keys())
def add_block(flowgraph_middleware, block_key, block_name):
flowgraph_middleware.add_block(block_key, block_name)
def remove_block(flowgraph_middleware, block_name):
flowgraph_middleware._flowgraph.remove_element(
flowgraph_middleware._flowgraph.get_block(block_name)
)
def get_current_blocks(flowgraph_middleware):
return [
BlockModel(key=block.key, label=block.label)
for block in flowgraph_middleware._flowgraph.blocks
]