Add MCP progress reporting to all multi-fetch tools
ctx.report_progress() gives clients real-time visibility into long-running tool calls. Per-fetch counters for parallel gather calls, stage-based milestones for linear pipelines. No-op when client doesn't send a progressToken.
This commit is contained in:
parent
9d25f5efe3
commit
fb574d26b8
@ -38,6 +38,7 @@ def register(mcp: FastMCP) -> None:
|
||||
today = datetime.now(timezone.utc).strftime("%Y%m%d")
|
||||
|
||||
# Fetch predictions (6-minute interval for smooth curve) + hilo for markers
|
||||
await ctx.report_progress(1, 4, "Fetching predictions")
|
||||
predictions_raw, hilo_raw = await asyncio.gather(
|
||||
noaa.get_data(
|
||||
station_id, product="predictions", begin_date=today,
|
||||
@ -62,6 +63,7 @@ def register(mcp: FastMCP) -> None:
|
||||
# Fetch observed water levels if requested
|
||||
observed = None
|
||||
if include_observed:
|
||||
await ctx.report_progress(2, 4, "Fetching observed data")
|
||||
try:
|
||||
obs_raw = await noaa.get_data(
|
||||
station_id, product="water_level", hours=hours,
|
||||
@ -71,6 +73,7 @@ def register(mcp: FastMCP) -> None:
|
||||
pass # observed overlay is optional — skip on failure
|
||||
|
||||
# Look up station name
|
||||
await ctx.report_progress(3, 4, "Looking up station")
|
||||
station_name = ""
|
||||
try:
|
||||
stations = await noaa.get_stations()
|
||||
@ -80,6 +83,7 @@ def register(mcp: FastMCP) -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await ctx.report_progress(4, 4, "Rendering chart")
|
||||
if format == "png":
|
||||
from mcnoaa_tides.charts.tides import render_tide_chart_png
|
||||
|
||||
@ -132,12 +136,19 @@ def register(mcp: FastMCP) -> None:
|
||||
"air_pressure": {"product": "air_pressure", "hours": hours},
|
||||
}
|
||||
|
||||
completed = 0
|
||||
total_steps = len(requests) + 2 # 6 fetches + station lookup + render
|
||||
|
||||
async def fetch(name: str, params: dict) -> tuple[str, dict | None]:
|
||||
nonlocal completed
|
||||
try:
|
||||
data = await noaa.get_data(station_id, **params)
|
||||
return name, data
|
||||
result = name, data
|
||||
except Exception:
|
||||
return name, None
|
||||
result = name, None
|
||||
completed += 1
|
||||
await ctx.report_progress(completed, total_steps, f"Fetched {name}")
|
||||
return result
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[fetch(name, params) for name, params in requests.items()]
|
||||
@ -149,6 +160,7 @@ def register(mcp: FastMCP) -> None:
|
||||
snapshot[name] = data
|
||||
|
||||
# Look up station name
|
||||
await ctx.report_progress(total_steps - 1, total_steps, "Looking up station")
|
||||
station_name = ""
|
||||
try:
|
||||
stations = await noaa.get_stations()
|
||||
@ -158,6 +170,7 @@ def register(mcp: FastMCP) -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await ctx.report_progress(total_steps, total_steps, "Rendering dashboard")
|
||||
if format == "png":
|
||||
from mcnoaa_tides.charts.conditions import render_conditions_png
|
||||
|
||||
|
||||
@ -47,13 +47,20 @@ def register(mcp: FastMCP) -> None:
|
||||
"air_pressure": {"product": "air_pressure", "hours": hours},
|
||||
}
|
||||
|
||||
completed = 0
|
||||
total = len(requests)
|
||||
|
||||
async def fetch(name: str, params: dict) -> tuple[str, dict | str]:
|
||||
nonlocal completed
|
||||
try:
|
||||
data = await noaa.get_data(station_id, **params)
|
||||
return name, data
|
||||
result = name, data
|
||||
except Exception as exc:
|
||||
msg = str(exc) or type(exc).__name__
|
||||
return name, f"{type(exc).__name__}: {msg}"
|
||||
result = name, f"{type(exc).__name__}: {msg}"
|
||||
completed += 1
|
||||
await ctx.report_progress(completed, total, f"Fetched {name}")
|
||||
return result
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[fetch(name, params) for name, params in requests.items()]
|
||||
|
||||
@ -72,6 +72,8 @@ def register(mcp: FastMCP) -> None:
|
||||
GPS coordinates find the nearest tidal station automatically.
|
||||
"""
|
||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||
|
||||
await ctx.report_progress(1, 3, "Resolving station")
|
||||
station, distance_nm = await _resolve_station(
|
||||
noaa, station_id, latitude, longitude,
|
||||
)
|
||||
@ -81,6 +83,7 @@ def register(mcp: FastMCP) -> None:
|
||||
begin = (now_utc - timedelta(hours=12)).strftime("%Y%m%d %H:%M")
|
||||
end = (now_utc + timedelta(hours=12)).strftime("%Y%m%d %H:%M")
|
||||
|
||||
await ctx.report_progress(2, 3, "Fetching tide data")
|
||||
hilo_data, obs_data = await asyncio.gather(
|
||||
noaa.get_data(
|
||||
station["id"], "predictions",
|
||||
@ -94,6 +97,8 @@ def register(mcp: FastMCP) -> None:
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
await ctx.report_progress(3, 3, "Classifying phase")
|
||||
|
||||
# Parse hilo predictions
|
||||
hilo_preds = []
|
||||
if isinstance(hilo_data, dict):
|
||||
@ -142,6 +147,8 @@ def register(mcp: FastMCP) -> None:
|
||||
Returns station info, tide schedule, conditions, and assessment.
|
||||
"""
|
||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||
|
||||
await ctx.report_progress(1, 6, "Resolving station")
|
||||
station, distance_nm = await _resolve_station(
|
||||
noaa, latitude=latitude, longitude=longitude,
|
||||
)
|
||||
@ -151,6 +158,9 @@ def register(mcp: FastMCP) -> None:
|
||||
end = (now_utc + timedelta(hours=soak_hours)).strftime("%Y%m%d %H:%M")
|
||||
|
||||
# Parallel fetch: hilo predictions for soak window + current conditions
|
||||
await ctx.report_progress(2, 6, "Fetching tide and conditions data")
|
||||
completed = 0
|
||||
|
||||
hilo_fut = noaa.get_data(
|
||||
station["id"], "predictions",
|
||||
begin_date=begin, end_date=end,
|
||||
@ -158,10 +168,14 @@ def register(mcp: FastMCP) -> None:
|
||||
)
|
||||
|
||||
async def _safe_fetch(product, **kwargs):
|
||||
nonlocal completed
|
||||
try:
|
||||
return await noaa.get_data(station["id"], product, **kwargs)
|
||||
result = await noaa.get_data(station["id"], product, **kwargs)
|
||||
except Exception:
|
||||
return None
|
||||
result = None
|
||||
completed += 1
|
||||
await ctx.report_progress(2 + completed, 6, f"Fetched {product}")
|
||||
return result
|
||||
|
||||
hilo_data, wind_data, temp_data, pressure_data = await asyncio.gather(
|
||||
hilo_fut,
|
||||
@ -170,6 +184,8 @@ def register(mcp: FastMCP) -> None:
|
||||
_safe_fetch("air_pressure", hours=3, time_zone="gmt"),
|
||||
)
|
||||
|
||||
await ctx.report_progress(6, 6, "Assessing conditions")
|
||||
|
||||
# Build tide schedule
|
||||
hilo_events = parse_hilo_predictions(hilo_data.get("predictions", []))
|
||||
tide_schedule = [
|
||||
@ -270,6 +286,8 @@ def register(mcp: FastMCP) -> None:
|
||||
)
|
||||
|
||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||
num_events = len(events)
|
||||
num_stations = 0 # determined after resolution
|
||||
|
||||
# Group events by nearest station to batch hilo fetches
|
||||
station_groups: dict[str, list[tuple[int, dict, float | None]]] = {}
|
||||
@ -279,6 +297,9 @@ def register(mcp: FastMCP) -> None:
|
||||
lon = event.get("longitude")
|
||||
if lat is None or lon is None:
|
||||
continue
|
||||
await ctx.report_progress(
|
||||
idx + 1, num_events, f"Resolving station for event {idx + 1}/{num_events}",
|
||||
)
|
||||
nearest = await noaa.find_nearest(float(lat), float(lon), limit=1)
|
||||
if not nearest:
|
||||
continue
|
||||
@ -290,8 +311,16 @@ def register(mcp: FastMCP) -> None:
|
||||
|
||||
# For each station group, fetch hilo once for the full time window
|
||||
enriched = [None] * len(events)
|
||||
num_stations = len(station_groups)
|
||||
station_idx = 0
|
||||
|
||||
for station_id, group in station_groups.items():
|
||||
station_idx += 1
|
||||
await ctx.report_progress(
|
||||
station_idx, num_stations,
|
||||
f"Processing station {station_id}",
|
||||
)
|
||||
|
||||
timestamps = []
|
||||
for _, event, _ in group:
|
||||
ts_str = event.get("timestamp", "")
|
||||
@ -370,6 +399,7 @@ def register(mcp: FastMCP) -> None:
|
||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||
|
||||
# Parallel fetch: observed levels + 6-minute predictions
|
||||
await ctx.report_progress(1, 3, "Fetching observations and predictions")
|
||||
obs_data, pred_data = await asyncio.gather(
|
||||
noaa.get_data(
|
||||
station_id, "water_level",
|
||||
@ -392,6 +422,7 @@ def register(mcp: FastMCP) -> None:
|
||||
}
|
||||
|
||||
# Parse predictions into parallel lists for interpolation
|
||||
await ctx.report_progress(2, 3, "Computing deviations")
|
||||
pred_times = []
|
||||
pred_values = []
|
||||
for p in pred_records:
|
||||
@ -437,6 +468,7 @@ def register(mcp: FastMCP) -> None:
|
||||
direction = "above" if mean_signed > 0 else "below"
|
||||
|
||||
# Risk classification
|
||||
await ctx.report_progress(3, 3, "Classifying risk")
|
||||
if max_dev >= threshold_ft * 2:
|
||||
risk = "high"
|
||||
explanation = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user