main - feat: Create new BlockTypeModel and allow named ports

This commit is contained in:
Yoel Bassin 2025-04-27 16:56:49 +03:00
parent 42484d6c7d
commit ac4105b210
6 changed files with 45 additions and 17 deletions

View File

@ -11,9 +11,10 @@ class ElementMiddleware:
def _rewrite(self): def _rewrite(self):
self._element.rewrite() self._element.rewrite()
def validate(self): def validate(self) -> bool:
self._rewrite() self._rewrite()
self._element.validate() self._element.validate()
return self._element.is_valid()
def get_all_errors(self) -> list[ErrorModel]: def get_all_errors(self) -> list[ErrorModel]:
self.validate() self.validate()

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Set
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.FlowGraph import FlowGraph from gnuradio.grc.core.FlowGraph import FlowGraph
from gnuradio.grc.core.ports.port import Port
from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.middlewares.block import BlockMiddleware from gnuradio_mcp.middlewares.block import BlockMiddleware
@ -14,12 +15,25 @@ if TYPE_CHECKING:
from gnuradio_mcp.middlewares.platform import PlatformMiddleware from gnuradio_mcp.middlewares.platform import PlatformMiddleware
def get_port_from_port_model_in_port_list(
port_list: list[Port], port_model: PortModel
) -> Block:
for port in port_list:
if port.key == port_model.key:
return port
raise ValueError(f"Port not found: {port_model.key}")
def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block: def get_port_from_port_model(flowgraph, port_model: PortModel) -> Block:
block_from_port_model = flowgraph.get_block(port_model.parent) block_from_port_model = flowgraph.get_block(port_model.parent)
if port_model.direction == SOURCE: if port_model.direction == SOURCE:
return block_from_port_model.sources[port_model.key] return get_port_from_port_model_in_port_list(
block_from_port_model.sources, port_model
)
elif port_model.direction == SINK: elif port_model.direction == SINK:
return block_from_port_model.sinks[port_model.key] return get_port_from_port_model_in_port_list(
block_from_port_model.sinks, port_model
)
else: else:
raise ValueError(f"Invalid port direction: {port_model.direction}") raise ValueError(f"Invalid port direction: {port_model.direction}")

View File

@ -4,7 +4,7 @@ from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.base import ElementMiddleware from gnuradio_mcp.middlewares.base import ElementMiddleware
from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware from gnuradio_mcp.middlewares.flowgraph import FlowGraphMiddleware
from gnuradio_mcp.models import BlockModel from gnuradio_mcp.models import BlockTypeModel
class PlatformMiddleware(ElementMiddleware): class PlatformMiddleware(ElementMiddleware):
@ -13,10 +13,14 @@ class PlatformMiddleware(ElementMiddleware):
self._platform = self._element self._platform = self._element
@property @property
def blocks(self) -> list[BlockModel]: def blocks(self) -> list[BlockTypeModel]:
return [ return [
BlockModel.from_block(block) for block in self._platform.blocks.values() BlockTypeModel.from_block_type(block)
for block in self._platform.blocks.values()
] ]
def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware: def make_flowgraph(self, filepath: str = "") -> FlowGraphMiddleware:
return FlowGraphMiddleware.from_file(self, filepath) return FlowGraphMiddleware.from_file(self, filepath)
def save_flowgraph(self, filepath: str, flowgraph: FlowGraphMiddleware) -> None:
self._platform.save_flow_graph(filepath, flowgraph._flowgraph)

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal, get_args from typing import Any, Literal, Type, get_args
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.Connection import Connection from gnuradio.grc.core.Connection import Connection
@ -9,13 +9,22 @@ from gnuradio.grc.core.ports.port import Port
from pydantic import BaseModel from pydantic import BaseModel
class BlockTypeModel(BaseModel):
label: str
key: str
@classmethod
def from_block_type(cls, block: Type[Block]) -> BlockTypeModel:
return cls(label=block.label, key=block.key)
class BlockModel(BaseModel): class BlockModel(BaseModel):
label: str label: str
key: str key: str
@classmethod @classmethod
def from_block(cls, block: Block) -> BlockModel: def from_block(cls, block: Block) -> BlockModel:
return cls(label=block.label, key=block.key) return cls(label=block.label, key=block.name)
class ParamModel(BaseModel): class ParamModel(BaseModel):
@ -42,7 +51,7 @@ SINK, SOURCE = get_args(DirectionType)
class PortModel(BaseModel): class PortModel(BaseModel):
parent: str parent: str
key: int key: str
name: str name: str
dtype: str dtype: str
direction: DirectionType direction: DirectionType
@ -59,7 +68,7 @@ class PortModel(BaseModel):
raise ValueError("Currently not supporting named ports") raise ValueError("Currently not supporting named ports")
return cls( return cls(
parent=port.parent.name, parent=port.parent.name,
key=int(port.key), key=port.key,
name=port.name, name=port.name,
dtype=port.dtype, dtype=port.dtype,
direction=direction, direction=direction,

View File

@ -75,7 +75,7 @@ def check_port_models(port_models, ports, direction):
assert isinstance(port_models, list) assert isinstance(port_models, list)
assert len(port_models) == len(ports) assert len(port_models) == len(ports)
for model, port in zip(port_models, ports): for model, port in zip(port_models, ports):
assert model.key == int(port.key) assert model.key == port.key
assert model.name == port.name assert model.name == port.name
assert model.dtype == port.dtype assert model.dtype == port.dtype
assert model.direction == direction assert model.direction == direction

View File

@ -3,18 +3,18 @@ from __future__ import annotations
from gnuradio.grc.core.blocks.block import Block from gnuradio.grc.core.blocks.block import Block
from gnuradio.grc.core.platform import Platform from gnuradio.grc.core.platform import Platform
from gnuradio_mcp.middlewares.platform import BlockModel, PlatformMiddleware from gnuradio_mcp.middlewares.platform import BlockTypeModel, PlatformMiddleware
def test_block_model_from_block(platform: Platform): def test_block_model_from_block(platform: Platform):
block = Block(platform) block_type = Block
model = BlockModel.from_block(block) model = BlockTypeModel.from_block_type(block_type)
assert model.label == block.label assert model.label == block_type.label
assert model.key == block.key assert model.key == block_type.key
def test_platform_middleware_blocks(platform: Platform): def test_platform_middleware_blocks(platform: Platform):
middleware = PlatformMiddleware(platform) middleware = PlatformMiddleware(platform)
block_models = middleware.blocks block_models = middleware.blocks
assert block_models # Checks that the list is not empty assert block_models # Checks that the list is not empty
assert all(isinstance(block_model, BlockModel) for block_model in block_models) assert all(isinstance(block_model, BlockTypeModel) for block_model in block_models)