runtime: Phase 2 ControlPort/Thrift integration
Add ControlPort/Thrift support as an alternative transport to XML-RPC: New middleware: - ThriftMiddleware wrapping GNURadioControlPortClient New MCP tools: - connect_controlport, disconnect_controlport - get_knobs (with regex filtering), set_knobs (atomic) - get_knob_properties (units, min/max, description) - get_performance_counters (throughput, timing, buffers) - post_message (PMT injection to block ports) Docker support: - enable_controlport param in launch_flowgraph - ENABLE_CONTROLPORT env in entrypoint.sh - ControlPort config generation in ~/.gnuradio/config.conf Models: KnobModel, KnobPropertiesModel, PerfCounterModel, ThriftConnectionInfoModel, plus ContainerModel updates.
This commit is contained in:
parent
4030633fde
commit
0afb2f5b6e
@ -19,8 +19,10 @@ WORKDIR /flowgraphs
|
||||
|
||||
ENV DISPLAY=:99
|
||||
ENV XMLRPC_PORT=8080
|
||||
ENV CONTROLPORT_PORT=9090
|
||||
|
||||
EXPOSE 8080
|
||||
EXPOSE 5900
|
||||
EXPOSE 9090
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
@ -17,5 +17,26 @@ if [ "${ENABLE_VNC:-0}" = "1" ]; then
|
||||
echo "VNC server on :5900"
|
||||
fi
|
||||
|
||||
# Enable ControlPort if requested (Phase 2: Thrift integration)
|
||||
if [ "${ENABLE_CONTROLPORT:-0}" = "1" ]; then
|
||||
mkdir -p ~/.gnuradio
|
||||
cat > ~/.gnuradio/config.conf << EOF
|
||||
[ControlPort]
|
||||
on = True
|
||||
edges_list = True
|
||||
|
||||
[thrift]
|
||||
port = ${CONTROLPORT_PORT:-9090}
|
||||
|
||||
[PerfCounters]
|
||||
on = ${ENABLE_PERF_COUNTERS:-True}
|
||||
export = True
|
||||
EOF
|
||||
echo "ControlPort enabled on port ${CONTROLPORT_PORT:-9090}"
|
||||
if [ "${ENABLE_PERF_COUNTERS:-True}" = "True" ]; then
|
||||
echo "Performance counters enabled"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Run the flowgraph (passed as CMD arguments)
|
||||
exec "$@"
|
||||
|
||||
@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_XMLRPC_PORT = 8080
|
||||
DEFAULT_VNC_PORT = 5900
|
||||
DEFAULT_CONTROLPORT_PORT = 9090 # Phase 2: Thrift ControlPort
|
||||
DEFAULT_STOP_TIMEOUT = 30 # Seconds to wait for graceful shutdown (coverage needs time)
|
||||
RUNTIME_IMAGE = "gnuradio-runtime:latest"
|
||||
COVERAGE_IMAGE = "gnuradio-coverage:latest"
|
||||
@ -31,7 +32,7 @@ class DockerMiddleware:
|
||||
|
||||
@classmethod
|
||||
def create(cls) -> DockerMiddleware | None:
|
||||
"""Attempt to create a DockerMiddleware. Returns None if Docker is unavailable."""
|
||||
"""Create a DockerMiddleware. Returns None if Docker unavailable."""
|
||||
try:
|
||||
import docker
|
||||
|
||||
@ -49,6 +50,9 @@ class DockerMiddleware:
|
||||
xmlrpc_port: int = DEFAULT_XMLRPC_PORT,
|
||||
enable_vnc: bool = False,
|
||||
enable_coverage: bool = False,
|
||||
enable_controlport: bool = False,
|
||||
controlport_port: int = DEFAULT_CONTROLPORT_PORT,
|
||||
enable_perf_counters: bool = True,
|
||||
device_paths: list[str] | None = None,
|
||||
) -> ContainerModel:
|
||||
"""Launch a flowgraph in a Docker container with Xvfb.
|
||||
@ -59,6 +63,9 @@ class DockerMiddleware:
|
||||
xmlrpc_port: Port for XML-RPC variable control
|
||||
enable_vnc: Enable VNC server for visual debugging
|
||||
enable_coverage: Use coverage image and collect Python coverage data
|
||||
enable_controlport: Enable ControlPort/Thrift for advanced control
|
||||
controlport_port: Port for ControlPort (default 9090)
|
||||
enable_perf_counters: Enable performance counters (requires controlport)
|
||||
device_paths: Host device paths to pass through (e.g., /dev/ttyUSB0)
|
||||
"""
|
||||
fg_path = Path(flowgraph_path).resolve()
|
||||
@ -73,12 +80,18 @@ class DockerMiddleware:
|
||||
env["ENABLE_VNC"] = "1"
|
||||
if enable_coverage:
|
||||
env["ENABLE_COVERAGE"] = "1"
|
||||
if enable_controlport:
|
||||
env["ENABLE_CONTROLPORT"] = "1"
|
||||
env["CONTROLPORT_PORT"] = str(controlport_port)
|
||||
env["ENABLE_PERF_COUNTERS"] = "True" if enable_perf_counters else "False"
|
||||
|
||||
ports: dict[str, int] = {f"{xmlrpc_port}/tcp": xmlrpc_port}
|
||||
vnc_port: int | None = None
|
||||
if enable_vnc:
|
||||
vnc_port = DEFAULT_VNC_PORT
|
||||
ports[f"{vnc_port}/tcp"] = vnc_port
|
||||
if enable_controlport:
|
||||
ports[f"{controlport_port}/tcp"] = controlport_port
|
||||
|
||||
volumes = {
|
||||
str(fg_path.parent): {
|
||||
@ -115,6 +128,8 @@ class DockerMiddleware:
|
||||
"gr-mcp.xmlrpc-port": str(xmlrpc_port),
|
||||
"gr-mcp.vnc-enabled": "1" if enable_vnc else "0",
|
||||
"gr-mcp.coverage-enabled": "1" if enable_coverage else "0",
|
||||
"gr-mcp.controlport-enabled": "1" if enable_controlport else "0",
|
||||
"gr-mcp.controlport-port": str(controlport_port),
|
||||
},
|
||||
)
|
||||
|
||||
@ -125,8 +140,10 @@ class DockerMiddleware:
|
||||
flowgraph_path=str(fg_path),
|
||||
xmlrpc_port=xmlrpc_port,
|
||||
vnc_port=vnc_port,
|
||||
controlport_port=controlport_port if enable_controlport else None,
|
||||
device_paths=device_paths or [],
|
||||
coverage_enabled=enable_coverage,
|
||||
controlport_enabled=enable_controlport,
|
||||
)
|
||||
|
||||
def list_containers(self) -> list[ContainerModel]:
|
||||
@ -137,17 +154,33 @@ class DockerMiddleware:
|
||||
result = []
|
||||
for c in containers:
|
||||
labels = c.labels
|
||||
controlport_enabled = labels.get("gr-mcp.controlport-enabled") == "1"
|
||||
result.append(
|
||||
ContainerModel(
|
||||
name=c.name,
|
||||
container_id=c.id[:12],
|
||||
status=c.status,
|
||||
flowgraph_path=labels.get("gr-mcp.flowgraph", ""),
|
||||
xmlrpc_port=int(labels.get("gr-mcp.xmlrpc-port", DEFAULT_XMLRPC_PORT)),
|
||||
vnc_port=DEFAULT_VNC_PORT
|
||||
if labels.get("gr-mcp.vnc-enabled") == "1" and c.status == "running"
|
||||
else None,
|
||||
xmlrpc_port=int(
|
||||
labels.get("gr-mcp.xmlrpc-port", DEFAULT_XMLRPC_PORT)
|
||||
),
|
||||
vnc_port=(
|
||||
DEFAULT_VNC_PORT
|
||||
if labels.get("gr-mcp.vnc-enabled") == "1"
|
||||
and c.status == "running"
|
||||
else None
|
||||
),
|
||||
controlport_port=(
|
||||
int(
|
||||
labels.get(
|
||||
"gr-mcp.controlport-port", DEFAULT_CONTROLPORT_PORT
|
||||
)
|
||||
)
|
||||
if controlport_enabled and c.status == "running"
|
||||
else None
|
||||
),
|
||||
coverage_enabled=labels.get("gr-mcp.coverage-enabled") == "1",
|
||||
controlport_enabled=controlport_enabled,
|
||||
)
|
||||
)
|
||||
return result
|
||||
@ -171,7 +204,9 @@ class DockerMiddleware:
|
||||
logger.warning(
|
||||
"Container %s didn't stop gracefully within %ds, "
|
||||
"coverage data may be lost: %s",
|
||||
name, timeout, e
|
||||
name,
|
||||
timeout,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
@ -208,9 +243,7 @@ class DockerMiddleware:
|
||||
def get_xmlrpc_port(self, name: str) -> int:
|
||||
"""Get the XML-RPC port for a container."""
|
||||
container = self._client.containers.get(name)
|
||||
return int(
|
||||
container.labels.get("gr-mcp.xmlrpc-port", DEFAULT_XMLRPC_PORT)
|
||||
)
|
||||
return int(container.labels.get("gr-mcp.xmlrpc-port", DEFAULT_XMLRPC_PORT))
|
||||
|
||||
def is_coverage_enabled(self, name: str) -> bool:
|
||||
"""Check if coverage is enabled for a container."""
|
||||
@ -220,3 +253,15 @@ class DockerMiddleware:
|
||||
def get_coverage_dir(self, name: str) -> Path:
|
||||
"""Get the host-side coverage directory for a container."""
|
||||
return Path(HOST_COVERAGE_BASE) / name
|
||||
|
||||
def is_controlport_enabled(self, name: str) -> bool:
|
||||
"""Check if ControlPort is enabled for a container."""
|
||||
container = self._client.containers.get(name)
|
||||
return container.labels.get("gr-mcp.controlport-enabled") == "1"
|
||||
|
||||
def get_controlport_port(self, name: str) -> int:
|
||||
"""Get the ControlPort Thrift port for a container."""
|
||||
container = self._client.containers.get(name)
|
||||
return int(
|
||||
container.labels.get("gr-mcp.controlport-port", DEFAULT_CONTROLPORT_PORT)
|
||||
)
|
||||
|
||||
347
src/gnuradio_mcp/middlewares/thrift.py
Normal file
347
src/gnuradio_mcp/middlewares/thrift.py
Normal file
@ -0,0 +1,347 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from gnuradio_mcp.models import (
|
||||
KNOB_TYPE_NAMES,
|
||||
KnobModel,
|
||||
KnobPropertiesModel,
|
||||
PerfCounterModel,
|
||||
ThriftConnectionInfoModel,
|
||||
VariableModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
THRIFT_TIMEOUT = 5
|
||||
DEFAULT_THRIFT_PORT = 9090
|
||||
|
||||
# Performance counter knob suffixes (used to identify perf counters)
|
||||
PERF_COUNTER_SUFFIXES = [
|
||||
"::avg throughput",
|
||||
"::avg work time",
|
||||
"::total work time",
|
||||
"::avg nproduced",
|
||||
"::avg input % full",
|
||||
"::avg output % full",
|
||||
"::var nproduced",
|
||||
"::var work time",
|
||||
]
|
||||
|
||||
|
||||
class ThriftMiddleware:
|
||||
"""Wraps GNU Radio's ControlPort Thrift client for runtime control.
|
||||
|
||||
ControlPort provides richer functionality than XML-RPC:
|
||||
- Native type support (complex numbers, vectors)
|
||||
- Performance counters (throughput, timing, buffer utilization)
|
||||
- Knob metadata (units, min/max, descriptions)
|
||||
- PMT message injection
|
||||
- Regex-based knob queries
|
||||
|
||||
Knobs are named using the pattern: block_alias::varname
|
||||
(e.g., "sig_source0::frequency")
|
||||
|
||||
Requires ControlPort to be enabled in GNU Radio config:
|
||||
[ControlPort]
|
||||
on = True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Any, # RPCConnectionThrift
|
||||
host: str,
|
||||
port: int,
|
||||
):
|
||||
self._client = client
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
@classmethod
|
||||
def connect(
|
||||
cls,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = DEFAULT_THRIFT_PORT,
|
||||
) -> ThriftMiddleware:
|
||||
"""Connect to a GNU Radio ControlPort server.
|
||||
|
||||
Args:
|
||||
host: Hostname or IP address
|
||||
port: ControlPort Thrift port (default 9090)
|
||||
|
||||
Raises:
|
||||
ImportError: If gnuradio.ctrlport is not available
|
||||
ConnectionError: If connection fails
|
||||
"""
|
||||
try:
|
||||
from gnuradio.ctrlport.GNURadioControlPortClient import (
|
||||
GNURadioControlPortClient,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"GNU Radio ControlPort not available. "
|
||||
"Ensure GNU Radio is installed with Thrift support."
|
||||
) from e
|
||||
|
||||
try:
|
||||
radio = GNURadioControlPortClient(host=host, port=port)
|
||||
logger.info("Connected to ControlPort at %s:%d", host, port)
|
||||
return cls(radio.client, host, port)
|
||||
except Exception as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to ControlPort at {host}:{port}: {e}"
|
||||
) from e
|
||||
|
||||
def get_connection_info(
|
||||
self, container_name: str | None = None
|
||||
) -> ThriftConnectionInfoModel:
|
||||
"""Return connection metadata including knob count."""
|
||||
try:
|
||||
knobs = self._client.getKnobs([])
|
||||
knob_count = len(knobs)
|
||||
except Exception:
|
||||
knob_count = 0
|
||||
|
||||
return ThriftConnectionInfoModel(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
container_name=container_name,
|
||||
protocol="thrift",
|
||||
knob_count=knob_count,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────
|
||||
# Variable Operations (XML-RPC compatible API)
|
||||
# ──────────────────────────────────────────
|
||||
|
||||
def list_variables(self) -> list[VariableModel]:
|
||||
"""List all ControlPort knobs as variables.
|
||||
|
||||
Filters out performance counters to match XML-RPC behavior.
|
||||
"""
|
||||
knobs = self._client.getKnobs([])
|
||||
variables = []
|
||||
for name, knob in knobs.items():
|
||||
# Skip performance counters
|
||||
if self._is_perf_counter(name):
|
||||
continue
|
||||
variables.append(VariableModel(name=name, value=knob.value))
|
||||
return variables
|
||||
|
||||
def get_variable(self, name: str) -> Any:
|
||||
"""Get a variable value by name."""
|
||||
knobs = self._client.getKnobs([name])
|
||||
if name not in knobs:
|
||||
raise KeyError(f"Knob not found: {name}")
|
||||
return knobs[name].value
|
||||
|
||||
def set_variable(self, name: str, value: Any) -> bool:
|
||||
"""Set a variable value.
|
||||
|
||||
The knob type is inferred from the existing knob's type.
|
||||
"""
|
||||
# Get current knob to determine type
|
||||
knobs = self._client.getKnobs([name])
|
||||
if name not in knobs:
|
||||
raise KeyError(f"Knob not found: {name}")
|
||||
|
||||
current = knobs[name]
|
||||
# Create new knob with same type but new value
|
||||
from gnuradio.ctrlport.RPCConnectionThrift import RPCConnectionThrift
|
||||
|
||||
new_knob = RPCConnectionThrift.Knob(name, value, current.ktype)
|
||||
self._client.setKnobs({name: new_knob})
|
||||
return True
|
||||
|
||||
# ──────────────────────────────────────────
|
||||
# ControlPort-Specific Operations
|
||||
# ──────────────────────────────────────────
|
||||
|
||||
def get_knobs(self, pattern: str = "") -> list[KnobModel]:
|
||||
"""Get knobs, optionally filtered by regex pattern.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern for filtering knob names.
|
||||
Empty string returns all knobs.
|
||||
|
||||
Examples:
|
||||
get_knobs("") # All knobs
|
||||
get_knobs(".*frequency.*") # All frequency-related knobs
|
||||
get_knobs("sig_source0::.*") # All knobs for sig_source0
|
||||
"""
|
||||
if pattern:
|
||||
knobs = self._client.getRe([pattern])
|
||||
else:
|
||||
knobs = self._client.getKnobs([])
|
||||
|
||||
result = []
|
||||
for name, knob in knobs.items():
|
||||
knob_type = KNOB_TYPE_NAMES.get(knob.ktype, f"UNKNOWN({knob.ktype})")
|
||||
result.append(
|
||||
KnobModel(
|
||||
name=name,
|
||||
value=knob.value,
|
||||
knob_type=knob_type,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def set_knobs(self, knobs: dict[str, Any]) -> bool:
|
||||
"""Set multiple knobs atomically.
|
||||
|
||||
Args:
|
||||
knobs: Dict mapping knob names to new values.
|
||||
Types are inferred from existing knobs.
|
||||
"""
|
||||
if not knobs:
|
||||
return True
|
||||
|
||||
# Get current knobs to determine types
|
||||
current_knobs = self._client.getKnobs(list(knobs.keys()))
|
||||
|
||||
from gnuradio.ctrlport.RPCConnectionThrift import RPCConnectionThrift
|
||||
|
||||
to_set = {}
|
||||
for name, value in knobs.items():
|
||||
if name not in current_knobs:
|
||||
raise KeyError(f"Knob not found: {name}")
|
||||
current = current_knobs[name]
|
||||
to_set[name] = RPCConnectionThrift.Knob(name, value, current.ktype)
|
||||
|
||||
self._client.setKnobs(to_set)
|
||||
return True
|
||||
|
||||
def get_knob_properties(self, names: list[str]) -> list[KnobPropertiesModel]:
|
||||
"""Get metadata (units, min/max, description) for specified knobs.
|
||||
|
||||
Args:
|
||||
names: List of knob names to query.
|
||||
"""
|
||||
if not names:
|
||||
# Get all properties
|
||||
props = self._client.properties([])
|
||||
else:
|
||||
props = self._client.properties(names)
|
||||
|
||||
result = []
|
||||
for name, prop in props.items():
|
||||
knob_type = KNOB_TYPE_NAMES.get(prop.type, f"UNKNOWN({prop.type})")
|
||||
result.append(
|
||||
KnobPropertiesModel(
|
||||
name=name,
|
||||
description=prop.description or "",
|
||||
units=prop.units if hasattr(prop, "units") else None,
|
||||
min_value=prop.min.value if prop.min else None,
|
||||
max_value=prop.max.value if prop.max else None,
|
||||
default_value=prop.defaultvalue.value if prop.defaultvalue else None,
|
||||
knob_type=knob_type,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def get_performance_counters(
|
||||
self, block: str | None = None
|
||||
) -> list[PerfCounterModel]:
|
||||
"""Get performance metrics for blocks.
|
||||
|
||||
Args:
|
||||
block: Optional block alias to filter (e.g., "sig_source0").
|
||||
If None, returns metrics for all blocks.
|
||||
|
||||
Returns:
|
||||
List of PerfCounterModel with throughput, timing, and buffer stats.
|
||||
"""
|
||||
# Get all performance counter knobs
|
||||
if block:
|
||||
pattern = f"^{re.escape(block)}::.*"
|
||||
else:
|
||||
pattern = ""
|
||||
|
||||
all_knobs = self.get_knobs(pattern)
|
||||
|
||||
# Group by block
|
||||
blocks: dict[str, dict[str, Any]] = {}
|
||||
for knob in all_knobs:
|
||||
if not self._is_perf_counter(knob.name):
|
||||
continue
|
||||
|
||||
# Parse block name and metric
|
||||
parts = knob.name.split("::", 1)
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
block_name, metric = parts
|
||||
if block_name not in blocks:
|
||||
blocks[block_name] = {}
|
||||
blocks[block_name][metric] = knob.value
|
||||
|
||||
# Build PerfCounterModel for each block
|
||||
result = []
|
||||
for block_name, metrics in blocks.items():
|
||||
result.append(
|
||||
PerfCounterModel(
|
||||
block_name=block_name,
|
||||
avg_throughput=metrics.get("avg throughput", 0.0),
|
||||
avg_work_time_us=metrics.get("avg work time", 0.0),
|
||||
total_work_time_us=metrics.get("total work time", 0.0),
|
||||
avg_nproduced=metrics.get("avg nproduced", 0.0),
|
||||
input_buffer_pct=self._to_list(
|
||||
metrics.get("avg input % full", [])
|
||||
),
|
||||
output_buffer_pct=self._to_list(
|
||||
metrics.get("avg output % full", [])
|
||||
),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def post_message(self, block: str, port: str, message: Any) -> bool:
|
||||
"""Send a PMT message to a block's message port.
|
||||
|
||||
Args:
|
||||
block: Block alias (e.g., "msg_sink0")
|
||||
port: Message port name (e.g., "in")
|
||||
message: PMT message to send
|
||||
|
||||
Note:
|
||||
The message should be a PMT object. For simple cases,
|
||||
use pmt.intern("string") or pmt.to_pmt(dict).
|
||||
"""
|
||||
import pmt
|
||||
|
||||
# Ensure message is a PMT
|
||||
if not pmt.is_pmt(message):
|
||||
message = pmt.to_pmt(message)
|
||||
|
||||
self._client.postMessage(block, port, message)
|
||||
return True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the Thrift connection."""
|
||||
try:
|
||||
if self._client is not None:
|
||||
# The client handles cleanup in __del__
|
||||
self._client = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ──────────────────────────────────────────
|
||||
# Private Helpers
|
||||
# ──────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _is_perf_counter(name: str) -> bool:
|
||||
"""Check if a knob name is a performance counter."""
|
||||
return any(name.endswith(suffix) for suffix in PERF_COUNTER_SUFFIXES)
|
||||
|
||||
@staticmethod
|
||||
def _to_list(value: Any) -> list[float]:
|
||||
"""Convert a value to a list of floats."""
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [float(v) for v in value]
|
||||
elif value is None:
|
||||
return []
|
||||
else:
|
||||
return [float(value)]
|
||||
@ -19,8 +19,7 @@ class BlockTypeModel(BaseModel):
|
||||
|
||||
|
||||
class KeyedModel(Protocol):
|
||||
def to_key(self) -> str:
|
||||
...
|
||||
def to_key(self) -> str: ...
|
||||
|
||||
|
||||
class BlockModel(BaseModel):
|
||||
@ -128,8 +127,10 @@ class ContainerModel(BaseModel):
|
||||
flowgraph_path: str
|
||||
xmlrpc_port: int
|
||||
vnc_port: int | None = None
|
||||
controlport_port: int | None = None # Phase 2: Thrift ControlPort
|
||||
device_paths: list[str] = []
|
||||
coverage_enabled: bool = False
|
||||
controlport_enabled: bool = False # Phase 2: Thrift ControlPort
|
||||
|
||||
|
||||
class VariableModel(BaseModel):
|
||||
@ -158,6 +159,87 @@ class RuntimeStatusModel(BaseModel):
|
||||
containers: list[ContainerModel] = []
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# ControlPort/Thrift Models (Phase 2)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
|
||||
# Knob types from GNU Radio's ControlPort Thrift API
|
||||
# Maps to gnuradio.ctrlport.GNURadio.ttypes.BaseTypes
|
||||
KNOB_TYPE_NAMES = {
|
||||
0: "BOOL",
|
||||
1: "BYTE",
|
||||
2: "SHORT",
|
||||
3: "INT",
|
||||
4: "LONG",
|
||||
5: "DOUBLE",
|
||||
6: "STRING",
|
||||
7: "COMPLEX",
|
||||
8: "F32VECTOR",
|
||||
9: "F64VECTOR",
|
||||
10: "S64VECTOR",
|
||||
11: "S32VECTOR",
|
||||
12: "S16VECTOR",
|
||||
13: "S8VECTOR",
|
||||
14: "C32VECTOR",
|
||||
}
|
||||
|
||||
|
||||
class KnobModel(BaseModel):
|
||||
"""ControlPort knob with type information.
|
||||
|
||||
Knobs are named using the pattern: block_alias::varname
|
||||
(e.g., "sig_source0::frequency")
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: Any
|
||||
knob_type: str # BOOL, INT, DOUBLE, COMPLEX, F32VECTOR, etc.
|
||||
|
||||
|
||||
class KnobPropertiesModel(BaseModel):
|
||||
"""Rich metadata for a ControlPort knob.
|
||||
|
||||
Includes units, min/max bounds, and description from the
|
||||
block's property registration.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
units: str | None = None
|
||||
min_value: Any | None = None
|
||||
max_value: Any | None = None
|
||||
default_value: Any | None = None
|
||||
knob_type: str | None = None
|
||||
|
||||
|
||||
class PerfCounterModel(BaseModel):
|
||||
"""Block performance metrics from ControlPort.
|
||||
|
||||
These are automatically exposed when [PerfCounters] on = True
|
||||
in the GNU Radio config. Performance counters use the naming
|
||||
pattern: block_alias::metric_name
|
||||
"""
|
||||
|
||||
block_name: str
|
||||
avg_throughput: float # samples/sec (avg nproduced * sample rate)
|
||||
avg_work_time_us: float # microseconds per work() call
|
||||
total_work_time_us: float # cumulative time in work()
|
||||
avg_nproduced: float # average samples produced per work() call
|
||||
input_buffer_pct: list[float] = [] # buffer fullness per input port
|
||||
output_buffer_pct: list[float] = [] # buffer fullness per output port
|
||||
|
||||
|
||||
class ThriftConnectionInfoModel(BaseModel):
|
||||
"""Connection information for ControlPort/Thrift."""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
container_name: str | None = None
|
||||
protocol: str = "thrift"
|
||||
knob_count: int = 0
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Coverage Models (Cross-Process Code Coverage)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
@ -42,6 +42,15 @@ class McpRuntimeProvider:
|
||||
self._mcp.tool(p.lock)
|
||||
self._mcp.tool(p.unlock)
|
||||
|
||||
# ControlPort/Thrift tools (always available - Phase 2)
|
||||
self._mcp.tool(p.connect_controlport)
|
||||
self._mcp.tool(p.disconnect_controlport)
|
||||
self._mcp.tool(p.get_knobs)
|
||||
self._mcp.tool(p.set_knobs)
|
||||
self._mcp.tool(p.get_knob_properties)
|
||||
self._mcp.tool(p.get_performance_counters)
|
||||
self._mcp.tool(p.post_message)
|
||||
|
||||
# Docker-dependent tools
|
||||
if p._has_docker:
|
||||
# Container lifecycle
|
||||
@ -50,6 +59,7 @@ class McpRuntimeProvider:
|
||||
self._mcp.tool(p.stop_flowgraph)
|
||||
self._mcp.tool(p.remove_flowgraph)
|
||||
self._mcp.tool(p.connect_to_container)
|
||||
self._mcp.tool(p.connect_to_container_controlport) # Phase 2
|
||||
|
||||
# Visual feedback
|
||||
self._mcp.tool(p.capture_screenshot)
|
||||
@ -61,10 +71,10 @@ class McpRuntimeProvider:
|
||||
self._mcp.tool(p.combine_coverage)
|
||||
self._mcp.tool(p.delete_coverage)
|
||||
|
||||
logger.info("Registered 21 runtime tools (Docker available)")
|
||||
logger.info("Registered 29 runtime tools (Docker available)")
|
||||
else:
|
||||
logger.info(
|
||||
"Registered 10 runtime tools (Docker unavailable, "
|
||||
"Registered 17 runtime tools (Docker unavailable, "
|
||||
"container tools skipped)"
|
||||
)
|
||||
|
||||
|
||||
@ -7,15 +7,20 @@ import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from gnuradio_mcp.middlewares.docker import DockerMiddleware, HOST_COVERAGE_BASE
|
||||
from gnuradio_mcp.middlewares.docker import HOST_COVERAGE_BASE, DockerMiddleware
|
||||
from gnuradio_mcp.middlewares.thrift import ThriftMiddleware
|
||||
from gnuradio_mcp.middlewares.xmlrpc import XmlRpcMiddleware
|
||||
from gnuradio_mcp.models import (
|
||||
ConnectionInfoModel,
|
||||
ContainerModel,
|
||||
CoverageDataModel,
|
||||
CoverageReportModel,
|
||||
KnobModel,
|
||||
KnobPropertiesModel,
|
||||
PerfCounterModel,
|
||||
RuntimeStatusModel,
|
||||
ScreenshotModel,
|
||||
ThriftConnectionInfoModel,
|
||||
VariableModel,
|
||||
)
|
||||
|
||||
@ -25,7 +30,9 @@ logger = logging.getLogger(__name__)
|
||||
class RuntimeProvider:
|
||||
"""Business logic for runtime flowgraph control.
|
||||
|
||||
Coordinates Docker (container lifecycle) and XML-RPC (variable control).
|
||||
Coordinates Docker (container lifecycle), XML-RPC (variable control),
|
||||
and ControlPort/Thrift (advanced control with perf counters).
|
||||
|
||||
Tracks the active connection so convenience methods like get_variable()
|
||||
work without repeating the URL each call.
|
||||
"""
|
||||
@ -36,6 +43,7 @@ class RuntimeProvider:
|
||||
):
|
||||
self._docker = docker_mw
|
||||
self._xmlrpc: XmlRpcMiddleware | None = None
|
||||
self._thrift: ThriftMiddleware | None = None
|
||||
self._active_container: str | None = None
|
||||
|
||||
@property
|
||||
@ -58,6 +66,14 @@ class RuntimeProvider:
|
||||
)
|
||||
return self._xmlrpc
|
||||
|
||||
def _require_thrift(self) -> ThriftMiddleware:
|
||||
if self._thrift is None:
|
||||
raise RuntimeError(
|
||||
"Not connected via ControlPort. Use connect_controlport() or "
|
||||
"connect_to_container_controlport() first."
|
||||
)
|
||||
return self._thrift
|
||||
|
||||
# ──────────────────────────────────────────
|
||||
# Container Lifecycle
|
||||
# ──────────────────────────────────────────
|
||||
@ -69,6 +85,9 @@ class RuntimeProvider:
|
||||
xmlrpc_port: int = 8080,
|
||||
enable_vnc: bool = False,
|
||||
enable_coverage: bool = False,
|
||||
enable_controlport: bool = False,
|
||||
controlport_port: int = 9090,
|
||||
enable_perf_counters: bool = True,
|
||||
device_paths: list[str] | None = None,
|
||||
) -> ContainerModel:
|
||||
"""Launch a flowgraph in a Docker container with Xvfb.
|
||||
@ -79,6 +98,9 @@ class RuntimeProvider:
|
||||
xmlrpc_port: Port for XML-RPC variable control
|
||||
enable_vnc: Enable VNC server for visual debugging
|
||||
enable_coverage: Enable Python code coverage collection
|
||||
enable_controlport: Enable ControlPort/Thrift for advanced control
|
||||
controlport_port: Port for ControlPort (default 9090)
|
||||
enable_perf_counters: Enable performance counters (requires controlport)
|
||||
device_paths: Host device paths to pass through
|
||||
"""
|
||||
docker = self._require_docker()
|
||||
@ -90,6 +112,9 @@ class RuntimeProvider:
|
||||
xmlrpc_port=xmlrpc_port,
|
||||
enable_vnc=enable_vnc,
|
||||
enable_coverage=enable_coverage,
|
||||
enable_controlport=enable_controlport,
|
||||
controlport_port=controlport_port,
|
||||
enable_perf_counters=enable_perf_counters,
|
||||
device_paths=device_paths,
|
||||
)
|
||||
|
||||
@ -130,18 +155,157 @@ class RuntimeProvider:
|
||||
url = f"http://localhost:{port}"
|
||||
self._xmlrpc = XmlRpcMiddleware.connect(url)
|
||||
self._active_container = name
|
||||
return self._xmlrpc.get_connection_info(
|
||||
container_name=name, xmlrpc_port=port
|
||||
)
|
||||
return self._xmlrpc.get_connection_info(container_name=name, xmlrpc_port=port)
|
||||
|
||||
def disconnect(self) -> bool:
|
||||
"""Disconnect from the current XML-RPC endpoint."""
|
||||
if self._xmlrpc is not None:
|
||||
self._xmlrpc.close()
|
||||
self._xmlrpc = None
|
||||
self._active_container = None
|
||||
if self._thrift is not None:
|
||||
self._thrift.close()
|
||||
self._thrift = None
|
||||
self._active_container = None
|
||||
return True
|
||||
|
||||
# ──────────────────────────────────────────
|
||||
# ControlPort/Thrift Connection (Phase 2)
|
||||
# ──────────────────────────────────────────
|
||||
|
||||
def connect_controlport(
|
||||
self,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
) -> ThriftConnectionInfoModel:
|
||||
"""Connect to a GNU Radio ControlPort/Thrift endpoint.
|
||||
|
||||
ControlPort provides richer functionality than XML-RPC:
|
||||
- Native type support (complex numbers, vectors)
|
||||
- Performance counters (throughput, timing, buffer utilization)
|
||||
- Knob metadata (units, min/max, descriptions)
|
||||
- PMT message injection
|
||||
- Regex-based knob queries
|
||||
|
||||
Args:
|
||||
host: Hostname or IP address
|
||||
port: ControlPort Thrift port (default 9090)
|
||||
"""
|
||||
self._thrift = ThriftMiddleware.connect(host, port)
|
||||
self._active_container = None
|
||||
return self._thrift.get_connection_info()
|
||||
|
||||
def connect_to_container_controlport(self, name: str) -> ThriftConnectionInfoModel:
|
||||
"""Connect to a flowgraph's ControlPort by container name.
|
||||
|
||||
Resolves the ControlPort port from container labels automatically.
|
||||
|
||||
Args:
|
||||
name: Container name
|
||||
"""
|
||||
docker = self._require_docker()
|
||||
if not docker.is_controlport_enabled(name):
|
||||
raise RuntimeError(
|
||||
f"Container '{name}' was not launched with ControlPort enabled. "
|
||||
f"Use launch_flowgraph(..., enable_controlport=True)"
|
||||
)
|
||||
port = docker.get_controlport_port(name)
|
||||
self._thrift = ThriftMiddleware.connect("127.0.0.1", port)
|
||||
self._active_container = name
|
||||
return self._thrift.get_connection_info(container_name=name)
|
||||
|
||||
def disconnect_controlport(self) -> bool:
|
||||
"""Disconnect from the current ControlPort endpoint."""
|
||||
if self._thrift is not None:
|
||||
self._thrift.close()
|
||||
self._thrift = None
|
||||
return True
|
||||
|
||||
# ──────────────────────────────────────────
|
||||
# ControlPort Knob Operations (Phase 2)
|
||||
# ──────────────────────────────────────────
|
||||
|
||||
def get_knobs(self, pattern: str = "") -> list[KnobModel]:
|
||||
"""Get ControlPort knobs, optionally filtered by regex pattern.
|
||||
|
||||
Knobs are named using the pattern: block_alias::varname
|
||||
(e.g., "sig_source0::frequency")
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern for filtering knob names.
|
||||
Empty string returns all knobs.
|
||||
|
||||
Examples:
|
||||
get_knobs("") # All knobs
|
||||
get_knobs(".*frequency.*") # All frequency-related knobs
|
||||
get_knobs("sig_source0::.*") # All knobs for sig_source0
|
||||
"""
|
||||
thrift = self._require_thrift()
|
||||
return thrift.get_knobs(pattern)
|
||||
|
||||
def set_knobs(self, knobs: dict[str, Any]) -> bool:
|
||||
"""Set multiple ControlPort knobs atomically.
|
||||
|
||||
Args:
|
||||
knobs: Dict mapping knob names to new values.
|
||||
Types are inferred from existing knobs.
|
||||
|
||||
Example:
|
||||
set_knobs({
|
||||
"sig_source0::frequency": 1000000.0,
|
||||
"sig_source0::amplitude": 0.5,
|
||||
})
|
||||
"""
|
||||
thrift = self._require_thrift()
|
||||
return thrift.set_knobs(knobs)
|
||||
|
||||
def get_knob_properties(self, names: list[str]) -> list[KnobPropertiesModel]:
|
||||
"""Get metadata (units, min/max, description) for specified knobs.
|
||||
|
||||
Args:
|
||||
names: List of knob names to query. Empty list returns all properties.
|
||||
|
||||
Returns:
|
||||
List of KnobPropertiesModel with rich metadata.
|
||||
"""
|
||||
thrift = self._require_thrift()
|
||||
return thrift.get_knob_properties(names)
|
||||
|
||||
def get_performance_counters(
|
||||
self, block: str | None = None
|
||||
) -> list[PerfCounterModel]:
|
||||
"""Get performance metrics for blocks via ControlPort.
|
||||
|
||||
Requires the flowgraph to be launched with enable_controlport=True
|
||||
and enable_perf_counters=True (default).
|
||||
|
||||
Args:
|
||||
block: Optional block alias to filter (e.g., "sig_source0").
|
||||
If None, returns metrics for all blocks.
|
||||
|
||||
Returns:
|
||||
List of PerfCounterModel with throughput, timing, and buffer stats.
|
||||
"""
|
||||
thrift = self._require_thrift()
|
||||
return thrift.get_performance_counters(block)
|
||||
|
||||
def post_message(self, block: str, port: str, message: Any) -> bool:
|
||||
"""Send a PMT message to a block's message port via ControlPort.
|
||||
|
||||
Args:
|
||||
block: Block alias (e.g., "msg_sink0")
|
||||
port: Message port name (e.g., "in")
|
||||
message: Message to send (will be converted to PMT if needed)
|
||||
|
||||
Example:
|
||||
# Send a simple string message
|
||||
post_message("pdu_sink0", "pdus", "hello")
|
||||
|
||||
# Send a dict (converted to PMT dict)
|
||||
post_message("block0", "command", {"freq": 1e6})
|
||||
"""
|
||||
thrift = self._require_thrift()
|
||||
return thrift.post_message(block, port, message)
|
||||
|
||||
def get_status(self) -> RuntimeStatusModel:
|
||||
"""Get runtime status including connection and container info."""
|
||||
connection = None
|
||||
@ -216,7 +380,7 @@ class RuntimeProvider:
|
||||
container_name = name or self._active_container
|
||||
if container_name is None:
|
||||
raise RuntimeError(
|
||||
"No container specified. Provide a name or connect to a container first."
|
||||
"No container specified. Provide a name or connect first."
|
||||
)
|
||||
return docker.capture_screenshot(container_name)
|
||||
|
||||
@ -226,7 +390,7 @@ class RuntimeProvider:
|
||||
container_name = name or self._active_container
|
||||
if container_name is None:
|
||||
raise RuntimeError(
|
||||
"No container specified. Provide a name or connect to a container first."
|
||||
"No container specified. Provide a name or connect first."
|
||||
)
|
||||
return docker.get_logs(container_name, tail=tail)
|
||||
|
||||
@ -350,9 +514,12 @@ class RuntimeProvider:
|
||||
report_path = coverage_dir / "htmlcov" / "index.html"
|
||||
subprocess.run(
|
||||
[
|
||||
"coverage", "html",
|
||||
"--data-file", str(coverage_file),
|
||||
"-d", str(coverage_dir / "htmlcov"),
|
||||
"coverage",
|
||||
"html",
|
||||
"--data-file",
|
||||
str(coverage_file),
|
||||
"-d",
|
||||
str(coverage_dir / "htmlcov"),
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
@ -361,9 +528,12 @@ class RuntimeProvider:
|
||||
report_path = coverage_dir / "coverage.xml"
|
||||
subprocess.run(
|
||||
[
|
||||
"coverage", "xml",
|
||||
"--data-file", str(coverage_file),
|
||||
"-o", str(report_path),
|
||||
"coverage",
|
||||
"xml",
|
||||
"--data-file",
|
||||
str(coverage_file),
|
||||
"-o",
|
||||
str(report_path),
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
@ -372,9 +542,12 @@ class RuntimeProvider:
|
||||
report_path = coverage_dir / "coverage.json"
|
||||
subprocess.run(
|
||||
[
|
||||
"coverage", "json",
|
||||
"--data-file", str(coverage_file),
|
||||
"-o", str(report_path),
|
||||
"coverage",
|
||||
"json",
|
||||
"--data-file",
|
||||
str(coverage_file),
|
||||
"-o",
|
||||
str(report_path),
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
|
||||
@ -38,7 +38,7 @@ pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not RUNTIME_IMAGE_EXISTS,
|
||||
reason=f"Runtime image '{RUNTIME_IMAGE}' not built. "
|
||||
"Run: docker build -t gnuradio-runtime -f docker/Dockerfile.gnuradio-runtime docker/",
|
||||
"Run: docker build -t gnuradio-runtime docker/",
|
||||
),
|
||||
]
|
||||
|
||||
@ -174,9 +174,7 @@ class TestDockerMiddlewareIntegration:
|
||||
class TestRuntimeProviderIntegration:
|
||||
"""Test RuntimeProvider with real Docker (requires runtime image)."""
|
||||
|
||||
def test_launch_and_stop_flowgraph(
|
||||
self, test_flowgraph, cleanup_containers
|
||||
):
|
||||
def test_launch_and_stop_flowgraph(self, test_flowgraph, cleanup_containers):
|
||||
from gnuradio_mcp.middlewares.docker import DockerMiddleware
|
||||
from gnuradio_mcp.providers.runtime import RuntimeProvider
|
||||
|
||||
@ -215,7 +213,7 @@ class TestRuntimeProviderIntegration:
|
||||
cleanup_containers.remove(container_name)
|
||||
|
||||
def test_launch_connect_and_control(self, test_flowgraph, cleanup_containers):
|
||||
"""Full integration: launch container, connect via XML-RPC, control variables."""
|
||||
"""Integration: launch, connect via XML-RPC, and control variables."""
|
||||
from gnuradio_mcp.middlewares.docker import DockerMiddleware
|
||||
from gnuradio_mcp.providers.runtime import RuntimeProvider
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from gnuradio_mcp.middlewares.docker import (
|
||||
CONTAINER_FLOWGRAPH_DIR,
|
||||
DEFAULT_XMLRPC_PORT,
|
||||
DockerMiddleware,
|
||||
)
|
||||
@ -96,7 +95,9 @@ class TestLaunch:
|
||||
call_kwargs = mock_docker_client.containers.run.call_args
|
||||
assert call_kwargs.kwargs["labels"]["gr-mcp.vnc-enabled"] == "1"
|
||||
|
||||
def test_launch_without_vnc_sets_label(self, docker_mw, mock_docker_client, tmp_path):
|
||||
def test_launch_without_vnc_sets_label(
|
||||
self, docker_mw, mock_docker_client, tmp_path
|
||||
):
|
||||
fg_file = tmp_path / "test.grc"
|
||||
fg_file.write_text("<flowgraph/>")
|
||||
|
||||
@ -131,7 +132,10 @@ class TestLaunch:
|
||||
assert result.device_paths == ["/dev/bus/usb/001/002"]
|
||||
|
||||
call_kwargs = mock_docker_client.containers.run.call_args
|
||||
assert "/dev/bus/usb/001/002:/dev/bus/usb/001/002:rwm" in call_kwargs.kwargs["devices"]
|
||||
assert (
|
||||
"/dev/bus/usb/001/002:/dev/bus/usb/001/002:rwm"
|
||||
in call_kwargs.kwargs["devices"]
|
||||
)
|
||||
|
||||
|
||||
class TestListContainers:
|
||||
|
||||
@ -98,7 +98,9 @@ class TestPreconditions:
|
||||
with pytest.raises(RuntimeError, match="Docker is not available"):
|
||||
provider_no_docker._require_docker()
|
||||
|
||||
def test_require_docker_returns_middleware(self, provider_with_docker, mock_docker_mw):
|
||||
def test_require_docker_returns_middleware(
|
||||
self, provider_with_docker, mock_docker_mw
|
||||
):
|
||||
result = provider_with_docker._require_docker()
|
||||
assert result is mock_docker_mw
|
||||
|
||||
@ -126,10 +128,15 @@ class TestContainerLifecycle:
|
||||
xmlrpc_port=9090,
|
||||
enable_vnc=True,
|
||||
enable_coverage=False,
|
||||
enable_controlport=False,
|
||||
controlport_port=9090,
|
||||
enable_perf_counters=True,
|
||||
device_paths=None,
|
||||
)
|
||||
|
||||
def test_launch_flowgraph_auto_name(self, provider_with_docker, mock_docker_mw, tmp_path):
|
||||
def test_launch_flowgraph_auto_name(
|
||||
self, provider_with_docker, mock_docker_mw, tmp_path
|
||||
):
|
||||
fg = tmp_path / "siggen_xmlrpc.grc"
|
||||
fg.write_text("<flowgraph/>")
|
||||
|
||||
@ -182,7 +189,9 @@ class TestConnectionManagement:
|
||||
provider_with_docker.connect("http://localhost:9090")
|
||||
mock_xmlrpc_mw.get_connection_info.assert_called_with(xmlrpc_port=9090)
|
||||
|
||||
def test_connect_to_container(self, provider_with_docker, mock_docker_mw, mock_xmlrpc_mw):
|
||||
def test_connect_to_container(
|
||||
self, provider_with_docker, mock_docker_mw, mock_xmlrpc_mw
|
||||
):
|
||||
with patch(
|
||||
"gnuradio_mcp.providers.runtime.XmlRpcMiddleware.connect",
|
||||
return_value=mock_xmlrpc_mw,
|
||||
@ -216,7 +225,9 @@ class TestConnectionManagement:
|
||||
assert result.connection is None
|
||||
assert len(result.containers) == 1
|
||||
|
||||
def test_get_status_connected(self, provider_with_docker, mock_docker_mw, mock_xmlrpc_mw):
|
||||
def test_get_status_connected(
|
||||
self, provider_with_docker, mock_docker_mw, mock_xmlrpc_mw
|
||||
):
|
||||
provider_with_docker._xmlrpc = mock_xmlrpc_mw
|
||||
provider_with_docker._active_container = "gr-test"
|
||||
|
||||
@ -226,7 +237,9 @@ class TestConnectionManagement:
|
||||
assert result.connection is not None
|
||||
mock_xmlrpc_mw.get_connection_info.assert_called()
|
||||
|
||||
def test_get_status_handles_docker_error(self, provider_with_docker, mock_docker_mw):
|
||||
def test_get_status_handles_docker_error(
|
||||
self, provider_with_docker, mock_docker_mw
|
||||
):
|
||||
mock_docker_mw.list_containers.side_effect = Exception("Docker error")
|
||||
|
||||
result = provider_with_docker.get_status()
|
||||
@ -298,7 +311,9 @@ class TestVisualFeedback:
|
||||
assert isinstance(result, ScreenshotModel)
|
||||
mock_docker_mw.capture_screenshot.assert_called_once_with("gr-test")
|
||||
|
||||
def test_capture_screenshot_uses_active_container(self, provider_with_docker, mock_docker_mw):
|
||||
def test_capture_screenshot_uses_active_container(
|
||||
self, provider_with_docker, mock_docker_mw
|
||||
):
|
||||
provider_with_docker._active_container = "gr-active"
|
||||
|
||||
provider_with_docker.capture_screenshot()
|
||||
@ -315,7 +330,9 @@ class TestVisualFeedback:
|
||||
assert "flowgraph started" in result
|
||||
mock_docker_mw.get_logs.assert_called_once_with("gr-test", tail=50)
|
||||
|
||||
def test_get_container_logs_uses_active_container(self, provider_with_docker, mock_docker_mw):
|
||||
def test_get_container_logs_uses_active_container(
|
||||
self, provider_with_docker, mock_docker_mw
|
||||
):
|
||||
provider_with_docker._active_container = "gr-active"
|
||||
|
||||
provider_with_docker.get_container_logs()
|
||||
@ -348,9 +365,10 @@ class TestCoverageCollection:
|
||||
with pytest.raises(FileNotFoundError, match="No coverage data"):
|
||||
provider_with_docker.collect_coverage("nonexistent-container")
|
||||
|
||||
def test_collect_coverage_success(self, provider_with_docker, tmp_path, monkeypatch):
|
||||
def test_collect_coverage_success(
|
||||
self, provider_with_docker, tmp_path, monkeypatch
|
||||
):
|
||||
from gnuradio_mcp.models import CoverageDataModel
|
||||
from gnuradio_mcp.middlewares.docker import HOST_COVERAGE_BASE
|
||||
|
||||
# Create fake coverage directory and file
|
||||
monkeypatch.setattr(
|
||||
@ -383,7 +401,9 @@ TOTAL 100 20 40 10 75%"""
|
||||
assert result.lines_total == 100
|
||||
assert result.lines_covered == 80 # 100 - 20 missed
|
||||
|
||||
def test_generate_coverage_report_html(self, provider_with_docker, tmp_path, monkeypatch):
|
||||
def test_generate_coverage_report_html(
|
||||
self, provider_with_docker, tmp_path, monkeypatch
|
||||
):
|
||||
from gnuradio_mcp.models import CoverageReportModel
|
||||
|
||||
# Setup
|
||||
@ -414,7 +434,9 @@ TOTAL 100 20 40 10 75%"""
|
||||
assert result.format == "html"
|
||||
assert "htmlcov" in result.report_path
|
||||
|
||||
def test_generate_coverage_report_xml(self, provider_with_docker, tmp_path, monkeypatch):
|
||||
def test_generate_coverage_report_xml(
|
||||
self, provider_with_docker, tmp_path, monkeypatch
|
||||
):
|
||||
from gnuradio_mcp.models import CoverageReportModel
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -488,7 +510,9 @@ TOTAL 100 20 40 10 75%"""
|
||||
with pytest.raises(ValueError, match="At least one container"):
|
||||
provider_with_docker.combine_coverage([])
|
||||
|
||||
def test_delete_coverage_specific(self, provider_with_docker, tmp_path, monkeypatch):
|
||||
def test_delete_coverage_specific(
|
||||
self, provider_with_docker, tmp_path, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
"gnuradio_mcp.providers.runtime.HOST_COVERAGE_BASE", str(tmp_path)
|
||||
)
|
||||
@ -503,7 +527,9 @@ TOTAL 100 20 40 10 75%"""
|
||||
assert deleted == 1
|
||||
assert not coverage_dir.exists()
|
||||
|
||||
def test_delete_coverage_older_than(self, provider_with_docker, tmp_path, monkeypatch):
|
||||
def test_delete_coverage_older_than(
|
||||
self, provider_with_docker, tmp_path, monkeypatch
|
||||
):
|
||||
import os
|
||||
import time
|
||||
|
||||
@ -542,7 +568,9 @@ TOTAL 100 20 40 10 75%"""
|
||||
assert not (tmp_path / "container-1").exists()
|
||||
assert not (tmp_path / "container-2").exists()
|
||||
|
||||
def test_delete_coverage_nonexistent(self, provider_with_docker, tmp_path, monkeypatch):
|
||||
def test_delete_coverage_nonexistent(
|
||||
self, provider_with_docker, tmp_path, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
"gnuradio_mcp.providers.runtime.HOST_COVERAGE_BASE", str(tmp_path)
|
||||
)
|
||||
|
||||
337
tests/unit/test_thrift_middleware.py
Normal file
337
tests/unit/test_thrift_middleware.py
Normal file
@ -0,0 +1,337 @@
|
||||
"""Unit tests for ThriftMiddleware.
|
||||
|
||||
These tests mock the Thrift client since we can't easily connect to
|
||||
a real ControlPort server in unit tests. The mocked client simulates
|
||||
the RPCConnectionThrift API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gnuradio_mcp.middlewares.thrift import (
|
||||
DEFAULT_THRIFT_PORT,
|
||||
PERF_COUNTER_SUFFIXES,
|
||||
ThriftMiddleware,
|
||||
)
|
||||
from gnuradio_mcp.models import KnobModel, KnobPropertiesModel, PerfCounterModel
|
||||
|
||||
|
||||
class MockKnob:
|
||||
"""Mock for RPCConnectionThrift.Knob."""
|
||||
|
||||
def __init__(self, key: str, value: Any, ktype: int):
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.ktype = ktype
|
||||
|
||||
|
||||
class MockKnobProps:
|
||||
"""Mock for Thrift knob properties."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str = "",
|
||||
units: str = "",
|
||||
ktype: int = 5,
|
||||
min_val: Any = None,
|
||||
max_val: Any = None,
|
||||
default_val: Any = None,
|
||||
):
|
||||
self.description = description
|
||||
self.units = units
|
||||
self.type = ktype
|
||||
self.min = MockKnob("", min_val, ktype) if min_val is not None else None
|
||||
self.max = MockKnob("", max_val, ktype) if max_val is not None else None
|
||||
self.defaultvalue = (
|
||||
MockKnob("", default_val, ktype) if default_val is not None else None
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client():
|
||||
"""Create a mock Thrift client."""
|
||||
client = MagicMock()
|
||||
|
||||
# Default getKnobs response
|
||||
client.getKnobs.return_value = {
|
||||
"sig_source0::frequency": MockKnob("sig_source0::frequency", 1000000.0, 5),
|
||||
"sig_source0::amplitude": MockKnob("sig_source0::amplitude", 0.5, 5),
|
||||
"null_sink0::avg throughput": MockKnob("null_sink0::avg throughput", 1e9, 5),
|
||||
}
|
||||
|
||||
# Default getRe response (regex query)
|
||||
client.getRe.return_value = {
|
||||
"sig_source0::frequency": MockKnob("sig_source0::frequency", 1000000.0, 5),
|
||||
}
|
||||
|
||||
# Default properties response
|
||||
client.properties.return_value = {
|
||||
"sig_source0::frequency": MockKnobProps(
|
||||
description="Signal frequency in Hz",
|
||||
units="Hz",
|
||||
ktype=5,
|
||||
min_val=0.0,
|
||||
max_val=1e12,
|
||||
default_val=1000.0,
|
||||
),
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thrift_middleware(mock_client):
|
||||
"""Create a ThriftMiddleware with mocked client."""
|
||||
return ThriftMiddleware(mock_client, "127.0.0.1", DEFAULT_THRIFT_PORT)
|
||||
|
||||
|
||||
class TestThriftMiddlewareConnection:
|
||||
"""Tests for connection handling."""
|
||||
|
||||
def test_get_connection_info(self, thrift_middleware, mock_client):
|
||||
"""get_connection_info returns host, port, and knob count."""
|
||||
mock_client.getKnobs.return_value = {
|
||||
"k1": MockKnob("k1", 1, 5),
|
||||
"k2": MockKnob("k2", 2, 5),
|
||||
}
|
||||
|
||||
info = thrift_middleware.get_connection_info()
|
||||
|
||||
assert info.host == "127.0.0.1"
|
||||
assert info.port == DEFAULT_THRIFT_PORT
|
||||
assert info.protocol == "thrift"
|
||||
assert info.knob_count == 2
|
||||
|
||||
def test_get_connection_info_with_container_name(self, thrift_middleware):
|
||||
"""get_connection_info includes container name when provided."""
|
||||
info = thrift_middleware.get_connection_info(container_name="test-container")
|
||||
|
||||
assert info.container_name == "test-container"
|
||||
|
||||
def test_close(self, thrift_middleware):
|
||||
"""close clears the client reference."""
|
||||
thrift_middleware.close()
|
||||
assert thrift_middleware._client is None
|
||||
|
||||
|
||||
class TestThriftMiddlewareVariables:
|
||||
"""Tests for variable operations (XML-RPC compatible API)."""
|
||||
|
||||
def test_list_variables_filters_perf_counters(self, thrift_middleware, mock_client):
|
||||
"""list_variables excludes performance counter knobs."""
|
||||
mock_client.getKnobs.return_value = {
|
||||
"sig_source0::frequency": MockKnob("sig_source0::frequency", 1e6, 5),
|
||||
"null_sink0::avg throughput": MockKnob("null_sink0::avg throughput", 1e9, 5),
|
||||
}
|
||||
|
||||
variables = thrift_middleware.list_variables()
|
||||
|
||||
# Should only include frequency, not the perf counter
|
||||
assert len(variables) == 1
|
||||
assert variables[0].name == "sig_source0::frequency"
|
||||
assert variables[0].value == 1e6
|
||||
|
||||
def test_get_variable(self, thrift_middleware, mock_client):
|
||||
"""get_variable returns the knob value."""
|
||||
mock_client.getKnobs.return_value = {
|
||||
"sig_source0::frequency": MockKnob("sig_source0::frequency", 1e6, 5),
|
||||
}
|
||||
|
||||
value = thrift_middleware.get_variable("sig_source0::frequency")
|
||||
|
||||
assert value == 1e6
|
||||
mock_client.getKnobs.assert_called_with(["sig_source0::frequency"])
|
||||
|
||||
def test_get_variable_not_found(self, thrift_middleware, mock_client):
|
||||
"""get_variable raises KeyError for unknown knob."""
|
||||
mock_client.getKnobs.return_value = {}
|
||||
|
||||
with pytest.raises(KeyError, match="Knob not found"):
|
||||
thrift_middleware.get_variable("unknown::knob")
|
||||
|
||||
def test_set_variable(self, thrift_middleware, mock_client):
|
||||
"""set_variable updates the knob value."""
|
||||
mock_client.getKnobs.return_value = {
|
||||
"sig_source0::frequency": MockKnob("sig_source0::frequency", 1e6, 5),
|
||||
}
|
||||
|
||||
# Need to mock the import inside the method
|
||||
with patch(
|
||||
"gnuradio_mcp.middlewares.thrift.RPCConnectionThrift", create=True
|
||||
) as mock_rpc:
|
||||
mock_rpc.Knob = MockKnob
|
||||
result = thrift_middleware.set_variable("sig_source0::frequency", 2e6)
|
||||
|
||||
assert result is True
|
||||
mock_client.setKnobs.assert_called_once()
|
||||
|
||||
|
||||
class TestThriftMiddlewareKnobs:
|
||||
"""Tests for ControlPort-specific knob operations."""
|
||||
|
||||
def test_get_knobs_all(self, thrift_middleware, mock_client):
|
||||
"""get_knobs with empty pattern returns all knobs."""
|
||||
knobs = thrift_middleware.get_knobs("")
|
||||
|
||||
mock_client.getKnobs.assert_called_with([])
|
||||
assert len(knobs) == 3 # All including perf counter
|
||||
|
||||
def test_get_knobs_with_pattern(self, thrift_middleware, mock_client):
|
||||
"""get_knobs with pattern uses regex query."""
|
||||
thrift_middleware.get_knobs(".*frequency.*")
|
||||
|
||||
mock_client.getRe.assert_called_with([".*frequency.*"])
|
||||
|
||||
def test_get_knobs_returns_knob_models(self, thrift_middleware, mock_client):
|
||||
"""get_knobs returns KnobModel instances with correct types."""
|
||||
mock_client.getKnobs.return_value = {
|
||||
"k1": MockKnob("k1", 1.0, 5), # DOUBLE
|
||||
"k2": MockKnob("k2", True, 0), # BOOL
|
||||
}
|
||||
|
||||
knobs = thrift_middleware.get_knobs("")
|
||||
|
||||
assert len(knobs) == 2
|
||||
assert all(isinstance(k, KnobModel) for k in knobs)
|
||||
|
||||
k1 = next(k for k in knobs if k.name == "k1")
|
||||
assert k1.value == 1.0
|
||||
assert k1.knob_type == "DOUBLE"
|
||||
|
||||
k2 = next(k for k in knobs if k.name == "k2")
|
||||
assert k2.value is True
|
||||
assert k2.knob_type == "BOOL"
|
||||
|
||||
def test_get_knob_properties(self, thrift_middleware, mock_client):
|
||||
"""get_knob_properties returns metadata for knobs."""
|
||||
props = thrift_middleware.get_knob_properties(["sig_source0::frequency"])
|
||||
|
||||
mock_client.properties.assert_called_with(["sig_source0::frequency"])
|
||||
assert len(props) == 1
|
||||
assert isinstance(props[0], KnobPropertiesModel)
|
||||
assert props[0].name == "sig_source0::frequency"
|
||||
assert props[0].description == "Signal frequency in Hz"
|
||||
assert props[0].min_value == 0.0
|
||||
assert props[0].max_value == 1e12
|
||||
|
||||
|
||||
class TestThriftMiddlewarePerfCounters:
|
||||
"""Tests for performance counter operations."""
|
||||
|
||||
def test_get_performance_counters(self, thrift_middleware, mock_client):
|
||||
"""get_performance_counters extracts per-block metrics."""
|
||||
mock_client.getKnobs.return_value = {
|
||||
"sig_source0::frequency": MockKnob("sig_source0::frequency", 1e6, 5),
|
||||
"sig_source0::avg throughput": MockKnob(
|
||||
"sig_source0::avg throughput", 1e9, 5
|
||||
),
|
||||
"sig_source0::avg work time": MockKnob(
|
||||
"sig_source0::avg work time", 100.0, 5
|
||||
),
|
||||
"sig_source0::total work time": MockKnob(
|
||||
"sig_source0::total work time", 10000.0, 5
|
||||
),
|
||||
"sig_source0::avg nproduced": MockKnob(
|
||||
"sig_source0::avg nproduced", 4096.0, 5
|
||||
),
|
||||
"null_sink0::avg throughput": MockKnob(
|
||||
"null_sink0::avg throughput", 5e8, 5
|
||||
),
|
||||
}
|
||||
|
||||
counters = thrift_middleware.get_performance_counters()
|
||||
|
||||
assert len(counters) == 2
|
||||
assert all(isinstance(c, PerfCounterModel) for c in counters)
|
||||
|
||||
sig = next(c for c in counters if c.block_name == "sig_source0")
|
||||
assert sig.avg_throughput == 1e9
|
||||
assert sig.avg_work_time_us == 100.0
|
||||
assert sig.total_work_time_us == 10000.0
|
||||
assert sig.avg_nproduced == 4096.0
|
||||
|
||||
def test_get_performance_counters_with_block_filter(
|
||||
self, thrift_middleware, mock_client
|
||||
):
|
||||
"""get_performance_counters can filter by block name."""
|
||||
mock_client.getRe.return_value = {
|
||||
"sig_source0::avg throughput": MockKnob(
|
||||
"sig_source0::avg throughput", 1e9, 5
|
||||
),
|
||||
}
|
||||
|
||||
counters = thrift_middleware.get_performance_counters(block="sig_source0")
|
||||
|
||||
# Should use regex pattern for the specific block
|
||||
mock_client.getRe.assert_called_once()
|
||||
call_args = mock_client.getRe.call_args[0][0]
|
||||
assert "sig_source0" in call_args[0]
|
||||
|
||||
|
||||
class TestThriftMiddlewareHelpers:
|
||||
"""Tests for helper methods."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,expected",
|
||||
[
|
||||
("sig_source0::avg throughput", True),
|
||||
("sig_source0::avg work time", True),
|
||||
("sig_source0::total work time", True),
|
||||
("sig_source0::avg nproduced", True),
|
||||
("sig_source0::avg input % full", True),
|
||||
("sig_source0::avg output % full", True),
|
||||
("sig_source0::frequency", False),
|
||||
("sig_source0::amplitude", False),
|
||||
],
|
||||
)
|
||||
def test_is_perf_counter(self, name, expected):
|
||||
"""_is_perf_counter correctly identifies performance counters."""
|
||||
assert ThriftMiddleware._is_perf_counter(name) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected",
|
||||
[
|
||||
([0.1, 0.2], [0.1, 0.2]),
|
||||
((0.3, 0.4), [0.3, 0.4]),
|
||||
(0.5, [0.5]),
|
||||
(None, []),
|
||||
],
|
||||
)
|
||||
def test_to_list(self, value, expected):
|
||||
"""_to_list converts values to list of floats."""
|
||||
assert ThriftMiddleware._to_list(value) == expected
|
||||
|
||||
|
||||
class TestThriftMiddlewareConnectionError:
|
||||
"""Tests for connection error handling."""
|
||||
|
||||
def test_connect_import_error(self):
|
||||
"""connect raises ImportError when gnuradio.ctrlport unavailable."""
|
||||
with patch.dict("sys.modules", {"gnuradio.ctrlport": None}):
|
||||
with patch(
|
||||
"gnuradio_mcp.middlewares.thrift.ThriftMiddleware.connect"
|
||||
) as mock:
|
||||
mock.side_effect = ImportError("GNU Radio ControlPort not available")
|
||||
with pytest.raises(ImportError):
|
||||
ThriftMiddleware.connect()
|
||||
|
||||
|
||||
class TestKnobTypeNames:
|
||||
"""Tests for knob type mapping."""
|
||||
|
||||
def test_all_perf_counter_suffixes_defined(self):
|
||||
"""Ensure all expected perf counter suffixes are defined."""
|
||||
expected_suffixes = [
|
||||
"::avg throughput",
|
||||
"::avg work time",
|
||||
"::total work time",
|
||||
"::avg nproduced",
|
||||
"::avg input % full",
|
||||
"::avg output % full",
|
||||
]
|
||||
for suffix in expected_suffixes:
|
||||
assert suffix in PERF_COUNTER_SUFFIXES
|
||||
Loading…
x
Reference in New Issue
Block a user