main - feat: Implement gnuradio element validation and errors
This commit is contained in:
parent
0be8e77596
commit
42484d6c7d
7
main.py
7
main.py
@ -22,11 +22,12 @@ platform = Platform(
|
|||||||
version=gr.version(),
|
version=gr.version(),
|
||||||
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
|
version_parts=(gr.major_version(), gr.api_version(), gr.minor_version()),
|
||||||
prefs=gr.prefs(),
|
prefs=gr.prefs(),
|
||||||
# install_prefix=gr.prefix()
|
|
||||||
)
|
)
|
||||||
platform.build_library()
|
platform.build_library()
|
||||||
|
|
||||||
|
|
||||||
platform_middleware = PlatformMiddleware(platform)
|
platform_middleware = PlatformMiddleware(platform)
|
||||||
|
flowgraph_mw = platform_middleware.make_flowgraph()
|
||||||
print(platform_middleware.blocks)
|
flowgraph_mw.add_block("blocks_add_xx")
|
||||||
|
for error in flowgraph_mw.get_all_errors():
|
||||||
|
print(error)
|
||||||
|
|||||||
23
src/gnuradio_mcp/middlewares/base.py
Normal file
23
src/gnuradio_mcp/middlewares/base.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from gnuradio.grc.core.base import Element
|
||||||
|
|
||||||
|
from gnuradio_mcp.models import ErrorModel
|
||||||
|
from gnuradio_mcp.utils import format_error_message
|
||||||
|
|
||||||
|
|
||||||
|
class ElementMiddleware:
|
||||||
|
def __init__(self, element: Element):
|
||||||
|
self._element = element
|
||||||
|
|
||||||
|
def _rewrite(self):
|
||||||
|
self._element.rewrite()
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self._rewrite()
|
||||||
|
self._element.validate()
|
||||||
|
|
||||||
|
def get_all_errors(self) -> list[ErrorModel]:
|
||||||
|
self.validate()
|
||||||
|
return [
|
||||||
|
format_error_message(elem, msg)
|
||||||
|
for elem, msg in self._element.iter_error_messages()
|
||||||
|
]
|
||||||
@ -4,17 +4,23 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
from gnuradio.grc.core.blocks.block import Block
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
|
|
||||||
|
from gnuradio_mcp.middlewares.base import ElementMiddleware
|
||||||
from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel
|
from gnuradio_mcp.models import SINK, SOURCE, ParamModel, PortModel
|
||||||
|
|
||||||
|
|
||||||
class BlockMiddleware:
|
class BlockMiddleware(ElementMiddleware):
|
||||||
def __init__(self, block: Block):
|
def __init__(self, block: Block):
|
||||||
self._block = block
|
super().__init__(block)
|
||||||
|
self._block = self._element
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self._block.name
|
return self._block.name
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str):
|
||||||
|
self._block.params["id"].set_value(name)
|
||||||
|
|
||||||
def set_param(self, param_name: str, param_value: Any):
|
def set_param(self, param_name: str, param_value: Any):
|
||||||
self._block.params[param_name].set_value(param_value)
|
self._block.params[param_name].set_value(param_value)
|
||||||
|
|
||||||
@ -50,6 +56,3 @@ class BlockMiddleware:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
return ports
|
return ports
|
||||||
|
|
||||||
def _rewrite(self):
|
|
||||||
self._block.rewrite()
|
|
||||||
|
|||||||
@ -1,19 +1,34 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import TYPE_CHECKING, Optional, Set
|
||||||
|
|
||||||
from gnuradio.grc.core.blocks.block import Block
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
from gnuradio.grc.core.FlowGraph import FlowGraph
|
from gnuradio.grc.core.FlowGraph import FlowGraph
|
||||||
|
|
||||||
|
from gnuradio_mcp.middlewares.base import ElementMiddleware
|
||||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||||
from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel
|
from gnuradio_mcp.models import SINK, SOURCE, BlockModel, ConnectionModel, PortModel
|
||||||
from gnuradio_mcp.utils import get_unique_id
|
from gnuradio_mcp.utils import get_unique_id
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||||
|
|
||||||
class FlowGraphMiddleware:
|
|
||||||
|
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
|
||||||
|
block_from_port_model = flowgraph.get_block(port_model.parent)
|
||||||
|
if port_model.direction == SOURCE:
|
||||||
|
return block_from_port_model.sources[port_model.key]
|
||||||
|
elif port_model.direction == SINK:
|
||||||
|
return block_from_port_model.sinks[port_model.key]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid port direction: {port_model.direction}")
|
||||||
|
|
||||||
|
|
||||||
|
class FlowGraphMiddleware(ElementMiddleware):
|
||||||
def __init__(self, flowgraph: FlowGraph):
|
def __init__(self, flowgraph: FlowGraph):
|
||||||
self._flowgraph = flowgraph
|
super().__init__(flowgraph)
|
||||||
self._blocks: Dict[str, BlockMiddleware] = {}
|
self._flowgraph = self._element
|
||||||
|
self._blocks: Set[BlockMiddleware] = set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def blocks(self) -> list[BlockModel]:
|
def blocks(self) -> list[BlockModel]:
|
||||||
@ -25,39 +40,34 @@ class FlowGraphMiddleware:
|
|||||||
def add_block(
|
def add_block(
|
||||||
self, block_type: str, block_name: Optional[str] = None
|
self, block_type: str, block_name: Optional[str] = None
|
||||||
) -> BlockMiddleware:
|
) -> BlockMiddleware:
|
||||||
block_name = block_name or get_unique_id(self._flowgraph.blocks)
|
block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type)
|
||||||
block = self._flowgraph.new_block(block_type)
|
block = self._flowgraph.new_block(block_type)
|
||||||
block.params["id"].set_value(block_name)
|
assert block is not None, f"Failed to create block: {block_type}"
|
||||||
self._blocks[block_name] = BlockMiddleware(block)
|
block_middleware = BlockMiddleware(block)
|
||||||
return self._blocks[block_name]
|
block_middleware.name = block_name
|
||||||
|
self._blocks.add(block_middleware)
|
||||||
|
return block_middleware
|
||||||
|
|
||||||
def remove_block(self, block_name: str) -> None:
|
def remove_block(self, block_name: str) -> None:
|
||||||
block = self._flowgraph.get_block(block_name)
|
block_middleware = self.get_block(block_name)
|
||||||
self._flowgraph.remove_element(block)
|
self._flowgraph.remove_element(block_middleware._block)
|
||||||
del self._blocks[block_name]
|
self._blocks.remove(block_middleware)
|
||||||
|
|
||||||
def get_block(self, block_name: str) -> BlockMiddleware:
|
def get_block(self, block_name: str) -> BlockMiddleware:
|
||||||
# TODO: Check if calling two times you get different results
|
return next(block for block in self._blocks if block.name == block_name)
|
||||||
return self._blocks[block_name]
|
|
||||||
|
|
||||||
def connect_blocks(
|
def connect_blocks(
|
||||||
self, src_port_model: PortModel, dst_port_model: PortModel
|
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||||
) -> None:
|
) -> None:
|
||||||
def get_block_by_port_model(port_model: PortModel) -> Block:
|
src_port = get_port_from_port_model(self._flowgraph, src_port_model)
|
||||||
return self._flowgraph.get_block(port_model.parent)
|
dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
|
||||||
|
|
||||||
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
|
||||||
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
|
||||||
self._flowgraph.connect(src_port, dst_port)
|
self._flowgraph.connect(src_port, dst_port)
|
||||||
|
|
||||||
def disconnect_blocks(
|
def disconnect_blocks(
|
||||||
self, src_port_model: PortModel, dst_port_model: PortModel
|
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||||
) -> None:
|
) -> None:
|
||||||
def get_block_by_port_model(port_model: PortModel) -> Block:
|
src_port = get_port_from_port_model(self._flowgraph, src_port_model)
|
||||||
return self._flowgraph.get_block(port_model.parent)
|
dst_port = get_port_from_port_model(self._flowgraph, dst_port_model)
|
||||||
|
|
||||||
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
|
||||||
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
|
||||||
self._flowgraph.disconnect(src_port, dst_port)
|
self._flowgraph.disconnect(src_port, dst_port)
|
||||||
|
|
||||||
def get_connections(self) -> list[ConnectionModel]:
|
def get_connections(self) -> list[ConnectionModel]:
|
||||||
@ -65,3 +75,12 @@ class FlowGraphMiddleware:
|
|||||||
ConnectionModel.from_connection(connection)
|
ConnectionModel.from_connection(connection)
|
||||||
for connection in self._flowgraph.connections
|
for connection in self._flowgraph.connections
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(
|
||||||
|
cls, platform: "PlatformMiddleware", filepath: str = ""
|
||||||
|
) -> FlowGraphMiddleware:
|
||||||
|
initial_state = platform._platform.parse_flow_graph(filepath)
|
||||||
|
flowgraph = FlowGraph(platform._platform)
|
||||||
|
flowgraph.import_data(initial_state)
|
||||||
|
return cls(flowgraph)
|
||||||
|
|||||||
@ -2,15 +2,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from gnuradio.grc.core.platform import Platform
|
from gnuradio.grc.core.platform import Platform
|
||||||
|
|
||||||
|
from gnuradio_mcp.middlewares.base import ElementMiddleware
|
||||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||||
from gnuradio_mcp.models import BlockModel
|
from gnuradio_mcp.models import BlockModel
|
||||||
|
|
||||||
|
|
||||||
class PlatformMiddleware:
|
class PlatformMiddleware(ElementMiddleware):
|
||||||
def __init__(self, platform: Platform):
|
def __init__(self, platform: Platform):
|
||||||
self._platform = platform
|
super().__init__(platform)
|
||||||
flowgraph = self._platform.make_flow_graph("")
|
self._platform = self._element
|
||||||
self._flowgraph_mw = FlowGraphMiddleware(flowgraph)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def blocks(self) -> list[BlockModel]:
|
def blocks(self) -> list[BlockModel]:
|
||||||
@ -18,6 +18,5 @@ class PlatformMiddleware:
|
|||||||
BlockModel.from_block(block) for block in self._platform.blocks.values()
|
BlockModel.from_block(block) for block in self._platform.blocks.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
@property
|
def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware:
|
||||||
def flowgraph(self) -> FlowGraphMiddleware:
|
return FlowGraphMiddleware.from_file(self, filepath)
|
||||||
return self._flowgraph_mw
|
|
||||||
|
|||||||
@ -77,3 +77,9 @@ class ConnectionModel(BaseModel):
|
|||||||
source=PortModel.from_port(connection.source_port),
|
source=PortModel.from_port(connection.source_port),
|
||||||
sink=PortModel.from_port(connection.sink_port),
|
sink=PortModel.from_port(connection.sink_port),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorModel(BaseModel):
|
||||||
|
type: str
|
||||||
|
key: BaseModel
|
||||||
|
message: str
|
||||||
|
|||||||
@ -1,5 +1,20 @@
|
|||||||
|
import re
|
||||||
from itertools import count
|
from itertools import count
|
||||||
|
|
||||||
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
|
from gnuradio.grc.core.Connection import Connection
|
||||||
|
from gnuradio.grc.core.params.param import Param
|
||||||
|
from gnuradio.grc.core.ports.port import Port
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from gnuradio_mcp.models import (
|
||||||
|
BlockModel,
|
||||||
|
ConnectionModel,
|
||||||
|
ErrorModel,
|
||||||
|
ParamModel,
|
||||||
|
PortModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_unique_id(flowgraph_blocks, base_id=""):
|
def get_unique_id(flowgraph_blocks, base_id=""):
|
||||||
block_ids = set(b.name for b in flowgraph_blocks)
|
block_ids = set(b.name for b in flowgraph_blocks)
|
||||||
@ -8,3 +23,28 @@ def get_unique_id(flowgraph_blocks, base_id=""):
|
|||||||
if block_id not in block_ids:
|
if block_id not in block_ids:
|
||||||
break
|
break
|
||||||
return block_id
|
return block_id
|
||||||
|
|
||||||
|
|
||||||
|
def format_error_message(elem, msg) -> ErrorModel:
|
||||||
|
msg = re.sub("[^A-Za-z0-9]+", " ", msg).strip()
|
||||||
|
model: BaseModel
|
||||||
|
match (elem):
|
||||||
|
case Connection():
|
||||||
|
model = ConnectionModel.from_connection(elem)
|
||||||
|
|
||||||
|
case Param():
|
||||||
|
model = ParamModel.from_param(elem)
|
||||||
|
|
||||||
|
case Port():
|
||||||
|
model = PortModel.from_port(elem)
|
||||||
|
|
||||||
|
case Block():
|
||||||
|
model = BlockModel.from_block(elem)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported element type: {type(elem)}")
|
||||||
|
return ErrorModel(
|
||||||
|
type=type(model).__name__,
|
||||||
|
key=model,
|
||||||
|
message=msg,
|
||||||
|
)
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import pytest
|
|||||||
from gnuradio import gr
|
from gnuradio import gr
|
||||||
from gnuradio.grc.core.platform import Platform
|
from gnuradio.grc.core.platform import Platform
|
||||||
|
|
||||||
|
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def platform() -> Platform:
|
def platform() -> Platform:
|
||||||
@ -20,6 +22,11 @@ def platform() -> Platform:
|
|||||||
return platform
|
return platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def platform_middleware(platform: Platform) -> PlatformMiddleware:
|
||||||
|
return PlatformMiddleware(platform)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test
|
@pytest.fixture(params=[1, 2, 10]) # Arbitrary number of blocks to test
|
||||||
def block_key(platform, request):
|
def block_key(platform, request):
|
||||||
block_keys = list(platform.blocks.keys())
|
block_keys = list(platform.blocks.keys())
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from gnuradio.grc.core.platform import Platform
|
|||||||
|
|
||||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||||
from gnuradio_mcp.models import SINK, SOURCE, ParamModel
|
from gnuradio_mcp.models import SINK, SOURCE, ErrorModel, ParamModel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -44,6 +44,22 @@ def test_block_middleware_set_params(block_middleware: BlockMiddleware):
|
|||||||
assert block_middleware._block.params["id"].get_value() == "my_custom_block_name"
|
assert block_middleware._block.params["id"].get_value() == "my_custom_block_name"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"block_key, initial_errors_number", [("blocks_add_xx", 3), ("blocks_copy", 2)]
|
||||||
|
)
|
||||||
|
def test_block_errors(
|
||||||
|
flowgraph_middleware: FlowGraphMiddleware,
|
||||||
|
block_key: str,
|
||||||
|
initial_errors_number: int,
|
||||||
|
):
|
||||||
|
block_middleware = flowgraph_middleware.add_block(block_key)
|
||||||
|
for error in block_middleware.get_all_errors():
|
||||||
|
assert isinstance(error, ErrorModel)
|
||||||
|
assert len(block_middleware.get_all_errors()) == initial_errors_number
|
||||||
|
# Call again to check that the errors are not duplicated
|
||||||
|
assert len(block_middleware.get_all_errors()) == initial_errors_number
|
||||||
|
|
||||||
|
|
||||||
def check_param_models(block: Block, params: list[ParamModel]):
|
def check_param_models(block: Block, params: list[ParamModel]):
|
||||||
assert params
|
assert params
|
||||||
assert len(params) == len(block.params)
|
assert len(params) == len(block.params)
|
||||||
|
|||||||
@ -3,17 +3,16 @@ from __future__ import annotations
|
|||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from gnuradio.grc.core.platform import Platform
|
|
||||||
|
|
||||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||||
from gnuradio_mcp.models import BlockModel, ConnectionModel
|
from gnuradio_mcp.middlewares.platform import PlatformMiddleware
|
||||||
|
from gnuradio_mcp.models import BlockModel, ConnectionModel, ErrorModel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def flowgraph_middleware(platform: Platform):
|
def flowgraph_middleware(platform_middleware: PlatformMiddleware):
|
||||||
flowgraph = platform.make_flow_graph("")
|
return platform_middleware.make_flowgraph()
|
||||||
return FlowGraphMiddleware(flowgraph)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -67,6 +66,11 @@ def test_block_unique_names_for_same_type(
|
|||||||
assert first_block.name != second_block.name
|
assert first_block.name != second_block.name
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||||
|
block = flowgraph_middleware.add_block(block_key)
|
||||||
|
assert block_key in block.name
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"block_key, sinks_number, sources_number",
|
"block_key, sinks_number, sources_number",
|
||||||
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
|
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
|
||||||
@ -109,6 +113,12 @@ def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_ke
|
|||||||
assert len(flowgraph_middleware.get_connections()) == 0
|
assert len(flowgraph_middleware.get_connections()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware):
|
||||||
|
for error in flowgraph_middleware.get_all_errors():
|
||||||
|
assert isinstance(error, ErrorModel)
|
||||||
|
assert len(flowgraph_middleware.get_all_errors()) == 0
|
||||||
|
|
||||||
|
|
||||||
def util_iter_possible_connections(
|
def util_iter_possible_connections(
|
||||||
source_block: BlockMiddleware,
|
source_block: BlockMiddleware,
|
||||||
dest_block: BlockMiddleware,
|
dest_block: BlockMiddleware,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user