Add MCP progress reporting to all multi-fetch tools

ctx.report_progress() gives clients real-time visibility into
long-running tool calls. Per-fetch counters for parallel gather
calls, stage-based milestones for linear pipelines. No-op when
client doesn't send a progressToken.
This commit is contained in:
Ryan Malloy 2026-02-23 13:28:56 -07:00
parent 9d25f5efe3
commit fb574d26b8
3 changed files with 58 additions and 6 deletions

View File

@ -38,6 +38,7 @@ def register(mcp: FastMCP) -> None:
today = datetime.now(timezone.utc).strftime("%Y%m%d")
# Fetch predictions (6-minute interval for smooth curve) + hilo for markers
await ctx.report_progress(1, 4, "Fetching predictions")
predictions_raw, hilo_raw = await asyncio.gather(
noaa.get_data(
station_id, product="predictions", begin_date=today,
@ -62,6 +63,7 @@ def register(mcp: FastMCP) -> None:
# Fetch observed water levels if requested
observed = None
if include_observed:
await ctx.report_progress(2, 4, "Fetching observed data")
try:
obs_raw = await noaa.get_data(
station_id, product="water_level", hours=hours,
@ -71,6 +73,7 @@ def register(mcp: FastMCP) -> None:
pass # observed overlay is optional — skip on failure
# Look up station name
await ctx.report_progress(3, 4, "Looking up station")
station_name = ""
try:
stations = await noaa.get_stations()
@ -80,6 +83,7 @@ def register(mcp: FastMCP) -> None:
except Exception:
pass
await ctx.report_progress(4, 4, "Rendering chart")
if format == "png":
from mcnoaa_tides.charts.tides import render_tide_chart_png
@ -132,12 +136,19 @@ def register(mcp: FastMCP) -> None:
"air_pressure": {"product": "air_pressure", "hours": hours},
}
completed = 0
total_steps = len(requests) + 2 # 6 fetches + station lookup + render
async def fetch(name: str, params: dict) -> tuple[str, dict | None]:
nonlocal completed
try:
data = await noaa.get_data(station_id, **params)
return name, data
result = name, data
except Exception:
return name, None
result = name, None
completed += 1
await ctx.report_progress(completed, total_steps, f"Fetched {name}")
return result
results = await asyncio.gather(
*[fetch(name, params) for name, params in requests.items()]
@ -149,6 +160,7 @@ def register(mcp: FastMCP) -> None:
snapshot[name] = data
# Look up station name
await ctx.report_progress(total_steps - 1, total_steps, "Looking up station")
station_name = ""
try:
stations = await noaa.get_stations()
@ -158,6 +170,7 @@ def register(mcp: FastMCP) -> None:
except Exception:
pass
await ctx.report_progress(total_steps, total_steps, "Rendering dashboard")
if format == "png":
from mcnoaa_tides.charts.conditions import render_conditions_png

View File

@ -47,13 +47,20 @@ def register(mcp: FastMCP) -> None:
"air_pressure": {"product": "air_pressure", "hours": hours},
}
completed = 0
total = len(requests)
async def fetch(name: str, params: dict) -> tuple[str, dict | str]:
nonlocal completed
try:
data = await noaa.get_data(station_id, **params)
return name, data
result = name, data
except Exception as exc:
msg = str(exc) or type(exc).__name__
return name, f"{type(exc).__name__}: {msg}"
result = name, f"{type(exc).__name__}: {msg}"
completed += 1
await ctx.report_progress(completed, total, f"Fetched {name}")
return result
results = await asyncio.gather(
*[fetch(name, params) for name, params in requests.items()]

View File

@ -72,6 +72,8 @@ def register(mcp: FastMCP) -> None:
GPS coordinates find the nearest tidal station automatically.
"""
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
await ctx.report_progress(1, 3, "Resolving station")
station, distance_nm = await _resolve_station(
noaa, station_id, latitude, longitude,
)
@ -81,6 +83,7 @@ def register(mcp: FastMCP) -> None:
begin = (now_utc - timedelta(hours=12)).strftime("%Y%m%d %H:%M")
end = (now_utc + timedelta(hours=12)).strftime("%Y%m%d %H:%M")
await ctx.report_progress(2, 3, "Fetching tide data")
hilo_data, obs_data = await asyncio.gather(
noaa.get_data(
station["id"], "predictions",
@ -94,6 +97,8 @@ def register(mcp: FastMCP) -> None:
return_exceptions=True,
)
await ctx.report_progress(3, 3, "Classifying phase")
# Parse hilo predictions
hilo_preds = []
if isinstance(hilo_data, dict):
@ -142,6 +147,8 @@ def register(mcp: FastMCP) -> None:
Returns station info, tide schedule, conditions, and assessment.
"""
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
await ctx.report_progress(1, 6, "Resolving station")
station, distance_nm = await _resolve_station(
noaa, latitude=latitude, longitude=longitude,
)
@ -151,6 +158,9 @@ def register(mcp: FastMCP) -> None:
end = (now_utc + timedelta(hours=soak_hours)).strftime("%Y%m%d %H:%M")
# Parallel fetch: hilo predictions for soak window + current conditions
await ctx.report_progress(2, 6, "Fetching tide and conditions data")
completed = 0
hilo_fut = noaa.get_data(
station["id"], "predictions",
begin_date=begin, end_date=end,
@ -158,10 +168,14 @@ def register(mcp: FastMCP) -> None:
)
async def _safe_fetch(product, **kwargs):
nonlocal completed
try:
return await noaa.get_data(station["id"], product, **kwargs)
result = await noaa.get_data(station["id"], product, **kwargs)
except Exception:
return None
result = None
completed += 1
await ctx.report_progress(2 + completed, 6, f"Fetched {product}")
return result
hilo_data, wind_data, temp_data, pressure_data = await asyncio.gather(
hilo_fut,
@ -170,6 +184,8 @@ def register(mcp: FastMCP) -> None:
_safe_fetch("air_pressure", hours=3, time_zone="gmt"),
)
await ctx.report_progress(6, 6, "Assessing conditions")
# Build tide schedule
hilo_events = parse_hilo_predictions(hilo_data.get("predictions", []))
tide_schedule = [
@ -270,6 +286,8 @@ def register(mcp: FastMCP) -> None:
)
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
num_events = len(events)
num_stations = 0 # determined after resolution
# Group events by nearest station to batch hilo fetches
station_groups: dict[str, list[tuple[int, dict, float | None]]] = {}
@ -279,6 +297,9 @@ def register(mcp: FastMCP) -> None:
lon = event.get("longitude")
if lat is None or lon is None:
continue
await ctx.report_progress(
idx + 1, num_events, f"Resolving station for event {idx + 1}/{num_events}",
)
nearest = await noaa.find_nearest(float(lat), float(lon), limit=1)
if not nearest:
continue
@ -290,8 +311,16 @@ def register(mcp: FastMCP) -> None:
# For each station group, fetch hilo once for the full time window
enriched = [None] * len(events)
num_stations = len(station_groups)
station_idx = 0
for station_id, group in station_groups.items():
station_idx += 1
await ctx.report_progress(
station_idx, num_stations,
f"Processing station {station_id}",
)
timestamps = []
for _, event, _ in group:
ts_str = event.get("timestamp", "")
@ -370,6 +399,7 @@ def register(mcp: FastMCP) -> None:
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
# Parallel fetch: observed levels + 6-minute predictions
await ctx.report_progress(1, 3, "Fetching observations and predictions")
obs_data, pred_data = await asyncio.gather(
noaa.get_data(
station_id, "water_level",
@ -392,6 +422,7 @@ def register(mcp: FastMCP) -> None:
}
# Parse predictions into parallel lists for interpolation
await ctx.report_progress(2, 3, "Computing deviations")
pred_times = []
pred_values = []
for p in pred_records:
@ -437,6 +468,7 @@ def register(mcp: FastMCP) -> None:
direction = "above" if mean_signed > 0 else "below"
# Risk classification
await ctx.report_progress(3, 3, "Classifying risk")
if max_dev >= threshold_ft * 2:
risk = "high"
explanation = (