main - fiat: Implement BlockMiddleware and reorganize tests
This commit is contained in:
parent
c5b4be6950
commit
37cb51f056
21
src/gnuradio_mcp/middlewares/block.py
Normal file
21
src/gnuradio_mcp/middlewares/block.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import List
|
||||||
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
|
|
||||||
|
from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel
|
||||||
|
|
||||||
|
|
||||||
|
class BlockMiddleware:
|
||||||
|
def __init__(self, block: Block):
|
||||||
|
self._block = block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def params(self) -> List[ParamModel]:
|
||||||
|
return [ParamModel.from_param(param) for param in self._block.params.values()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sinks(self) -> List[PortModel]:
|
||||||
|
return [PortModel.from_port(port, SINK) for port in self._block.sinks]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sources(self) -> List[PortModel]:
|
||||||
|
return [PortModel.from_port(port, SOURCE) for port in self._block.sources]
|
||||||
@ -1,6 +1,6 @@
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from gnuradio.grc.core.FlowGraph import FlowGraph
|
from gnuradio.grc.core.FlowGraph import FlowGraph
|
||||||
|
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||||
from gnuradio_mcp.models import BlockModel
|
from gnuradio_mcp.models import BlockModel
|
||||||
|
|
||||||
|
|
||||||
@ -14,12 +14,15 @@ class FlowGraphMiddleware:
|
|||||||
BlockModel(key=block.key, label=block.label)
|
BlockModel(key=block.key, label=block.label)
|
||||||
for block in self._flowgraph.blocks
|
for block in self._flowgraph.blocks
|
||||||
]
|
]
|
||||||
|
|
||||||
def add_block(self, block_type: str, block_name: str) -> None:
|
def add_block(self, block_type: str, block_name: str) -> None:
|
||||||
block = self._flowgraph.new_block(block_type)
|
block = self._flowgraph.new_block(block_type)
|
||||||
block.params["id"].set_value(block_name)
|
block.params["id"].set_value(block_name)
|
||||||
|
|
||||||
|
|
||||||
def remove_block(self, block_name: str) -> None:
|
def remove_block(self, block_name: str) -> None:
|
||||||
block = self._flowgraph.get_block(block_name)
|
block = self._flowgraph.get_block(block_name)
|
||||||
self._flowgraph.remove_element(block)
|
self._flowgraph.remove_element(block)
|
||||||
|
|
||||||
|
def get_block(self, block_name: str) -> BlockMiddleware:
|
||||||
|
block = self._flowgraph.get_block(block_name)
|
||||||
|
return BlockMiddleware(block)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
|
from typing import Any, Literal, Optional, get_args
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from gnuradio.grc.core.blocks.block import Block
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
|
from gnuradio.grc.core.ports.port import Port
|
||||||
|
from gnuradio.grc.core.params.param import Param
|
||||||
|
|
||||||
class BlockModel(BaseModel):
|
class BlockModel(BaseModel):
|
||||||
label: str
|
label: str
|
||||||
@ -8,4 +10,49 @@ class BlockModel(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_block(cls, block: Block) -> "BlockModel":
|
def from_block(cls, block: Block) -> "BlockModel":
|
||||||
return cls(label=block.label, key=block.key)
|
return cls(label=block.label, key=block.key)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamModel(BaseModel):
|
||||||
|
parent: str
|
||||||
|
key: str
|
||||||
|
name: str
|
||||||
|
dtype: str
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_param(cls, param: Param) -> "ParamModel":
|
||||||
|
return cls(
|
||||||
|
parent=param.parent.name,
|
||||||
|
key=param.key,
|
||||||
|
name=param.name,
|
||||||
|
dtype=param.dtype,
|
||||||
|
value=param.get_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DirectionType = Literal["sink", "source"]
|
||||||
|
SINK, SOURCE = get_args(DirectionType)
|
||||||
|
|
||||||
|
class PortModel(BaseModel):
|
||||||
|
parent: str
|
||||||
|
key: str
|
||||||
|
name: str
|
||||||
|
dtype: str
|
||||||
|
direction: DirectionType
|
||||||
|
optional: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_port(
|
||||||
|
cls, port: Port, direction: Optional[DirectionType] = None
|
||||||
|
) -> "PortModel":
|
||||||
|
direction = direction or port._dir
|
||||||
|
return cls(
|
||||||
|
parent=port.parent.name,
|
||||||
|
key=port.key,
|
||||||
|
name=port.name,
|
||||||
|
dtype=port.dtype,
|
||||||
|
direction=direction,
|
||||||
|
optional=port.optional,
|
||||||
|
)
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
22
tests/conftest.py
Normal file
22
tests/conftest.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import pytest
|
||||||
|
from gnuradio.grc.core.platform import Platform
|
||||||
|
from gnuradio import gr
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def platform() -> Platform:
|
||||||
|
platform = Platform(
|
||||||
|
version=gr.version(),
|
||||||
|
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
|
||||||
|
prefs=gr.prefs(),
|
||||||
|
)
|
||||||
|
platform.build_library()
|
||||||
|
return platform
|
||||||
|
|
||||||
|
@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test
|
||||||
|
def block_key(platform, request):
|
||||||
|
block_keys = list(platform.blocks.keys())
|
||||||
|
assert block_keys, "No blocks available in platform library."
|
||||||
|
idx = request.param
|
||||||
|
if idx < len(block_keys):
|
||||||
|
return block_keys[idx]
|
||||||
|
return block_keys[0]
|
||||||
58
tests/test_block.py
Normal file
58
tests/test_block.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import List
|
||||||
|
import pytest
|
||||||
|
from gnuradio.grc.core.platform import Platform
|
||||||
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
|
from gnuradio.grc.core.FlowGraph import FlowGraph
|
||||||
|
from gnuradio import gr
|
||||||
|
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||||
|
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||||
|
from gnuradio_mcp.models import SINK, SOURCE, ParamModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flowgraph_middleware(platform: Platform):
|
||||||
|
flowgraph = platform.make_flow_graph("")
|
||||||
|
return FlowGraphMiddleware(flowgraph)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||||
|
flowgraph_middleware.add_block(block_key, block_key)
|
||||||
|
return flowgraph_middleware._flowgraph.get_block(block_key)
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_middleware_params(block: Block):
|
||||||
|
middleware = BlockMiddleware(block)
|
||||||
|
check_param_models(block, middleware.params)
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_middleware_sinks(block: Block):
|
||||||
|
middleware = BlockMiddleware(block)
|
||||||
|
check_port_models(middleware.sinks, block.sinks, SINK)
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_middleware_sources(block: Block):
|
||||||
|
middleware = BlockMiddleware(block)
|
||||||
|
check_port_models(middleware.sources, block.sources, SOURCE)
|
||||||
|
|
||||||
|
|
||||||
|
def check_param_models(block: Block, params: List[ParamModel]):
|
||||||
|
assert params
|
||||||
|
assert len(params) == len(block.params)
|
||||||
|
for param in params:
|
||||||
|
original_param = block.params[param.key]
|
||||||
|
assert param.key == original_param.key
|
||||||
|
assert param.name == original_param.name
|
||||||
|
assert param.dtype == original_param.dtype
|
||||||
|
assert param.value == original_param.value
|
||||||
|
|
||||||
|
|
||||||
|
def check_port_models(port_models, ports, direction):
|
||||||
|
assert isinstance(port_models, list)
|
||||||
|
assert len(port_models) == len(ports)
|
||||||
|
for model, port in zip(port_models, ports):
|
||||||
|
assert model.key == port.key
|
||||||
|
assert model.name == port.name
|
||||||
|
assert model.dtype == port.dtype
|
||||||
|
assert model.direction == direction
|
||||||
|
assert model.optional == port.optional
|
||||||
@ -1,86 +1,46 @@
|
|||||||
|
from typing import List
|
||||||
import pytest
|
import pytest
|
||||||
from gnuradio.grc.core.platform import Platform
|
from gnuradio.grc.core.platform import Platform
|
||||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||||
from gnuradio import gr
|
from gnuradio import gr
|
||||||
|
|
||||||
from gnuradio_mcp.models import BlockModel
|
from gnuradio_mcp.models import BlockModel
|
||||||
|
from .utils import get_block_keys, add_block, remove_block, get_current_blocks
|
||||||
BLOCK_KEYS_TO_TEST = [1, 10]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def platform() -> Platform:
|
|
||||||
platform = Platform(
|
|
||||||
version=gr.version(),
|
|
||||||
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
|
|
||||||
prefs=gr.prefs(),
|
|
||||||
)
|
|
||||||
platform.build_library()
|
|
||||||
return platform
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def flowgraph_middleware(platform):
|
def flowgraph_middleware(platform: Platform):
|
||||||
flowgraph = platform.make_flow_graph("")
|
flowgraph = platform.make_flow_graph("")
|
||||||
return FlowGraphMiddleware(flowgraph)
|
return FlowGraphMiddleware(flowgraph)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def initial_blocks(flowgraph_middleware):
|
def initial_blocks(flowgraph_middleware: FlowGraphMiddleware):
|
||||||
return [
|
return [
|
||||||
BlockModel(key=block.key, label=block.label)
|
BlockModel(key=block.key, label=block.label)
|
||||||
for block in flowgraph_middleware._flowgraph.blocks
|
for block in flowgraph_middleware._flowgraph.blocks
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_block_keys(platform):
|
def test_flowgraph_block_addition_and_removal(
|
||||||
return list(platform.blocks.keys())
|
flowgraph_middleware: FlowGraphMiddleware,
|
||||||
|
platform: Platform,
|
||||||
|
initial_blocks: List[BlockModel],
|
||||||
def add_block(flowgraph_middleware, block_key, block_name):
|
block_key: str,
|
||||||
flowgraph_middleware.add_block(block_key, block_name)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_block(flowgraph_middleware, block_name):
|
|
||||||
flowgraph_middleware._flowgraph.remove_element(
|
|
||||||
flowgraph_middleware._flowgraph.get_block(block_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_blocks(flowgraph_middleware):
|
|
||||||
return [
|
|
||||||
BlockModel(key=block.key, label=block.label)
|
|
||||||
for block in flowgraph_middleware._flowgraph.blocks
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_blocks_match_initial_state(flowgraph_middleware, initial_blocks):
|
|
||||||
assert flowgraph_middleware.blocks == initial_blocks
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("block_index", BLOCK_KEYS_TO_TEST)
|
|
||||||
def test_add_block_preserves_and_adds(
|
|
||||||
flowgraph_middleware, platform, initial_blocks, block_index
|
|
||||||
):
|
):
|
||||||
block_keys = get_block_keys(platform)
|
block_keys = get_block_keys(platform)
|
||||||
assert block_keys, "No blocks available in platform library."
|
assert block_keys, "No blocks available in platform library."
|
||||||
block_key = block_keys[block_index]
|
block_name = f"test_block_{block_key}"
|
||||||
block_name = f"test_block_{block_index}"
|
|
||||||
add_block(flowgraph_middleware, block_key, block_name)
|
add_block(flowgraph_middleware, block_key, block_name)
|
||||||
blocks = flowgraph_middleware.blocks
|
blocks = flowgraph_middleware.blocks
|
||||||
assert all(b in blocks for b in initial_blocks)
|
assert all(b in blocks for b in initial_blocks)
|
||||||
assert any(b.key == block_key for b in blocks)
|
assert any(b.key == block_key for b in blocks)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("block_index", BLOCK_KEYS_TO_TEST)
|
|
||||||
def test_remove_block_restores_initial(
|
|
||||||
flowgraph_middleware, platform, initial_blocks, block_index
|
|
||||||
):
|
|
||||||
block_keys = get_block_keys(platform)
|
|
||||||
block_key = block_keys[block_index]
|
|
||||||
block_name = f"block_to_remove_{block_index}"
|
|
||||||
add_block(flowgraph_middleware, block_key, block_name)
|
|
||||||
assert any(b.key == block_key for b in flowgraph_middleware.blocks)
|
|
||||||
remove_block(flowgraph_middleware, block_name)
|
remove_block(flowgraph_middleware, block_name)
|
||||||
current_blocks = get_current_blocks(flowgraph_middleware)
|
current_blocks = get_current_blocks(flowgraph_middleware)
|
||||||
assert current_blocks == initial_blocks
|
assert current_blocks == initial_blocks
|
||||||
|
|
||||||
|
|
||||||
|
def test_flowgraph_initial_state(
|
||||||
|
flowgraph_middleware: FlowGraphMiddleware, initial_blocks: List[BlockModel]
|
||||||
|
):
|
||||||
|
assert flowgraph_middleware.blocks == initial_blocks
|
||||||
|
|||||||
@ -4,25 +4,15 @@ from gnuradio.grc.core.blocks.block import Block
|
|||||||
from gnuradio.grc.core.platform import Platform
|
from gnuradio.grc.core.platform import Platform
|
||||||
from gnuradio import gr
|
from gnuradio import gr
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def platform() -> Platform:
|
|
||||||
platform = Platform(
|
|
||||||
version=gr.version(),
|
|
||||||
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
|
|
||||||
prefs=gr.prefs(),
|
|
||||||
)
|
|
||||||
platform.build_library()
|
|
||||||
return platform
|
|
||||||
|
|
||||||
|
def test_block_model_from_block(platform: Platform):
|
||||||
def test_block_model_from_block(platform):
|
|
||||||
block = Block(platform)
|
block = Block(platform)
|
||||||
model = BlockModel.from_block(block)
|
model = BlockModel.from_block(block)
|
||||||
assert model.label == block.label
|
assert model.label == block.label
|
||||||
assert model.key == block.key
|
assert model.key == block.key
|
||||||
|
|
||||||
|
|
||||||
def test_platform_middleware_blocks(platform):
|
def test_platform_middleware_blocks(platform: Platform):
|
||||||
middleware = PlatformMiddleware(platform)
|
middleware = PlatformMiddleware(platform)
|
||||||
block_models = middleware.blocks
|
block_models = middleware.blocks
|
||||||
assert block_models # Checks that the list is not empty
|
assert block_models # Checks that the list is not empty
|
||||||
|
|||||||
18
tests/utils.py
Normal file
18
tests/utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from gnuradio_mcp.models import BlockModel
|
||||||
|
|
||||||
|
def get_block_keys(platform):
|
||||||
|
return list(platform.blocks.keys())
|
||||||
|
|
||||||
|
def add_block(flowgraph_middleware, block_key, block_name):
|
||||||
|
flowgraph_middleware.add_block(block_key, block_name)
|
||||||
|
|
||||||
|
def remove_block(flowgraph_middleware, block_name):
|
||||||
|
flowgraph_middleware._flowgraph.remove_element(
|
||||||
|
flowgraph_middleware._flowgraph.get_block(block_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_current_blocks(flowgraph_middleware):
|
||||||
|
return [
|
||||||
|
BlockModel(key=block.key, label=block.label)
|
||||||
|
for block in flowgraph_middleware._flowgraph.blocks
|
||||||
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user