main - feat: Implement block connection
This commit is contained in:
parent
73bf514fc1
commit
5485413efd
@ -9,14 +9,39 @@ class BlockMiddleware:
|
|||||||
def __init__(self, block: Block):
|
def __init__(self, block: Block):
|
||||||
self._block = block
|
self._block = block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._block.name
|
||||||
|
|
||||||
|
# TODO: Check if rewrite is needed
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def params(self) -> list[ParamModel]:
|
def params(self) -> list[ParamModel]:
|
||||||
return [ParamModel.from_param(param) for param in self._block.params.values()]
|
return [ParamModel.from_param(param) for param in self._block.params.values()]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sinks(self) -> list[PortModel]:
|
def sinks(self) -> list[PortModel]:
|
||||||
return [PortModel.from_port(port, SINK) for port in self._block.sinks]
|
self._rewrite()
|
||||||
|
ports = []
|
||||||
|
for port in self._block.sinks:
|
||||||
|
try:
|
||||||
|
port_model = PortModel.from_port(port, SINK)
|
||||||
|
ports.append(port_model)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return ports
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sources(self) -> list[PortModel]:
|
def sources(self) -> list[PortModel]:
|
||||||
return [PortModel.from_port(port, SOURCE) for port in self._block.sources]
|
self._rewrite()
|
||||||
|
ports = []
|
||||||
|
for port in self._block.sources:
|
||||||
|
try:
|
||||||
|
port_model = PortModel.from_port(port, SOURCE)
|
||||||
|
ports.append(port_model)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return ports
|
||||||
|
|
||||||
|
def _rewrite(self):
|
||||||
|
self._block.rewrite()
|
||||||
|
|||||||
@ -1,17 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
from gnuradio.grc.core.FlowGraph import FlowGraph
|
from gnuradio.grc.core.FlowGraph import FlowGraph
|
||||||
|
|
||||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||||
from gnuradio_mcp.models import BlockModel
|
from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel
|
||||||
from gnuradio_mcp.utils import get_unique_id
|
from gnuradio_mcp.utils import get_unique_id
|
||||||
|
|
||||||
|
|
||||||
class FlowGraphMiddleware:
|
class FlowGraphMiddleware:
|
||||||
def __init__(self, flowgraph: FlowGraph):
|
def __init__(self, flowgraph: FlowGraph):
|
||||||
self._flowgraph = flowgraph
|
self._flowgraph = flowgraph
|
||||||
|
self._blocks: Dict[str, BlockMiddleware] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def blocks(self) -> list[BlockModel]:
|
def blocks(self) -> list[BlockModel]:
|
||||||
@ -20,16 +22,46 @@ class FlowGraphMiddleware:
|
|||||||
for block in self._flowgraph.blocks
|
for block in self._flowgraph.blocks
|
||||||
]
|
]
|
||||||
|
|
||||||
def add_block(self, block_type: str, block_name: Optional[str] = None) -> str:
|
def add_block(
|
||||||
|
self, block_type: str, block_name: Optional[str] = None
|
||||||
|
) -> BlockMiddleware:
|
||||||
block_name = block_name or get_unique_id(self._flowgraph.blocks)
|
block_name = block_name or get_unique_id(self._flowgraph.blocks)
|
||||||
block = self._flowgraph.new_block(block_type)
|
block = self._flowgraph.new_block(block_type)
|
||||||
block.params["id"].set_value(block_name)
|
block.params["id"].set_value(block_name)
|
||||||
return block_name
|
self._blocks[block_name] = BlockMiddleware(block)
|
||||||
|
return self._blocks[block_name]
|
||||||
|
|
||||||
def remove_block(self, block_name: str) -> None:
|
def remove_block(self, block_name: str) -> None:
|
||||||
block = self._flowgraph.get_block(block_name)
|
block = self._flowgraph.get_block(block_name)
|
||||||
self._flowgraph.remove_element(block)
|
self._flowgraph.remove_element(block)
|
||||||
|
del self._blocks[block_name]
|
||||||
|
|
||||||
def get_block(self, block_name: str) -> BlockMiddleware:
|
def get_block(self, block_name: str) -> BlockMiddleware:
|
||||||
block = self._flowgraph.get_block(block_name)
|
# TODO: Check if calling two times you get different results
|
||||||
return BlockMiddleware(block)
|
return self._blocks[block_name]
|
||||||
|
|
||||||
|
def connect_blocks(
|
||||||
|
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||||
|
) -> None:
|
||||||
|
def get_block_by_port_model(port_model: PortModel) -> Block:
|
||||||
|
return self._flowgraph.get_block(port_model.parent)
|
||||||
|
|
||||||
|
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
||||||
|
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
||||||
|
self._flowgraph.connect(src_port, dst_port)
|
||||||
|
|
||||||
|
def disconnect_blocks(
|
||||||
|
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||||
|
) -> None:
|
||||||
|
def get_block_by_port_model(port_model: PortModel) -> Block:
|
||||||
|
return self._flowgraph.get_block(port_model.parent)
|
||||||
|
|
||||||
|
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
||||||
|
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
||||||
|
self._flowgraph.disconnect(src_port, dst_port)
|
||||||
|
|
||||||
|
def get_connections(self) -> list[ConnectionModel]:
|
||||||
|
return [
|
||||||
|
ConnectionModel.from_connection(connection)
|
||||||
|
for connection in self._flowgraph.connections
|
||||||
|
]
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from typing import Any, Literal, get_args
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from gnuradio.grc.core.blocks.block import Block
|
from gnuradio.grc.core.blocks.block import Block
|
||||||
|
from gnuradio.grc.core.Connection import Connection
|
||||||
from gnuradio.grc.core.params.param import Param
|
from gnuradio.grc.core.params.param import Param
|
||||||
from gnuradio.grc.core.ports.port import Port
|
from gnuradio.grc.core.ports.port import Port
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -41,7 +42,7 @@ SINK, SOURCE = get_args(DirectionType)
|
|||||||
|
|
||||||
class PortModel(BaseModel):
|
class PortModel(BaseModel):
|
||||||
parent: str
|
parent: str
|
||||||
key: str
|
key: int
|
||||||
name: str
|
name: str
|
||||||
dtype: str
|
dtype: str
|
||||||
direction: DirectionType
|
direction: DirectionType
|
||||||
@ -54,11 +55,25 @@ class PortModel(BaseModel):
|
|||||||
direction: DirectionType | None = None,
|
direction: DirectionType | None = None,
|
||||||
) -> PortModel:
|
) -> PortModel:
|
||||||
direction = direction or port._dir
|
direction = direction or port._dir
|
||||||
|
if not port.key.isnumeric():
|
||||||
|
raise ValueError("Currently not supporting named ports")
|
||||||
return cls(
|
return cls(
|
||||||
parent=port.parent.name,
|
parent=port.parent.name,
|
||||||
key=port.key,
|
key=int(port.key),
|
||||||
name=port.name,
|
name=port.name,
|
||||||
dtype=port.dtype,
|
dtype=port.dtype,
|
||||||
direction=direction,
|
direction=direction,
|
||||||
optional=port.optional,
|
optional=port.optional,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionModel(BaseModel):
|
||||||
|
source: PortModel
|
||||||
|
sink: PortModel
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_connection(cls, connection: Connection) -> "ConnectionModel":
|
||||||
|
return cls(
|
||||||
|
source=PortModel.from_port(connection.source_port),
|
||||||
|
sink=PortModel.from_port(connection.sink_port),
|
||||||
|
)
|
||||||
|
|||||||
@ -10,30 +10,28 @@ from gnuradio_mcp.models import SINK, SOURCE, ParamModel
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def flowgraph_middleware(platform: Platform):
|
def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware:
|
||||||
flowgraph = platform.make_flow_graph("")
|
flowgraph = platform.make_flow_graph("")
|
||||||
return FlowGraphMiddleware(flowgraph)
|
return FlowGraphMiddleware(flowgraph)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
def block_middleware(
|
||||||
block_name = flowgraph_middleware.add_block(block_key)
|
flowgraph_middleware: FlowGraphMiddleware, block_key: str
|
||||||
return flowgraph_middleware._flowgraph.get_block(block_name)
|
) -> BlockMiddleware:
|
||||||
|
return flowgraph_middleware.add_block(block_key)
|
||||||
|
|
||||||
|
|
||||||
def test_block_middleware_params(block: Block):
|
def test_block_middleware_params(block_middleware: BlockMiddleware):
|
||||||
middleware = BlockMiddleware(block)
|
check_param_models(block_middleware._block, block_middleware.params)
|
||||||
check_param_models(block, middleware.params)
|
|
||||||
|
|
||||||
|
|
||||||
def test_block_middleware_sinks(block: Block):
|
def test_block_middleware_sinks(block_middleware: BlockMiddleware):
|
||||||
middleware = BlockMiddleware(block)
|
check_port_models(block_middleware.sinks, block_middleware._block.sinks, SINK)
|
||||||
check_port_models(middleware.sinks, block.sinks, SINK)
|
|
||||||
|
|
||||||
|
|
||||||
def test_block_middleware_sources(block: Block):
|
def test_block_middleware_sources(block_middleware: BlockMiddleware):
|
||||||
middleware = BlockMiddleware(block)
|
check_port_models(block_middleware.sources, block_middleware._block.sources, SOURCE)
|
||||||
check_port_models(middleware.sources, block.sources, SOURCE)
|
|
||||||
|
|
||||||
|
|
||||||
def check_param_models(block: Block, params: list[ParamModel]):
|
def check_param_models(block: Block, params: list[ParamModel]):
|
||||||
@ -51,7 +49,7 @@ def check_port_models(port_models, ports, direction):
|
|||||||
assert isinstance(port_models, list)
|
assert isinstance(port_models, list)
|
||||||
assert len(port_models) == len(ports)
|
assert len(port_models) == len(ports)
|
||||||
for model, port in zip(port_models, ports):
|
for model, port in zip(port_models, ports):
|
||||||
assert model.key == port.key
|
assert model.key == int(port.key)
|
||||||
assert model.name == port.name
|
assert model.name == port.name
|
||||||
assert model.dtype == port.dtype
|
assert model.dtype == port.dtype
|
||||||
assert model.direction == direction
|
assert model.direction == direction
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from gnuradio.grc.core.platform import Platform
|
from gnuradio.grc.core.platform import Platform
|
||||||
|
|
||||||
|
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||||
from gnuradio_mcp.models import BlockModel
|
from gnuradio_mcp.models import BlockModel, ConnectionModel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -23,26 +26,21 @@ def initial_blocks(flowgraph_middleware: FlowGraphMiddleware):
|
|||||||
|
|
||||||
def test_flowgraph_block_addition_and_removal(
|
def test_flowgraph_block_addition_and_removal(
|
||||||
flowgraph_middleware: FlowGraphMiddleware,
|
flowgraph_middleware: FlowGraphMiddleware,
|
||||||
platform: Platform,
|
|
||||||
initial_blocks: list[BlockModel],
|
initial_blocks: list[BlockModel],
|
||||||
block_key: str,
|
block_key: str,
|
||||||
):
|
):
|
||||||
block_keys = platform.blocks.keys()
|
explicit_name = "my_custom_block_name"
|
||||||
assert block_keys, "No blocks available in platform library."
|
|
||||||
block_name = f"test_block_{block_key}"
|
flowgraph_middleware.add_block(block_key, explicit_name)
|
||||||
flowgraph_middleware.add_block(block_key, block_name)
|
|
||||||
blocks = flowgraph_middleware.blocks
|
blocks = flowgraph_middleware.blocks
|
||||||
assert all(b in blocks for b in initial_blocks)
|
assert all(b in blocks for b in initial_blocks)
|
||||||
assert any(b.key == block_key for b in blocks)
|
assert any(b.key == block_key for b in blocks)
|
||||||
|
|
||||||
flowgraph_middleware._flowgraph.remove_element(
|
flowgraph_middleware.remove_block(explicit_name)
|
||||||
flowgraph_middleware._flowgraph.get_block(block_name),
|
|
||||||
)
|
blocks = flowgraph_middleware.blocks
|
||||||
current_blocks = [
|
assert all(b in initial_blocks for b in blocks)
|
||||||
BlockModel(key=block.key, label=block.label)
|
|
||||||
for block in flowgraph_middleware._flowgraph.blocks
|
|
||||||
]
|
|
||||||
assert current_blocks == initial_blocks
|
|
||||||
|
|
||||||
|
|
||||||
def test_flowgraph_initial_state(
|
def test_flowgraph_initial_state(
|
||||||
@ -53,35 +51,68 @@ def test_flowgraph_initial_state(
|
|||||||
|
|
||||||
|
|
||||||
def test_block_naming(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
def test_block_naming(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||||
# Explicit name
|
|
||||||
explicit_name = "my_custom_block_name"
|
explicit_name = "my_custom_block_name"
|
||||||
flowgraph_middleware.add_block(block_key, explicit_name)
|
|
||||||
block = flowgraph_middleware._flowgraph.get_block(explicit_name)
|
|
||||||
assert block.params["id"].get_value() == explicit_name
|
|
||||||
# Remove for clean state
|
|
||||||
flowgraph_middleware._flowgraph.remove_element(block)
|
|
||||||
|
|
||||||
# Implicit name (should use get_unique_id logic)
|
block = flowgraph_middleware.add_block(block_key, explicit_name)
|
||||||
flowgraph_middleware.add_block(block_key)
|
|
||||||
# The last block added should be the last in the blocks list
|
assert block.name == explicit_name
|
||||||
last_block = flowgraph_middleware._flowgraph.blocks[-1]
|
|
||||||
# The id param should match the block's name
|
|
||||||
assert last_block.params["id"].get_value() == last_block.name
|
|
||||||
# Remove for clean state
|
|
||||||
flowgraph_middleware._flowgraph.remove_element(last_block)
|
|
||||||
|
|
||||||
|
|
||||||
def test_block_unique_names_for_same_type(
|
def test_block_unique_names_for_same_type(
|
||||||
flowgraph_middleware: FlowGraphMiddleware, block_key: str
|
flowgraph_middleware: FlowGraphMiddleware, block_key: str
|
||||||
):
|
):
|
||||||
# Add two blocks of the same type without explicit names
|
first_block = flowgraph_middleware.add_block(block_key)
|
||||||
flowgraph_middleware.add_block(block_key)
|
second_block = flowgraph_middleware.add_block(block_key)
|
||||||
first_block = flowgraph_middleware._flowgraph.blocks[-1]
|
|
||||||
flowgraph_middleware.add_block(block_key)
|
|
||||||
second_block = flowgraph_middleware._flowgraph.blocks[-1]
|
|
||||||
# They should have different names
|
|
||||||
assert first_block.name != second_block.name
|
assert first_block.name != second_block.name
|
||||||
assert first_block.params["id"].get_value() != second_block.params["id"].get_value()
|
|
||||||
# Clean up
|
|
||||||
flowgraph_middleware._flowgraph.remove_element(first_block)
|
@pytest.mark.parametrize(
|
||||||
flowgraph_middleware._flowgraph.remove_element(second_block)
|
"block_key, sinks_number, sources_number",
|
||||||
|
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
|
||||||
|
)
|
||||||
|
def test_block_connections(
|
||||||
|
flowgraph_middleware: FlowGraphMiddleware,
|
||||||
|
block_key: str,
|
||||||
|
sources_number: int,
|
||||||
|
sinks_number: int,
|
||||||
|
):
|
||||||
|
source_block = flowgraph_middleware.add_block(block_key)
|
||||||
|
dest_block = flowgraph_middleware.add_block(block_key)
|
||||||
|
|
||||||
|
for connection in util_iter_possible_connections(source_block, dest_block):
|
||||||
|
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
|
||||||
|
connections = flowgraph_middleware.get_connections()
|
||||||
|
assert any(
|
||||||
|
c.source.key == connection.source.key and c.sink.key == connection.sink.key
|
||||||
|
for c in connections
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(flowgraph_middleware.get_connections()) == sources_number * sinks_number
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||||
|
source_block = flowgraph_middleware.add_block(block_key)
|
||||||
|
dest_block = flowgraph_middleware.add_block(block_key)
|
||||||
|
|
||||||
|
for connection in util_iter_possible_connections(source_block, dest_block):
|
||||||
|
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
|
||||||
|
|
||||||
|
for connection in util_iter_possible_connections(source_block, dest_block):
|
||||||
|
flowgraph_middleware.disconnect_blocks(connection.source, connection.sink)
|
||||||
|
connections = flowgraph_middleware.get_connections()
|
||||||
|
assert not any(
|
||||||
|
c.source.key == connection.source.key and c.sink.key == connection.sink.key
|
||||||
|
for c in connections
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(flowgraph_middleware.get_connections()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def util_iter_possible_connections(
|
||||||
|
source_block: BlockMiddleware,
|
||||||
|
dest_block: BlockMiddleware,
|
||||||
|
) -> Generator[ConnectionModel]:
|
||||||
|
for sink in source_block.sinks:
|
||||||
|
for source in dest_block.sources:
|
||||||
|
yield ConnectionModel(source=source, sink=sink)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user