main - feat: Implement block connection
This commit is contained in:
parent
73bf514fc1
commit
5485413efd
@ -9,14 +9,39 @@ class BlockMiddleware:
|
||||
def __init__(self, block: Block):
|
||||
self._block = block
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._block.name
|
||||
|
||||
# TODO: Check if rewrite is needed
|
||||
|
||||
@property
|
||||
def params(self) -> list[ParamModel]:
|
||||
return [ParamModel.from_param(param) for param in self._block.params.values()]
|
||||
|
||||
@property
|
||||
def sinks(self) -> list[PortModel]:
|
||||
return [PortModel.from_port(port, SINK) for port in self._block.sinks]
|
||||
self._rewrite()
|
||||
ports = []
|
||||
for port in self._block.sinks:
|
||||
try:
|
||||
port_model = PortModel.from_port(port, SINK)
|
||||
ports.append(port_model)
|
||||
except ValueError:
|
||||
pass
|
||||
return ports
|
||||
|
||||
@property
|
||||
def sources(self) -> list[PortModel]:
|
||||
return [PortModel.from_port(port, SOURCE) for port in self._block.sources]
|
||||
self._rewrite()
|
||||
ports = []
|
||||
for port in self._block.sources:
|
||||
try:
|
||||
port_model = PortModel.from_port(port, SOURCE)
|
||||
ports.append(port_model)
|
||||
except ValueError:
|
||||
pass
|
||||
return ports
|
||||
|
||||
def _rewrite(self):
|
||||
self._block.rewrite()
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from gnuradio.grc.core.blocks.block import Block
|
||||
from gnuradio.grc.core.FlowGraph import FlowGraph
|
||||
|
||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||
from gnuradio_mcp.models import BlockModel
|
||||
from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel
|
||||
from gnuradio_mcp.utils import get_unique_id
|
||||
|
||||
|
||||
class FlowGraphMiddleware:
|
||||
def __init__(self, flowgraph: FlowGraph):
|
||||
self._flowgraph = flowgraph
|
||||
self._blocks: Dict[str, BlockMiddleware] = {}
|
||||
|
||||
@property
|
||||
def blocks(self) -> list[BlockModel]:
|
||||
@ -20,16 +22,46 @@ class FlowGraphMiddleware:
|
||||
for block in self._flowgraph.blocks
|
||||
]
|
||||
|
||||
def add_block(self, block_type: str, block_name: Optional[str] = None) -> str:
|
||||
def add_block(
|
||||
self, block_type: str, block_name: Optional[str] = None
|
||||
) -> BlockMiddleware:
|
||||
block_name = block_name or get_unique_id(self._flowgraph.blocks)
|
||||
block = self._flowgraph.new_block(block_type)
|
||||
block.params["id"].set_value(block_name)
|
||||
return block_name
|
||||
self._blocks[block_name] = BlockMiddleware(block)
|
||||
return self._blocks[block_name]
|
||||
|
||||
def remove_block(self, block_name: str) -> None:
|
||||
block = self._flowgraph.get_block(block_name)
|
||||
self._flowgraph.remove_element(block)
|
||||
del self._blocks[block_name]
|
||||
|
||||
def get_block(self, block_name: str) -> BlockMiddleware:
|
||||
block = self._flowgraph.get_block(block_name)
|
||||
return BlockMiddleware(block)
|
||||
# TODO: Check if calling two times you get different results
|
||||
return self._blocks[block_name]
|
||||
|
||||
def connect_blocks(
|
||||
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||
) -> None:
|
||||
def get_block_by_port_model(port_model: PortModel) -> Block:
|
||||
return self._flowgraph.get_block(port_model.parent)
|
||||
|
||||
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
||||
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
||||
self._flowgraph.connect(src_port, dst_port)
|
||||
|
||||
def disconnect_blocks(
|
||||
self, src_port_model: PortModel, dst_port_model: PortModel
|
||||
) -> None:
|
||||
def get_block_by_port_model(port_model: PortModel) -> Block:
|
||||
return self._flowgraph.get_block(port_model.parent)
|
||||
|
||||
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
|
||||
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
|
||||
self._flowgraph.disconnect(src_port, dst_port)
|
||||
|
||||
def get_connections(self) -> list[ConnectionModel]:
|
||||
return [
|
||||
ConnectionModel.from_connection(connection)
|
||||
for connection in self._flowgraph.connections
|
||||
]
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from gnuradio.grc.core.blocks.block import Block
|
||||
from gnuradio.grc.core.Connection import Connection
|
||||
from gnuradio.grc.core.params.param import Param
|
||||
from gnuradio.grc.core.ports.port import Port
|
||||
from pydantic import BaseModel
|
||||
@ -41,7 +42,7 @@ SINK, SOURCE = get_args(DirectionType)
|
||||
|
||||
class PortModel(BaseModel):
|
||||
parent: str
|
||||
key: str
|
||||
key: int
|
||||
name: str
|
||||
dtype: str
|
||||
direction: DirectionType
|
||||
@ -54,11 +55,25 @@ class PortModel(BaseModel):
|
||||
direction: DirectionType | None = None,
|
||||
) -> PortModel:
|
||||
direction = direction or port._dir
|
||||
if not port.key.isnumeric():
|
||||
raise ValueError("Currently not supporting named ports")
|
||||
return cls(
|
||||
parent=port.parent.name,
|
||||
key=port.key,
|
||||
key=int(port.key),
|
||||
name=port.name,
|
||||
dtype=port.dtype,
|
||||
direction=direction,
|
||||
optional=port.optional,
|
||||
)
|
||||
|
||||
|
||||
class ConnectionModel(BaseModel):
|
||||
source: PortModel
|
||||
sink: PortModel
|
||||
|
||||
@classmethod
|
||||
def from_connection(cls, connection: Connection) -> "ConnectionModel":
|
||||
return cls(
|
||||
source=PortModel.from_port(connection.source_port),
|
||||
sink=PortModel.from_port(connection.sink_port),
|
||||
)
|
||||
|
||||
@ -10,30 +10,28 @@ from gnuradio_mcp.models import SINK, SOURCE, ParamModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flowgraph_middleware(platform: Platform):
|
||||
def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware:
|
||||
flowgraph = platform.make_flow_graph("")
|
||||
return FlowGraphMiddleware(flowgraph)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||
block_name = flowgraph_middleware.add_block(block_key)
|
||||
return flowgraph_middleware._flowgraph.get_block(block_name)
|
||||
def block_middleware(
|
||||
flowgraph_middleware: FlowGraphMiddleware, block_key: str
|
||||
) -> BlockMiddleware:
|
||||
return flowgraph_middleware.add_block(block_key)
|
||||
|
||||
|
||||
def test_block_middleware_params(block: Block):
|
||||
middleware = BlockMiddleware(block)
|
||||
check_param_models(block, middleware.params)
|
||||
def test_block_middleware_params(block_middleware: BlockMiddleware):
|
||||
check_param_models(block_middleware._block, block_middleware.params)
|
||||
|
||||
|
||||
def test_block_middleware_sinks(block: Block):
|
||||
middleware = BlockMiddleware(block)
|
||||
check_port_models(middleware.sinks, block.sinks, SINK)
|
||||
def test_block_middleware_sinks(block_middleware: BlockMiddleware):
|
||||
check_port_models(block_middleware.sinks, block_middleware._block.sinks, SINK)
|
||||
|
||||
|
||||
def test_block_middleware_sources(block: Block):
|
||||
middleware = BlockMiddleware(block)
|
||||
check_port_models(middleware.sources, block.sources, SOURCE)
|
||||
def test_block_middleware_sources(block_middleware: BlockMiddleware):
|
||||
check_port_models(block_middleware.sources, block_middleware._block.sources, SOURCE)
|
||||
|
||||
|
||||
def check_param_models(block: Block, params: list[ParamModel]):
|
||||
@ -51,7 +49,7 @@ def check_port_models(port_models, ports, direction):
|
||||
assert isinstance(port_models, list)
|
||||
assert len(port_models) == len(ports)
|
||||
for model, port in zip(port_models, ports):
|
||||
assert model.key == port.key
|
||||
assert model.key == int(port.key)
|
||||
assert model.name == port.name
|
||||
assert model.dtype == port.dtype
|
||||
assert model.direction == direction
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from gnuradio.grc.core.platform import Platform
|
||||
|
||||
from gnuradio_mcp.middlewares.block import BlockMiddleware
|
||||
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
|
||||
from gnuradio_mcp.models import BlockModel
|
||||
from gnuradio_mcp.models import BlockModel, ConnectionModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -23,26 +26,21 @@ def initial_blocks(flowgraph_middleware: FlowGraphMiddleware):
|
||||
|
||||
def test_flowgraph_block_addition_and_removal(
|
||||
flowgraph_middleware: FlowGraphMiddleware,
|
||||
platform: Platform,
|
||||
initial_blocks: list[BlockModel],
|
||||
block_key: str,
|
||||
):
|
||||
block_keys = platform.blocks.keys()
|
||||
assert block_keys, "No blocks available in platform library."
|
||||
block_name = f"test_block_{block_key}"
|
||||
flowgraph_middleware.add_block(block_key, block_name)
|
||||
explicit_name = "my_custom_block_name"
|
||||
|
||||
flowgraph_middleware.add_block(block_key, explicit_name)
|
||||
|
||||
blocks = flowgraph_middleware.blocks
|
||||
assert all(b in blocks for b in initial_blocks)
|
||||
assert any(b.key == block_key for b in blocks)
|
||||
|
||||
flowgraph_middleware._flowgraph.remove_element(
|
||||
flowgraph_middleware._flowgraph.get_block(block_name),
|
||||
)
|
||||
current_blocks = [
|
||||
BlockModel(key=block.key, label=block.label)
|
||||
for block in flowgraph_middleware._flowgraph.blocks
|
||||
]
|
||||
assert current_blocks == initial_blocks
|
||||
flowgraph_middleware.remove_block(explicit_name)
|
||||
|
||||
blocks = flowgraph_middleware.blocks
|
||||
assert all(b in initial_blocks for b in blocks)
|
||||
|
||||
|
||||
def test_flowgraph_initial_state(
|
||||
@ -53,35 +51,68 @@ def test_flowgraph_initial_state(
|
||||
|
||||
|
||||
def test_block_naming(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||
# Explicit name
|
||||
explicit_name = "my_custom_block_name"
|
||||
flowgraph_middleware.add_block(block_key, explicit_name)
|
||||
block = flowgraph_middleware._flowgraph.get_block(explicit_name)
|
||||
assert block.params["id"].get_value() == explicit_name
|
||||
# Remove for clean state
|
||||
flowgraph_middleware._flowgraph.remove_element(block)
|
||||
|
||||
# Implicit name (should use get_unique_id logic)
|
||||
flowgraph_middleware.add_block(block_key)
|
||||
# The last block added should be the last in the blocks list
|
||||
last_block = flowgraph_middleware._flowgraph.blocks[-1]
|
||||
# The id param should match the block's name
|
||||
assert last_block.params["id"].get_value() == last_block.name
|
||||
# Remove for clean state
|
||||
flowgraph_middleware._flowgraph.remove_element(last_block)
|
||||
block = flowgraph_middleware.add_block(block_key, explicit_name)
|
||||
|
||||
assert block.name == explicit_name
|
||||
|
||||
|
||||
def test_block_unique_names_for_same_type(
|
||||
flowgraph_middleware: FlowGraphMiddleware, block_key: str
|
||||
):
|
||||
# Add two blocks of the same type without explicit names
|
||||
flowgraph_middleware.add_block(block_key)
|
||||
first_block = flowgraph_middleware._flowgraph.blocks[-1]
|
||||
flowgraph_middleware.add_block(block_key)
|
||||
second_block = flowgraph_middleware._flowgraph.blocks[-1]
|
||||
# They should have different names
|
||||
first_block = flowgraph_middleware.add_block(block_key)
|
||||
second_block = flowgraph_middleware.add_block(block_key)
|
||||
|
||||
assert first_block.name != second_block.name
|
||||
assert first_block.params["id"].get_value() != second_block.params["id"].get_value()
|
||||
# Clean up
|
||||
flowgraph_middleware._flowgraph.remove_element(first_block)
|
||||
flowgraph_middleware._flowgraph.remove_element(second_block)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_key, sinks_number, sources_number",
|
||||
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
|
||||
)
|
||||
def test_block_connections(
|
||||
flowgraph_middleware: FlowGraphMiddleware,
|
||||
block_key: str,
|
||||
sources_number: int,
|
||||
sinks_number: int,
|
||||
):
|
||||
source_block = flowgraph_middleware.add_block(block_key)
|
||||
dest_block = flowgraph_middleware.add_block(block_key)
|
||||
|
||||
for connection in util_iter_possible_connections(source_block, dest_block):
|
||||
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
|
||||
connections = flowgraph_middleware.get_connections()
|
||||
assert any(
|
||||
c.source.key == connection.source.key and c.sink.key == connection.sink.key
|
||||
for c in connections
|
||||
)
|
||||
|
||||
assert len(flowgraph_middleware.get_connections()) == sources_number * sinks_number
|
||||
|
||||
|
||||
def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
|
||||
source_block = flowgraph_middleware.add_block(block_key)
|
||||
dest_block = flowgraph_middleware.add_block(block_key)
|
||||
|
||||
for connection in util_iter_possible_connections(source_block, dest_block):
|
||||
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
|
||||
|
||||
for connection in util_iter_possible_connections(source_block, dest_block):
|
||||
flowgraph_middleware.disconnect_blocks(connection.source, connection.sink)
|
||||
connections = flowgraph_middleware.get_connections()
|
||||
assert not any(
|
||||
c.source.key == connection.source.key and c.sink.key == connection.sink.key
|
||||
for c in connections
|
||||
)
|
||||
|
||||
assert len(flowgraph_middleware.get_connections()) == 0
|
||||
|
||||
|
||||
def util_iter_possible_connections(
|
||||
source_block: BlockMiddleware,
|
||||
dest_block: BlockMiddleware,
|
||||
) -> Generator[ConnectionModel]:
|
||||
for sink in source_block.sinks:
|
||||
for source in dest_block.sources:
|
||||
yield ConnectionModel(source=source, sink=sink)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user