main - feat: Implement gnuradio element validation and errors
This commit is contained in:
parent
0be8e77596
commit
42484d6c7d
7
main.py
7
main.py
@ -22,11 +22,12 @@ platform = Platform(
|
||||
version=gr.version(),
|
||||
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
|
||||
prefs=gr.prefs(),
|
||||
# install_prefix=gr.prefix()
|
||||
)
|
||||
platform.build_library()
|
||||
|
||||
|
||||
platform_middleware = PlatformMiddleware(platform)
|
||||
|
||||
print(platform_middleware.blocks)
|
||||
flowgraph_mw = platform_middleware.make_flowgraph()
|
||||
flowgraph_mw.add_block("blocks_add_xx")
|
||||
for error in flowgraph_mw.get_all_errors():
|
||||
print(error)
|
||||
|
||||
23
src/gnuradio_mcp/middlewares/base.py
Normal file
23
src/gnuradio_mcp/middlewares/base.py
Normal file
@ -0,0 +1,23 @@
|
||||
from gnuradio.grc.core.base import Element
|
||||
|
||||
from gnuradio_mcp.models import ErrorModel
|
||||
from gnuradio_mcp.utils import format_error_message
|
||||
|
||||
|
||||
class ElementMiddleware:
|
||||
def __init__(self, element: Element):
|
||||
self._element = element
|
||||
|
||||
def _rewrite(self):
|
||||
self._element.rewrite()
|
||||
|
||||
def validate(self):
|
||||
self._rewrite()
|
||||
self._element.validate()
|
||||
|
||||
def get_all_errors(self) -> list[ErrorModel]:
|
||||
self.validate()
|
||||
return [
|
||||
format_error_message(elem, msg)
|
||||
for elem, msg in self._element.iter_error_messages()
|
||||
]
|
||||
@ -4,17 +4,23 @@ from typing import Any, Dict
|
||||
|
||||
from gnuradio.grc.core.blocks.block import Block
|
||||
|
||||
from gnuradio_mcp.middlewares.base import ElementMiddleware
|
||||
from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel
|
||||
|
||||
|
||||
class BlockMiddleware:
|
||||
class BlockMiddleware(ElementMiddleware):
|
||||
def __init__(self, block: Block):
|
||||
self._block = block
|
||||
super().__init__(block)
|
||||
self._block = self._element
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._block.name
|
||||
|
||||
@name.setter
|
||||
def name(self, name: str):
|
||||
self._block.params["id"].set_value(name)
|
||||
|
||||
def set_param(self, param_name: str, param_value: Any):
|
||||
self._block.params[param_name].set_value(param_value)
|
||||
|
||||
@ -50,6 +56,3 @@ class BlockMiddleware:
|
||||
except ValueError:
|
||||
pass
|
||||
return ports
|
||||
|
||||
def _rewrite(self):
|
||||
self._block.rewrite()
|
||||
|
||||
@ -1,19 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Set
|
||||
|
||||
from gnuradio.grc.core.blocks.block import Block
|
||||
from gnuradio.grc.core.FlowGraph import FlowGraph
|
||||
|
||||
from gnuradio_mcp.middlewares.base import ElementMiddleware
|
||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||
from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel
|
||||
from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel
|
||||
from gnuradio_mcp.utils import get_unique_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||
|
||||
class FlowGraphMiddleware:
|
||||
|
||||
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
|
||||
block_from_port_model = flowgraph.get_block(port_model.parent)
|
||||
if port_model.direction == SOURCE:
|
||||
return block_from_port_model.sources[port_model.key]
|
||||
elif port_model.direction == SINK:
|
||||
return block_from_port_model.sinks[port_model.key]
|
||||
else:
|
||||
raise ValueError(f"Invalid port direction: {port_model.direction}")
|
||||
|
||||
|
||||
class FlowGraphMiddleware(ElementMiddleware):
|
||||
def __init__(self, flowgraph: FlowGraph):
|
||||
self._flowgraph = flowgraph
|
||||
self._blocks: Dict[str, BlockMiddleware] = {}
|
||||
super().__init__(flowgraph)
|
||||
self._flowgraph = self._element
|
||||
self._blocks: Set[BlockMiddleware] = set()
|
||||
|
||||
@property
|
||||
def blocks(self) -> list[BlockModel]:
|
||||
@ -25,39 +40,34 @@ class FlowGraphMiddleware:
|
||||
def add_block(
|
||||
self, block_type: str, block_name: Optional[str] = None
|
||||
) -> BlockMiddleware:
|
||||
block_name = block_name or get_unique_id(self._flowgraph.blocks)
|
||||
block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type)
|
||||
block = self._flowgraph.new_block(block_type)
|
||||
block.params["id"].set_value(block_name)
|
||||
self._blocks[block_name] = BlockMiddleware(block)
|
||||
return self._blocks[block_name]
|
||||
assert block is not None, f"Failed to create block: {block_type}"
|
||||
block_middleware = BlockMiddleware(block)
|
||||
block_middleware.name = block_name
|
||||
self._blocks.add(block_middleware)
|
||||
return block_middleware
|
||||
|
||||
def remove_block(self, block_name: str) -> None:
|
||||
block = self._flowgraph.get_block(block_name)
|
||||
self._flowgraph.remove_element(block)
|
||||
del self._blocks[block_name]
|
||||
block_middleware = self.get_block(block_name)
|
||||
self._flowgraph.remove_element(block_middleware._block)
|
||||
self._blocks.remove(block_middleware)
|
||||
|
||||
def get_block(self, block_name: str) -> BlockMiddleware:
|
||||
# TODO: Check if calling two times you get different results
|
||||
return self._blocks[block_name]
|
||||
return next(block for block in self._blocks if block.name == block_name)
|
||||
|
||||
def connect_blocks(
|
||||
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||
) -> None:
|
||||
def get_block_by_port_model(port_model: PortModel) -> Block:
|
||||
return self._flowgraph.get_block(port_model.parent)
|
||||
|
||||
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
||||
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
||||
src_port = get_port_from_port_model(self._flowgraph, src_port_model)
|
||||
dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
|
||||
self._flowgraph.connect(src_port, dst_port)
|
||||
|
||||
def disconnect_blocks(
|
||||
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||
) -> None:
|
||||
def get_block_by_port_model(port_model: PortModel) -> Block:
|
||||
return self._flowgraph.get_block(port_model.parent)
|
||||
|
||||
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
||||
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
||||
src_port = get_port_from_port_model(self._flowgraph, src_port_model)
|
||||
dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
|
||||
self._flowgraph.disconnect(src_port, dst_port)
|
||||
|
||||
def get_connections(self) -> list[ConnectionModel]:
|
||||
@ -65,3 +75,12 @@ class FlowGraphMiddleware:
|
||||
ConnectionModel.from_connection(connection)
|
||||
for connection in self._flowgraph.connections
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls, platform: "PlatformMiddleware", filepath: str = ""
|
||||
) -> FlowGraphMiddleware:
|
||||
initial_state = platform._platform.parse_flow_graph(filepath)
|
||||
flowgraph = FlowGraph(platform._platform)
|
||||
flowgraph.import_data(initial_state)
|
||||
return cls(flowgraph)
|
||||
|
||||
@ -2,15 +2,15 @@ from __future__ import annotations
|
||||
|
||||
from gnuradio.grc.core.platform import Platform
|
||||
|
||||
from gnuradio_mcp.middlewares.base import ElementMiddleware
|
||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||
from gnuradio_mcp.models import BlockModel
|
||||
|
||||
|
||||
class PlatformMiddleware:
|
||||
class PlatformMiddleware(ElementMiddleware):
|
||||
def __init__(self, platform: Platform):
|
||||
self._platform = platform
|
||||
flowgraph = self._platform.make_flow_graph("")
|
||||
self._flowgraph_mw = FlowGraphMiddleware(flowgraph)
|
||||
super().__init__(platform)
|
||||
self._platform = self._element
|
||||
|
||||
@property
|
||||
def blocks(self) -> list[BlockModel]:
|
||||
@ -18,6 +18,5 @@ class PlatformMiddleware:
|
||||
BlockModel.from_block(block) for block in self._platform.blocks.values()
|
||||
]
|
||||
|
||||
@property
|
||||
def flowgraph(self) -> FlowGraphMiddleware:
|
||||
return self._flowgraph_mw
|
||||
def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware:
|
||||
return FlowGraphMiddleware.from_file(self, filepath)
|
||||
|
||||
@ -77,3 +77,9 @@ class ConnectionModel(BaseModel):
|
||||
source=PortModel.from_port(connection.source_port),
|
||||
sink=PortModel.from_port(connection.sink_port),
|
||||
)
|
||||
|
||||
|
||||
class ErrorModel(BaseModel):
|
||||
type: str
|
||||
key: BaseModel
|
||||
message: str
|
||||
|
||||
@ -1,5 +1,20 @@
|
||||
import re
|
||||
from itertools import count
|
||||
|
||||
from gnuradio.grc.core.blocks.block import Block
|
||||
from gnuradio.grc.core.Connection import Connection
|
||||
from gnuradio.grc.core.params.param import Param
|
||||
from gnuradio.grc.core.ports.port import Port
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gnuradio_mcp.models import (
|
||||
BlockModel,
|
||||
ConnectionModel,
|
||||
ErrorModel,
|
||||
ParamModel,
|
||||
PortModel,
|
||||
)
|
||||
|
||||
|
||||
def get_unique_id(flowgraph_blocks, base_id=""):
|
||||
block_ids = set(b.name for b in flowgraph_blocks)
|
||||
@ -8,3 +23,28 @@ def get_unique_id(flowgraph_blocks, base_id=""):
|
||||
if block_id not in block_ids:
|
||||
break
|
||||
return block_id
|
||||
|
||||
|
||||
def format_error_message(elem, msg) -> ErrorModel:
|
||||
msg = re.sub("[^A-Za-z0-9]+", " ", msg).strip()
|
||||
model: BaseModel
|
||||
match (elem):
|
||||
case Connection():
|
||||
model = ConnectionModel.from_connection(elem)
|
||||
|
||||
case Param():
|
||||
model = ParamModel.from_param(elem)
|
||||
|
||||
case Port():
|
||||
model = PortModel.from_port(elem)
|
||||
|
||||
case Block():
|
||||
model = BlockModel.from_block(elem)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported element type: {type(elem)}")
|
||||
return ErrorModel(
|
||||
type=type(model).__name__,
|
||||
key=model,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
@ -4,6 +4,8 @@ import pytest
|
||||
from gnuradio import gr
|
||||
from gnuradio.grc.core.platform import Platform
|
||||
|
||||
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform() -> Platform:
|
||||
@ -20,6 +22,11 @@ def platform() -> Platform:
|
||||
return platform
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform_middleware(platform: Platform) -> PlatformMiddleware:
|
||||
return PlatformMiddleware(platform)
|
||||
|
||||
|
||||
@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test
|
||||
def block_key(platform, request):
|
||||
block_keys = list(platform.blocks.keys())
|
||||
|
||||
@ -6,7 +6,7 @@ from gnuradio.grc.core.platform import Platform
|
||||
|
||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||
from gnuradio_mcp.models import SINK, SOURCE, ParamModel
|
||||
from gnuradio_mcp.models import SINK, SOURCE, ErrorModel, ParamModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -44,6 +44,22 @@ def test_block_middleware_set_params(block_middleware: BlockMiddleware):
|
||||
assert block_middleware._block.params["id"].get_value() == "my_custom_block_name"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_key, initial_errors_number", [("blocks_add_xx", 3), ("blocks_copy", 2)]
|
||||
)
|
||||
def test_block_errors(
|
||||
flowgraph_middleware: FlowGraphMiddleware,
|
||||
block_key: str,
|
||||
initial_errors_number: int,
|
||||
):
|
||||
block_middleware = flowgraph_middleware.add_block(block_key)
|
||||
for error in block_middleware.get_all_errors():
|
||||
assert isinstance(error, ErrorModel)
|
||||
assert len(block_middleware.get_all_errors()) == initial_errors_number
|
||||
# Call again to check that the errors are not duplicated
|
||||
assert len(block_middleware.get_all_errors()) == initial_errors_number
|
||||
|
||||
|
||||
def check_param_models(block: Block, params: list[ParamModel]):
|
||||
assert params
|
||||
assert len(params) == len(block.params)
|
||||
|
||||
@ -3,17 +3,16 @@ from __future__ import annotations
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from gnuradio.grc.core.platform import Platform
|
||||
|
||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||
from gnuradio_mcp.models import BlockModel, ConnectionModel
|
||||
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||
from gnuradio_mcp.models import BlockModel, ConnectionModel, ErrorModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flowgraph_middleware(platform: Platform):
|
||||
flowgraph = platform.make_flow_graph("")
|
||||
return FlowGraphMiddleware(flowgraph)
|
||||
def flowgraph_middleware(platform_middleware: PlatformMiddleware):
|
||||
return platform_middleware.make_flowgraph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -67,6 +66,11 @@ def test_block_unique_names_for_same_type(
|
||||
assert first_block.name != second_block.name
|
||||
|
||||
|
||||
def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||
block = flowgraph_middleware.add_block(block_key)
|
||||
assert block_key in block.name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_key, sinks_number, sources_number",
|
||||
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
|
||||
@ -109,6 +113,12 @@ def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_ke
|
||||
assert len(flowgraph_middleware.get_connections()) == 0
|
||||
|
||||
|
||||
def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware):
|
||||
for error in flowgraph_middleware.get_all_errors():
|
||||
assert isinstance(error, ErrorModel)
|
||||
assert len(flowgraph_middleware.get_all_errors()) == 0
|
||||
|
||||
|
||||
def util_iter_possible_connections(
|
||||
source_block: BlockMiddleware,
|
||||
dest_block: BlockMiddleware,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user