main - fix: Fix using already populated GRC file and connection of hidden ports

This commit is contained in:
Yoel Bassin 2025-04-27 20:51:48 +03:00
parent 6d6e4d0fc2
commit f5a0629da7
6 changed files with 84 additions and 36 deletions

View File

@ -38,11 +38,9 @@ class BlockMiddleware(ElementMiddleware):
self._rewrite()
ports = []
for port in self._block.sinks:
try:
port_model = PortModel.from_port(port, SINK)
if not port_model.hidden:
ports.append(port_model)
except ValueError:
pass
return ports
@property
@ -50,9 +48,7 @@ class BlockMiddleware(ElementMiddleware):
self._rewrite()
ports = []
for port in self._block.sources:
try:
port_model = PortModel.from_port(port, SOURCE)
if not port_model.hidden:
ports.append(port_model)
except ValueError:
pass
return ports

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Set
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph
@ -24,7 +25,7 @@ def get_port_from_port_model_in_port_list(
raise ValueError(f"Port not found: {port_model.key}")
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Port:
block_from_port_model = flowgraph.get_block(port_model.parent)
if port_model.direction == SOURCE:
return get_port_from_port_model_in_port_list(
@ -38,11 +39,14 @@ def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
raise ValueError(f"Invalid port direction: {port_model.direction}")
def set_block_name(block: Block, name: str):
block.params["id"].set_value(name)
class FlowGraphMiddleware(ElementMiddleware):
def __init__(self, flowgraph: FlowGraph):
super().__init__(flowgraph)
self._flowgraph = self._element
self._blocks: Set[BlockMiddleware] = set()
@property
def blocks(self) -> list[BlockModel]:
@ -50,22 +54,22 @@ class FlowGraphMiddleware(ElementMiddleware):
def add_block(
self, block_type: str, block_name: Optional[str] = None
) -> BlockMiddleware:
) -> BlockModel:
block_name = block_name or get_unique_id(self._flowgraph.blocks, block_type)
block = self._flowgraph.new_block(block_type)
assert block is not None, f"Failed to create block: {block_type}"
block_middleware = BlockMiddleware(block)
block_middleware.name = block_name
self._blocks.add(block_middleware)
return block_middleware
set_block_name(block, block_name)
return BlockModel.from_block(block)
def remove_block(self, block_name: str) -> None:
block_middleware = self.get_block(block_name)
self._flowgraph.remove_element(block_middleware._block)
self._blocks.remove(block_middleware)
@lru_cache(maxsize=None)
def get_block(self, block_name: str) -> BlockMiddleware:
return next(block for block in self._blocks if block.name == block_name)
return BlockMiddleware(
next(block for block in self._flowgraph.blocks if block.name == block_name)
)
def connect_blocks(
self, src_port_model: PortModel, dst_port_model: PortModel

View File

@ -1,12 +1,12 @@
from __future__ import annotations
from typing import Any, Literal, Type, get_args
from typing import Any, Literal, Protocol, Type, get_args
from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.Connection import Connection
from gnuradio.grc.core.params.param import Param
from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
class BlockTypeModel(BaseModel):
@ -18,6 +18,11 @@ class BlockTypeModel(BaseModel):
return cls(label=block.label, key=block.key)
class KeyedModel(Protocol):
def to_key(self) -> str:
...
class BlockModel(BaseModel):
label: str
name: str
@ -26,6 +31,9 @@ class BlockModel(BaseModel):
def from_block(cls, block: Block) -> BlockModel:
return cls(label=block.label, name=block.name)
def to_key(self) -> str:
return f"{self.label}:{self.name}"
class ParamModel(BaseModel):
parent: str
@ -44,6 +52,9 @@ class ParamModel(BaseModel):
value=param.get_value(),
)
def to_key(self) -> str:
return f"{self.parent}:{self.key}"
DirectionType = Literal["sink", "source"]
SINK, SOURCE = get_args(DirectionType)
@ -56,6 +67,7 @@ class PortModel(BaseModel):
dtype: str
direction: DirectionType
optional: bool = False
hidden: bool = False
@classmethod
def from_port(
@ -64,8 +76,6 @@ class PortModel(BaseModel):
direction: DirectionType | None = None,
) -> PortModel:
direction = direction or port._dir
if not port.key.isnumeric():
raise ValueError("Currently not supporting named ports")
return cls(
parent=port.parent.name,
key=port.key,
@ -73,8 +83,12 @@ class PortModel(BaseModel):
dtype=port.dtype,
direction=direction,
optional=port.optional,
hidden=port.hidden,
)
def to_key(self) -> str:
return f"{self.parent}:{self.direction}[{self.key}]"
class ConnectionModel(BaseModel):
source: PortModel
@ -87,8 +101,16 @@ class ConnectionModel(BaseModel):
sink=PortModel.from_port(connection.sink_port),
)
def to_key(self) -> str:
return f"{self.source.to_key()}-{self.sink.to_key()}"
class ErrorModel(BaseModel):
type: str
key: BaseModel
key: str
message: str
@field_validator("key", mode="before")
@classmethod
def transform_key(cls, v: KeyedModel) -> str:
return v.to_key()

View File

@ -45,6 +45,6 @@ def format_error_message(elem, msg) -> ErrorModel:
raise ValueError(f"Unsupported element type: {type(elem)}")
return ErrorModel(
type=type(model).__name__,
key=model,
key=model, # type: ignore
message=msg,
)

View File

@ -19,7 +19,8 @@ def flowgraph_middleware(platform: Platform) -> FlowGraphMiddleware:
def block_middleware(
flowgraph_middleware: FlowGraphMiddleware, block_key: str
) -> BlockMiddleware:
return flowgraph_middleware.add_block(block_key)
block_model = flowgraph_middleware.add_block(block_key)
return flowgraph_middleware.get_block(block_model.name)
def test_block_middleware_params(block_middleware: BlockMiddleware):
@ -52,7 +53,8 @@ def test_block_errors(
block_key: str,
initial_errors_number: int,
):
block_middleware = flowgraph_middleware.add_block(block_key)
block_model = flowgraph_middleware.add_block(block_key)
block_middleware = flowgraph_middleware.get_block(block_model.name)
for error in block_middleware.get_all_errors():
assert isinstance(error, ErrorModel)
assert len(block_middleware.get_all_errors()) == initial_errors_number

View File

@ -66,8 +66,17 @@ def test_block_unique_names_for_same_type(
def test_block_default_name(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
block = flowgraph_middleware.add_block(block_key)
assert block_key in block.name
block_model = flowgraph_middleware.add_block(block_key)
assert block_key in block_model.name
def test_remove_existing_block(flowgraph_middleware: FlowGraphMiddleware):
DEFAULT_VARIABLE_BLOCK_NAME = "samp_rate"
flowgraph_middleware.remove_block(DEFAULT_VARIABLE_BLOCK_NAME)
assert not any(
block.name == DEFAULT_VARIABLE_BLOCK_NAME
for block in flowgraph_middleware.blocks
)
@pytest.mark.parametrize(
@ -80,11 +89,11 @@ def test_block_connections(
sources_number: int,
sinks_number: int,
):
source_block = flowgraph_middleware.add_block(block_key)
dest_block = flowgraph_middleware.add_block(block_key)
source_block, dest_block = create_and_connect_blocks(
flowgraph_middleware, block_key, block_key
)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
connections = flowgraph_middleware.get_connections()
assert any(
c.source.key == connection.source.key and c.sink.key == connection.sink.key
@ -95,11 +104,9 @@ def test_block_connections(
def test_block_disconnection(flowgraph_middleware: FlowGraphMiddleware, block_key: str):
source_block = flowgraph_middleware.add_block(block_key)
dest_block = flowgraph_middleware.add_block(block_key)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
source_block, dest_block = create_and_connect_blocks(
flowgraph_middleware, block_key, block_key
)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.disconnect_blocks(connection.source, connection.sink)
@ -118,6 +125,23 @@ def test_default_flowgraph_errors(flowgraph_middleware: FlowGraphMiddleware):
assert len(flowgraph_middleware.get_all_errors()) == 0
def create_and_connect_blocks(
flowgraph_middleware: FlowGraphMiddleware,
source_block_key: str,
dest_block_key: str,
):
source_block_model = flowgraph_middleware.add_block(source_block_key)
dest_block_model = flowgraph_middleware.add_block(dest_block_key)
source_block = flowgraph_middleware.get_block(source_block_model.name)
dest_block = flowgraph_middleware.get_block(dest_block_model.name)
for connection in util_iter_possible_connections(source_block, dest_block):
flowgraph_middleware.connect_blocks(connection.source, connection.sink)
return source_block, dest_block
def util_iter_possible_connections(
source_block: BlockMiddleware,
dest_block: BlockMiddleware,