main - fix: Fix using already populated GRC file and connection of hidden ports

This commit is contained in:
Yoel Bassin 2025-04-27 20:51:48 +03:00
parent 6d6e4d0fc2
commit f5a0629da7
6 changed files with 84 additions and 36 deletions

View File

@ -38,11 +38,9 @@ class BlockMiddleware(ElementMiddleware):
self._rewrite() self._rewrite()
ports = [] ports = []
for port in self._block.sinks: for port in self._block.sinks:
try: port_model = PortModel.from_port(port, SINK)
port_model = PortModel.from_port(port, SINK) if not port_model.hidden:
ports.append(port_model) ports.append(port_model)
except ValueError:
pass
return ports return ports
@property @property
@ -50,9 +48,7 @@ class BlockMiddleware(ElementMiddleware):
self._rewrite() self._rewrite()
ports = [] ports = []
for port in self._block.sources: for port in self._block.sources:
try: port_model = PortModel.from_port(port, SOURCE)
port_model = PortModel.from_port(port, SOURCE) if not port_model.hidden:
ports.append(port_model) ports.append(port_model)
except ValueError:
pass
return ports return ports

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Set from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph from gnuradio.grc.core.FlowGraph import FlowGraph
@ -24,7 +25,7 @@ def get_port_from_port_model_in_port_list(
raise ValueError(f"Port not found: {port_model.key}") raise ValueError(f"Port not found: {port_model.key}")
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block: def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port:
block_from_port_model = flowgraph.get_block(port_model.parent) block_from_port_model = flowgraph.get_block(port_model.parent)
if port_model.direction == SOURCE: if port_model.direction == SOURCE:
return get_port_from_port_model_in_port_list( return get_port_from_port_model_in_port_list(
@ -38,11 +39,14 @@ def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
raise ValueError(f"Invalid port direction: {port_model.direction}") raise ValueError(f"Invalid port direction: {port_model.direction}")
def set_block_name(block: Block, name: str):
block.params["id"].set_value(name)
class FlowGraphMiddleware(ElementMiddleware): class FlowGraphMiddleware(ElementMiddleware):
def __init__(self, flowgraph: FlowGraph): def __init__(self, flowgraph: FlowGraph):
super().__init__(flowgraph) super().__init__(flowgraph)
self._flowgraph = self._element self._flowgraph = self._element
self._blocks: Set[BlockMiddleware] = set()
@property @property
def blocks(self) -> list[BlockModel]: def blocks(self) -> list[BlockModel]:
@ -50,22 +54,22 @@ class FlowGraphMiddleware(ElementMiddleware):
def add_block( def add_block(
self, block_type: str, block_name: Optional[str] = None self, block_type: str, block_name: Optional[str] = None
) -> BlockMiddleware: ) -> BlockModel:
block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type) block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type)
block = self._flowgraph.new_block(block_type) block = self._flowgraph.new_block(block_type)
assert block is not None, f"Failed to create block: {block_type}" assert block is not None, f"Failed to create block: {block_type}"
block_middleware = BlockMiddleware(block) set_block_name(block, block_name)
block_middleware.name = block_name return BlockModel.from_block(block)
self._blocks.add(block_middleware)
return block_middleware
def remove_block(self, block_name: str) -> None: def remove_block(self, block_name: str) -> None:
block_middleware = self.get_block(block_name) block_middleware = self.get_block(block_name)
self._flowgraph.remove_element(block_middleware._block) self._flowgraph.remove_element(block_middleware._block)
self._blocks.remove(block_middleware)
@lru_cache(maxsize=None)
def get_block(self, block_name: str) -> BlockMiddleware: def get_block(self, block_name: str) -> BlockMiddleware:
return next(block for block in self._blocks if block.name == block_name) return BlockMiddleware(
next(block for block in self._flowgraph.blocks if block.name == block_name)
)
def connect_blocks( def connect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel self, src_port_model: PortModel, dst_port_model: PortModel

View File

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal, Type, get_args from typing import Any, Literal, Protocol, Type, get_args
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.Connection import Connection from gnuradio.grc.core.Connection import Connection
from gnuradio.grc.core.params.param import Param from gnuradio.grc.core.params.param import Param
from gnuradio.grc.core.ports.port import Port from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel from pydantic import BaseModel, field_validator
class BlockTypeModel(BaseModel): class BlockTypeModel(BaseModel):
@ -18,6 +18,11 @@ class BlockTypeModel(BaseModel):
return cls(label=block.label, key=block.key) return cls(label=block.label, key=block.key)
class KeyedModel(Protocol):
def to_key(self) -> str:
...
class BlockModel(BaseModel): class BlockModel(BaseModel):
label: str label: str
name: str name: str
@ -26,6 +31,9 @@ class BlockModel(BaseModel):
def from_block(cls, block: Block) -> BlockModel: def from_block(cls, block: Block) -> BlockModel:
return cls(label=block.label, name=block.name) return cls(label=block.label, name=block.name)
def to_key(self) -> str:
return f"{self.label}:{self.name}"
class ParamModel(BaseModel): class ParamModel(BaseModel):
parent: str parent: str
@ -44,6 +52,9 @@ class ParamModel(BaseModel):
value=param.get_value(), value=param.get_value(),
) )
def to_key(self) -> str:
return f"{self.parent}:{self.key}"
DirectionType = Literal["sink", "source"] DirectionType = Literal["sink", "source"]
SINK, SOURCE = get_args(DirectionType) SINK, SOURCE = get_args(DirectionType)
@ -56,6 +67,7 @@ class PortModel(BaseModel):
dtype: str dtype: str
direction: DirectionType direction: DirectionType
optional: bool = False optional: bool = False
hidden: bool = False
@classmethod @classmethod
def from_port( def from_port(
@ -64,8 +76,6 @@ class PortModel(BaseModel):
direction: DirectionType | None = None, direction: DirectionType | None = None,
) -> PortModel: ) -> PortModel:
direction = direction or port._dir direction = direction or port._dir
if not port.key.isnumeric():
raise ValueError("Currently not supporting named ports")
return cls( return cls(
parent=port.parent.name, parent=port.parent.name,
key=port.key, key=port.key,
@ -73,8 +83,12 @@ class PortModel(BaseModel):
dtype=port.dtype, dtype=port.dtype,
direction=direction, direction=direction,
optional=port.optional, optional=port.optional,
hidden=port.hidden,
) )
def to_key(self) -> str:
return f"{self.parent}:{self.direction}[{self.key}]"
class ConnectionModel(BaseModel): class ConnectionModel(BaseModel):
source: PortModel source: PortModel
@ -87,8 +101,16 @@ class ConnectionModel(BaseModel):
sink=PortModel.from_port(connection.sink_port), sink=PortModel.from_port(connection.sink_port),
) )
def to_key(self) -> str:
return f"{self.source.to_key()}-{self.sink.to_key()}"
class ErrorModel(BaseModel): class ErrorModel(BaseModel):
type: str type: str
key: BaseModel key: str
message: str message: str
@field_validator("key", mode="before")
@classmethod
def transform_key(cls, v: KeyedModel) -> str:
return v.to_key()

View File

@ -45,6 +45,6 @@ def format_error_message(elem, msg) -> ErrorModel:
raise ValueError(f"Unsupported element type: {type(elem)}") raise ValueError(f"Unsupported element type: {type(elem)}")
return ErrorModel( return ErrorModel(
type=type(model).__name__, type=type(model).__name__,
key=model, key=model, # type: ignore
message=msg, message=msg,
) )

View File

@ -19,7 +19,8 @@ def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware:
def block_middleware( def block_middleware(
flowgraph_middleware: FlowGraphMiddleware, block_key: str flowgraph_middleware: FlowGraphMiddleware, block_key: str
) -> BlockMiddleware: ) -> BlockMiddleware:
return flowgraph_middleware.add_block(block_key) block_model = flowgraph_middleware.add_block(block_key)
return flowgraph_middleware.get_block(block_model.name)
def test_block_middleware_params(block_middleware: BlockMiddleware): def test_block_middleware_params(block_middleware: BlockMiddleware):
@ -52,7 +53,8 @@ def test_block_errors(
block_key: str, block_key: str,
initial_errors_number: int, initial_errors_number: int,
): ):
block_middleware = flowgraph_middleware.add_block(block_key) block_model = flowgraph_middleware.add_block(block_key)
block_middleware = flowgraph_middleware.get_block(block_model.name)
for error in block_middleware.get_all_errors(): for error in block_middleware.get_all_errors():
assert isinstance(error, ErrorModel) assert isinstance(error, ErrorModel)
assert len(block_middleware.get_all_errors()) == initial_errors_number assert len(block_middleware.get_all_errors()) == initial_errors_number

View File

@ -66,8 +66,17 @@ def test_block_unique_names_for_same_type(
def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str): def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
block = flowgraph_middleware.add_block(block_key) block_model = flowgraph_middleware.add_block(block_key)
assert block_key in block.name assert block_key in block_model.name
def test_remove_existing_block(flowgraph_middleware: FlowGraphMiddleware):
DEFAULT_VARIABLE_BLOCK_NAME = "samp_rate"
flowgraph_middleware.remove_block(DEFAULT_VARIABLE_BLOCK_NAME)
assert not any(
block.name == DEFAULT_VARIABLE_BLOCK_NAME
for block in flowgraph_middleware.blocks
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -80,11 +89,11 @@ def test_block_connections(
sources_number: int, sources_number: int,
sinks_number: int, sinks_number: int,
): ):
source_block = flowgraph_middleware.add_block(block_key) source_block, dest_block = create_and_connect_blocks(
dest_block = flowgraph_middleware.add_block(block_key) flowgraph_middleware, block_key, block_key
)
for connection in util_iter_possible_connections(source_block, dest_block): for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
connections = flowgraph_middleware.get_connections() connections = flowgraph_middleware.get_connections()
assert any( assert any(
c.source.key == connection.source.key and c.sink.key == connection.sink.key c.source.key == connection.source.key and c.sink.key == connection.sink.key
@ -95,11 +104,9 @@ def test_block_connections(
def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str): def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
source_block = flowgraph_middleware.add_block(block_key) source_block, dest_block = create_and_connect_blocks(
dest_block = flowgraph_middleware.add_block(block_key) flowgraph_middleware, block_key, block_key
)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
for connection in util_iter_possible_connections(source_block, dest_block): for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.disconnect_blocks(connection.source, connection.sink) flowgraph_middleware.disconnect_blocks(connection.source, connection.sink)
@ -118,6 +125,23 @@ def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware):
assert len(flowgraph_middleware.get_all_errors()) == 0 assert len(flowgraph_middleware.get_all_errors()) == 0
def create_and_connect_blocks(
flowgraph_middleware: FlowGraphMiddleware,
source_block_key: str,
dest_block_key: str,
):
source_block_model = flowgraph_middleware.add_block(source_block_key)
dest_block_model = flowgraph_middleware.add_block(dest_block_key)
source_block = flowgraph_middleware.get_block(source_block_model.name)
dest_block = flowgraph_middleware.get_block(dest_block_model.name)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
return source_block, dest_block
def util_iter_possible_connections( def util_iter_possible_connections(
source_block: BlockMiddleware, source_block: BlockMiddleware,
dest_block: BlockMiddleware, dest_block: BlockMiddleware,