main - feat: Implement block connection

This commit is contained in:
Yoel Bassin 2025-04-27 00:02:19 +03:00
parent 73bf514fc1
commit 5485413efd
5 changed files with 163 additions and 62 deletions

View File

@ -9,14 +9,39 @@ class BlockMiddleware:
def __init__(self, block: Block):
self._block = block
@property
def name(self) -> str:
return self._block.name
# TODO: Check if rewrite is needed
@property
def params(self) -> list[ParamModel]:
return [ParamModel.from_param(param) for param in self._block.params.values()]
@property
def sinks(self) -> list[PortModel]:
return [PortModel.from_port(port, SINK) for port in self._block.sinks]
self._rewrite()
ports = []
for port in self._block.sinks:
try:
port_model = PortModel.from_port(port, SINK)
ports.append(port_model)
except ValueError:
pass
return ports
@property
def sources(self) -> list[PortModel]:
return [PortModel.from_port(port, SOURCE) for port in self._block.sources]
self._rewrite()
ports = []
for port in self._block.sources:
try:
port_model = PortModel.from_port(port, SOURCE)
ports.append(port_model)
except ValueError:
pass
return ports
def _rewrite(self):
self._block.rewrite()

View File

@ -1,17 +1,19 @@
from __future__ import annotations
from typing import Optional
from typing import Dict, Optional
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.models import BlockModel
from gnuradio_mcp.models import BlockModel, ConnectionModel, PortModel
from gnuradio_mcp.utils import get_unique_id
class FlowGraphMiddleware:
def __init__(self, flowgraph: FlowGraph):
self._flowgraph = flowgraph
self._blocks: Dict[str, BlockMiddleware] = {}
@property
def blocks(self) -> list[BlockModel]:
@ -20,16 +22,46 @@ class FlowGraphMiddleware:
for block in self._flowgraph.blocks
]
def add_block(self, block_type: str, block_name: Optional[str] = None) -> str:
def add_block(
self, block_type: str, block_name: Optional[str] = None
) -> BlockMiddleware:
block_name = block_name or get_unique_id(self._flowgraph.blocks)
block = self._flowgraph.new_block(block_type)
block.params["id"].set_value(block_name)
return block_name
self._blocks[block_name] = BlockMiddleware(block)
return self._blocks[block_name]
def remove_block(self, block_name: str) -> None:
block = self._flowgraph.get_block(block_name)
self._flowgraph.remove_element(block)
del self._blocks[block_name]
def get_block(self, block_name: str) -> BlockMiddleware:
block = self._flowgraph.get_block(block_name)
return BlockMiddleware(block)
# TODO: Check if calling two times you get different results
return self._blocks[block_name]
def connect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel
) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block:
return self._flowgraph.get_block(port_model.parent)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
self._flowgraph.connect(src_port, dst_port)
def disconnect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel
) -> None:
def get_block_by_port_model(port_model: PortModel) -> Block:
return self._flowgraph.get_block(port_model.parent)
src_port = get_block_by_port_model(src_port_model).sources[src_port_model.key]
dst_port = get_block_by_port_model(dst_port_model).sinks[dst_port_model.key]
self._flowgraph.disconnect(src_port, dst_port)
def get_connections(self) -> list[ConnectionModel]:
return [
ConnectionModel.from_connection(connection)
for connection in self._flowgraph.connections
]

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import Any, Literal, get_args
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.Connection import Connection
from gnuradio.grc.core.params.param import Param
from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel
@ -41,7 +42,7 @@ SINK, SOURCE = get_args(DirectionType)
class PortModel(BaseModel):
parent: str
key: str
key: int
name: str
dtype: str
direction: DirectionType
@ -54,11 +55,25 @@ class PortModel(BaseModel):
direction: DirectionType | None = None,
) -> PortModel:
direction = direction or port._dir
if not port.key.isnumeric():
raise ValueError("Currently not supporting named ports")
return cls(
parent=port.parent.name,
key=port.key,
key=int(port.key),
name=port.name,
dtype=port.dtype,
direction=direction,
optional=port.optional,
)
class ConnectionModel(BaseModel):
source: PortModel
sink: PortModel
@classmethod
def from_connection(cls, connection: Connection) -> "ConnectionModel":
return cls(
source=PortModel.from_port(connection.source_port),
sink=PortModel.from_port(connection.sink_port),
)

View File

@ -10,30 +10,28 @@ from gnuradio_mcp.models import SINK, SOURCE, ParamModel
@pytest.fixture
def flowgraph_middleware(platform: Platform):
def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware:
flowgraph = platform.make_flow_graph("")
return FlowGraphMiddleware(flowgraph)
@pytest.fixture
def block(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
block_name = flowgraph_middleware.add_block(block_key)
return flowgraph_middleware._flowgraph.get_block(block_name)
def block_middleware(
flowgraph_middleware: FlowGraphMiddleware, block_key: str
) -> BlockMiddleware:
return flowgraph_middleware.add_block(block_key)
def test_block_middleware_params(block: Block):
middleware = BlockMiddleware(block)
check_param_models(block, middleware.params)
def test_block_middleware_params(block_middleware: BlockMiddleware):
check_param_models(block_middleware._block, block_middleware.params)
def test_block_middleware_sinks(block: Block):
middleware = BlockMiddleware(block)
check_port_models(middleware.sinks, block.sinks, SINK)
def test_block_middleware_sinks(block_middleware: BlockMiddleware):
check_port_models(block_middleware.sinks, block_middleware._block.sinks, SINK)
def test_block_middleware_sources(block: Block):
middleware = BlockMiddleware(block)
check_port_models(middleware.sources, block.sources, SOURCE)
def test_block_middleware_sources(block_middleware: BlockMiddleware):
check_port_models(block_middleware.sources, block_middleware._block.sources, SOURCE)
def check_param_models(block: Block, params: list[ParamModel]):
@ -51,7 +49,7 @@ def check_port_models(port_models, ports, direction):
assert isinstance(port_models, list)
assert len(port_models) == len(ports)
for model, port in zip(port_models, ports):
assert model.key == port.key
assert model.key == int(port.key)
assert model.name == port.name
assert model.dtype == port.dtype
assert model.direction == direction

View File

@ -1,10 +1,13 @@
from __future__ import annotations
from typing import Generator
import pytest
from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.block import BlockMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import BlockModel
from gnuradio_mcp.models import BlockModel, ConnectionModel
@pytest.fixture
@ -23,26 +26,21 @@ def initial_blocks(flowgraph_middleware: FlowGraphMiddleware):
def test_flowgraph_block_addition_and_removal(
flowgraph_middleware: FlowGraphMiddleware,
platform: Platform,
initial_blocks: list[BlockModel],
block_key: str,
):
block_keys = platform.blocks.keys()
assert block_keys, "No blocks available in platform library."
block_name = f"test_block_{block_key}"
flowgraph_middleware.add_block(block_key, block_name)
explicit_name = "my_custom_block_name"
flowgraph_middleware.add_block(block_key, explicit_name)
blocks = flowgraph_middleware.blocks
assert all(b in blocks for b in initial_blocks)
assert any(b.key == block_key for b in blocks)
flowgraph_middleware._flowgraph.remove_element(
flowgraph_middleware._flowgraph.get_block(block_name),
)
current_blocks = [
BlockModel(key=block.key, label=block.label)
for block in flowgraph_middleware._flowgraph.blocks
]
assert current_blocks == initial_blocks
flowgraph_middleware.remove_block(explicit_name)
blocks = flowgraph_middleware.blocks
assert all(b in initial_blocks for b in blocks)
def test_flowgraph_initial_state(
@ -53,35 +51,68 @@ def test_flowgraph_initial_state(
def test_block_naming(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
# Explicit name
explicit_name = "my_custom_block_name"
flowgraph_middleware.add_block(block_key, explicit_name)
block = flowgraph_middleware._flowgraph.get_block(explicit_name)
assert block.params["id"].get_value() == explicit_name
# Remove for clean state
flowgraph_middleware._flowgraph.remove_element(block)
# Implicit name (should use get_unique_id logic)
flowgraph_middleware.add_block(block_key)
# The last block added should be the last in the blocks list
last_block = flowgraph_middleware._flowgraph.blocks[-1]
# The id param should match the block's name
assert last_block.params["id"].get_value() == last_block.name
# Remove for clean state
flowgraph_middleware._flowgraph.remove_element(last_block)
block = flowgraph_middleware.add_block(block_key, explicit_name)
assert block.name == explicit_name
def test_block_unique_names_for_same_type(
flowgraph_middleware: FlowGraphMiddleware, block_key: str
):
# Add two blocks of the same type without explicit names
flowgraph_middleware.add_block(block_key)
first_block = flowgraph_middleware._flowgraph.blocks[-1]
flowgraph_middleware.add_block(block_key)
second_block = flowgraph_middleware._flowgraph.blocks[-1]
# They should have different names
first_block = flowgraph_middleware.add_block(block_key)
second_block = flowgraph_middleware.add_block(block_key)
assert first_block.name != second_block.name
assert first_block.params["id"].get_value() != second_block.params["id"].get_value()
# Clean up
flowgraph_middleware._flowgraph.remove_element(first_block)
flowgraph_middleware._flowgraph.remove_element(second_block)
@pytest.mark.parametrize(
"block_key, sinks_number, sources_number",
[("blocks_add_xx", 2, 1), ("blocks_copy", 1, 1), ("blocks_selector", 2, 2)],
)
def test_block_connections(
flowgraph_middleware: FlowGraphMiddleware,
block_key: str,
sources_number: int,
sinks_number: int,
):
source_block = flowgraph_middleware.add_block(block_key)
dest_block = flowgraph_middleware.add_block(block_key)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
connections = flowgraph_middleware.get_connections()
assert any(
c.source.key == connection.source.key and c.sink.key == connection.sink.key
for c in connections
)
assert len(flowgraph_middleware.get_connections()) == sources_number * sinks_number
def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
source_block = flowgraph_middleware.add_block(block_key)
dest_block = flowgraph_middleware.add_block(block_key)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.disconnect_blocks(connection.source, connection.sink)
connections = flowgraph_middleware.get_connections()
assert not any(
c.source.key == connection.source.key and c.sink.key == connection.sink.key
for c in connections
)
assert len(flowgraph_middleware.get_connections()) == 0
def util_iter_possible_connections(
source_block: BlockMiddleware,
dest_block: BlockMiddleware,
) -> Generator[ConnectionModel]:
for sink in source_block.sinks:
for source in dest_block.sources:
yield ConnectionModel(source=source, sink=sink)