main - feat: Implement block connection

This commit is contained in:
Yoel Bassin 2025-04-27 00:02:19 +03:00
parent 73bf514fc1
commit 5485413efd
5 changed files with 163 additions and 62 deletions

View File

@ -9,14 +9,39 @@ class BlockMiddleware:
def __init__(self, block: Block): def __init__(self, block: Block):
self._block = block self._block = block
@property
def name(self) -> str:
return self._block.name
# TODO: Check if rewrite is needed
@property @property
def params(self) -> list[ParamModel]: def params(self) -> list[ParamModel]:
return [ParamModel.from_param(param) for param in self._block.params.values()] return [ParamModel.from_param(param) for param in self._block.params.values()]
@property @property
def sinks(self) -> list[PortModel]: def sinks(self) -> list[PortModel]:
return [PortModel.from_port(port, SINK) for port in self._block.sinks] self._rewrite()
ports = []
for port in self._block.sinks:
try:
port_model = PortModel.from_port(port, SINK)
ports.append(port_model)
except ValueError:
pass
return ports
@property @property
def sources(self) -> list[PortModel]: def sources(self) -> list[PortModel]:
return [PortModel.from_port(port, SOURCE) for port in self._block.sources] self._rewrite()
ports = []
for port in self._block.sources:
try:
port_model = PortModel.from_port(port, SOURCE)
ports.append(port_model)
except ValueError:
pass
return ports
def _rewrite(self):
self._block.rewrite()

View File

@ -1,17 +1,19 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Dict, Optional
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.models import BlockModel from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel
from gnuradio_mcp.utils import get_unique_id from gnuradio_mcp.utils import get_unique_id
class FlowGraphMiddleware: class FlowGraphMiddleware:
def __init__(self, flowgraph: FlowGraph): def __init__(self, flowgraph: FlowGraph):
self._flowgraph = flowgraph self._flowgraph = flowgraph
self._blocks: Dict[str, BlockMiddleware] = {}
@property @property
def blocks(self) -> list[BlockModel]: def blocks(self) -> list[BlockModel]:
@ -20,16 +22,46 @@ class FlowGraphMiddleware:
for block in self._flowgraph.blocks for block in self._flowgraph.blocks
] ]
def add_block(self, block_type: str, block_name: Optional[str] = None) -> str: def add_block(
self, block_type: str, block_name: Optional[str] = None
) -> BlockMiddleware:
block_name = block_name or get_unique_id(self._flowgraph.blocks) block_name = block_name or get_unique_id(self._flowgraph.blocks)
block = self._flowgraph.new_block(block_type) block = self._flowgraph.new_block(block_type)
block.params["id"].set_value(block_name) block.params["id"].set_value(block_name)
return block_name self._blocks[block_name] = BlockMiddleware(block)
return self._blocks[block_name]
def remove_block(self, block_name: str) -> None: def remove_block(self, block_name: str) -> None:
block = self._flowgraph.get_block(block_name) block = self._flowgraph.get_block(block_name)
self._flowgraph.remove_element(block) self._flowgraph.remove_element(block)
del self._blocks[block_name]
def get_block(self, block_name: str) -> BlockMiddleware: def get_block(self, block_name: str) -> BlockMiddleware:
block = self._flowgraph.get_block(block_name) # TODO: Check if calling two times you get different results
return BlockMiddleware(block) return self._blocks[block_name]
def connect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel
) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block:
return self._flowgraph.get_block(port_model.parent)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
self._flowgraph.connect(src_port, dst_port)
def disconnect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel
) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block:
return self._flowgraph.get_block(port_model.parent)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
self._flowgraph.disconnect(src_port, dst_port)
def get_connections(self) -> list[ConnectionModel]:
return [
ConnectionModel.from_connection(connection)
for connection in self._flowgraph.connections
]

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import Any, Literal, get_args from typing import Any, Literal, get_args
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.Connection import Connection
from gnuradio.grc.core.params.param import Param from gnuradio.grc.core.params.param import Param
from gnuradio.grc.core.ports.port import Port from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel from pydantic import BaseModel
@ -41,7 +42,7 @@ SINK, SOURCE = get_args(DirectionType)
class PortModel(BaseModel): class PortModel(BaseModel):
parent: str parent: str
key: str key: int
name: str name: str
dtype: str dtype: str
direction: DirectionType direction: DirectionType
@ -54,11 +55,25 @@ class PortModel(BaseModel):
direction: DirectionType | None = None, direction: DirectionType | None = None,
) -> PortModel: ) -> PortModel:
direction = direction or port._dir direction = direction or port._dir
if not port.key.isnumeric():
raise ValueError("Currently not supporting named ports")
return cls( return cls(
parent=port.parent.name, parent=port.parent.name,
key=port.key, key=int(port.key),
name=port.name, name=port.name,
dtype=port.dtype, dtype=port.dtype,
direction=direction, direction=direction,
optional=port.optional, optional=port.optional,
) )
class ConnectionModel(BaseModel):
source: PortModel
sink: PortModel
@classmethod
def from_connection(cls, connection: Connection) -> "ConnectionModel":
return cls(
source=PortModel.from_port(connection.source_port),
sink=PortModel.from_port(connection.sink_port),
)

View File

@ -10,30 +10,28 @@ from gnuradio_mcp.models import SINK, SOURCE, ParamModel
@pytest.fixture @pytest.fixture
def flowgraph_middleware(platform: Platform): def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware:
flowgraph = platform.make_flow_graph("") flowgraph = platform.make_flow_graph("")
return FlowGraphMiddleware(flowgraph) return FlowGraphMiddleware(flowgraph)
@pytest.fixture @pytest.fixture
def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str): def block_middleware(
block_name = flowgraph_middleware.add_block(block_key) flowgraph_middleware: FlowGraphMiddleware, block_key: str
return flowgraph_middleware._flowgraph.get_block(block_name) ) -> BlockMiddleware:
return flowgraph_middleware.add_block(block_key)
def test_block_middleware_params(block: Block): def test_block_middleware_params(block_middleware: BlockMiddleware):
middleware = BlockMiddleware(block) check_param_models(block_middleware._block, block_middleware.params)
check_param_models(block, middleware.params)
def test_block_middleware_sinks(block: Block): def test_block_middleware_sinks(block_middleware: BlockMiddleware):
middleware = BlockMiddleware(block) check_port_models(block_middleware.sinks, block_middleware._block.sinks, SINK)
check_port_models(middleware.sinks, block.sinks, SINK)
def test_block_middleware_sources(block: Block): def test_block_middleware_sources(block_middleware: BlockMiddleware):
middleware = BlockMiddleware(block) check_port_models(block_middleware.sources, block_middleware._block.sources, SOURCE)
check_port_models(middleware.sources, block.sources, SOURCE)
def check_param_models(block: Block, params: list[ParamModel]): def check_param_models(block: Block, params: list[ParamModel]):
@ -51,7 +49,7 @@ def check_port_models(port_models, ports, direction):
assert isinstance(port_models, list) assert isinstance(port_models, list)
assert len(port_models) == len(ports) assert len(port_models) == len(ports)
for model, port in zip(port_models, ports): for model, port in zip(port_models, ports):
assert model.key == port.key assert model.key == int(port.key)
assert model.name == port.name assert model.name == port.name
assert model.dtype == port.dtype assert model.dtype == port.dtype
assert model.direction == direction assert model.direction == direction

View File

@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
from typing import Generator
import pytest import pytest
from gnuradio.grc.core.platform import Platform from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import BlockModel from gnuradio_mcp.models import BlockModel, ConnectionModel
@pytest.fixture @pytest.fixture
@ -23,26 +26,21 @@ def initial_blocks(flowgraph_middleware: FlowGraphMiddleware):
def test_flowgraph_block_addition_and_removal( def test_flowgraph_block_addition_and_removal(
flowgraph_middleware: FlowGraphMiddleware, flowgraph_middleware: FlowGraphMiddleware,
platform: Platform,
initial_blocks: list[BlockModel], initial_blocks: list[BlockModel],
block_key: str, block_key: str,
): ):
block_keys = platform.blocks.keys() explicit_name = "my_custom_block_name"
assert block_keys, "No blocks available in platform library."
block_name = f"test_block_{block_key}" flowgraph_middleware.add_block(block_key, explicit_name)
flowgraph_middleware.add_block(block_key, block_name)
blocks = flowgraph_middleware.blocks blocks = flowgraph_middleware.blocks
assert all(b in blocks for b in initial_blocks) assert all(b in blocks for b in initial_blocks)
assert any(b.key == block_key for b in blocks) assert any(b.key == block_key for b in blocks)
flowgraph_middleware._flowgraph.remove_element( flowgraph_middleware.remove_block(explicit_name)
flowgraph_middleware._flowgraph.get_block(block_name),
) blocks = flowgraph_middleware.blocks
current_blocks = [ assert all(b in initial_blocks for b in blocks)
BlockModel(key=block.key, label=block.label)
for block in flowgraph_middleware._flowgraph.blocks
]
assert current_blocks == initial_blocks
def test_flowgraph_initial_state( def test_flowgraph_initial_state(
@ -53,35 +51,68 @@ def test_flowgraph_initial_state(
def test_block_naming(flowgraph_middleware: FlowGraphMiddleware, block_key: str): def test_block_naming(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
# Explicit name
explicit_name = "my_custom_block_name" explicit_name = "my_custom_block_name"
flowgraph_middleware.add_block(block_key, explicit_name)
block = flowgraph_middleware._flowgraph.get_block(explicit_name)
assert block.params["id"].get_value() == explicit_name
# Remove for clean state
flowgraph_middleware._flowgraph.remove_element(block)
# Implicit name (should use get_unique_id logic) block = flowgraph_middleware.add_block(block_key, explicit_name)
flowgraph_middleware.add_block(block_key)
# The last block added should be the last in the blocks list assert block.name == explicit_name
last_block = flowgraph_middleware._flowgraph.blocks[-1]
# The id param should match the block's name
assert last_block.params["id"].get_value() == last_block.name
# Remove for clean state
flowgraph_middleware._flowgraph.remove_element(last_block)
def test_block_unique_names_for_same_type( def test_block_unique_names_for_same_type(
flowgraph_middleware: FlowGraphMiddleware, block_key: str flowgraph_middleware: FlowGraphMiddleware, block_key: str
): ):
# Add two blocks of the same type without explicit names first_block = flowgraph_middleware.add_block(block_key)
flowgraph_middleware.add_block(block_key) second_block = flowgraph_middleware.add_block(block_key)
first_block = flowgraph_middleware._flowgraph.blocks[-1]
flowgraph_middleware.add_block(block_key)
second_block = flowgraph_middleware._flowgraph.blocks[-1]
# They should have different names
assert first_block.name != second_block.name assert first_block.name != second_block.name
assert first_block.params["id"].get_value() != second_block.params["id"].get_value()
# Clean up
flowgraph_middleware._flowgraph.remove_element(first_block) @pytest.mark.parametrize(
flowgraph_middleware._flowgraph.remove_element(second_block) "block_key, sinks_number, sources_number",
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
)
def test_block_connections(
flowgraph_middleware: FlowGraphMiddleware,
block_key: str,
sources_number: int,
sinks_number: int,
):
source_block = flowgraph_middleware.add_block(block_key)
dest_block = flowgraph_middleware.add_block(block_key)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
connections = flowgraph_middleware.get_connections()
assert any(
c.source.key == connection.source.key and c.sink.key == connection.sink.key
for c in connections
)
assert len(flowgraph_middleware.get_connections()) == sources_number * sinks_number
def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
source_block = flowgraph_middleware.add_block(block_key)
dest_block = flowgraph_middleware.add_block(block_key)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.disconnect_blocks(connection.source, connection.sink)
connections = flowgraph_middleware.get_connections()
assert not any(
c.source.key == connection.source.key and c.sink.key == connection.sink.key
for c in connections
)
assert len(flowgraph_middleware.get_connections()) == 0
def util_iter_possible_connections(
source_block: BlockMiddleware,
dest_block: BlockMiddleware,
) -> Generator[ConnectionModel]:
for sink in source_block.sinks:
for source in dest_block.sources:
yield ConnectionModel(source=source, sink=sink)