Add MCP progress reporting to all multi-fetch tools
ctx.report_progress() gives clients real-time visibility into long-running tool calls. Per-fetch counters for parallel gather calls, stage-based milestones for linear pipelines. No-op when client doesn't send a progressToken.
This commit is contained in:
parent
9d25f5efe3
commit
fb574d26b8
@ -38,6 +38,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
today = datetime.now(timezone.utc).strftime("%Y%m%d")
|
today = datetime.now(timezone.utc).strftime("%Y%m%d")
|
||||||
|
|
||||||
# Fetch predictions (6-minute interval for smooth curve) + hilo for markers
|
# Fetch predictions (6-minute interval for smooth curve) + hilo for markers
|
||||||
|
await ctx.report_progress(1, 4, "Fetching predictions")
|
||||||
predictions_raw, hilo_raw = await asyncio.gather(
|
predictions_raw, hilo_raw = await asyncio.gather(
|
||||||
noaa.get_data(
|
noaa.get_data(
|
||||||
station_id, product="predictions", begin_date=today,
|
station_id, product="predictions", begin_date=today,
|
||||||
@ -62,6 +63,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
# Fetch observed water levels if requested
|
# Fetch observed water levels if requested
|
||||||
observed = None
|
observed = None
|
||||||
if include_observed:
|
if include_observed:
|
||||||
|
await ctx.report_progress(2, 4, "Fetching observed data")
|
||||||
try:
|
try:
|
||||||
obs_raw = await noaa.get_data(
|
obs_raw = await noaa.get_data(
|
||||||
station_id, product="water_level", hours=hours,
|
station_id, product="water_level", hours=hours,
|
||||||
@ -71,6 +73,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
pass # observed overlay is optional — skip on failure
|
pass # observed overlay is optional — skip on failure
|
||||||
|
|
||||||
# Look up station name
|
# Look up station name
|
||||||
|
await ctx.report_progress(3, 4, "Looking up station")
|
||||||
station_name = ""
|
station_name = ""
|
||||||
try:
|
try:
|
||||||
stations = await noaa.get_stations()
|
stations = await noaa.get_stations()
|
||||||
@ -80,6 +83,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
await ctx.report_progress(4, 4, "Rendering chart")
|
||||||
if format == "png":
|
if format == "png":
|
||||||
from mcnoaa_tides.charts.tides import render_tide_chart_png
|
from mcnoaa_tides.charts.tides import render_tide_chart_png
|
||||||
|
|
||||||
@ -132,12 +136,19 @@ def register(mcp: FastMCP) -> None:
|
|||||||
"air_pressure": {"product": "air_pressure", "hours": hours},
|
"air_pressure": {"product": "air_pressure", "hours": hours},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
completed = 0
|
||||||
|
total_steps = len(requests) + 2 # 6 fetches + station lookup + render
|
||||||
|
|
||||||
async def fetch(name: str, params: dict) -> tuple[str, dict | None]:
|
async def fetch(name: str, params: dict) -> tuple[str, dict | None]:
|
||||||
|
nonlocal completed
|
||||||
try:
|
try:
|
||||||
data = await noaa.get_data(station_id, **params)
|
data = await noaa.get_data(station_id, **params)
|
||||||
return name, data
|
result = name, data
|
||||||
except Exception:
|
except Exception:
|
||||||
return name, None
|
result = name, None
|
||||||
|
completed += 1
|
||||||
|
await ctx.report_progress(completed, total_steps, f"Fetched {name}")
|
||||||
|
return result
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[fetch(name, params) for name, params in requests.items()]
|
*[fetch(name, params) for name, params in requests.items()]
|
||||||
@ -149,6 +160,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
snapshot[name] = data
|
snapshot[name] = data
|
||||||
|
|
||||||
# Look up station name
|
# Look up station name
|
||||||
|
await ctx.report_progress(total_steps - 1, total_steps, "Looking up station")
|
||||||
station_name = ""
|
station_name = ""
|
||||||
try:
|
try:
|
||||||
stations = await noaa.get_stations()
|
stations = await noaa.get_stations()
|
||||||
@ -158,6 +170,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
await ctx.report_progress(total_steps, total_steps, "Rendering dashboard")
|
||||||
if format == "png":
|
if format == "png":
|
||||||
from mcnoaa_tides.charts.conditions import render_conditions_png
|
from mcnoaa_tides.charts.conditions import render_conditions_png
|
||||||
|
|
||||||
|
|||||||
@ -47,13 +47,20 @@ def register(mcp: FastMCP) -> None:
|
|||||||
"air_pressure": {"product": "air_pressure", "hours": hours},
|
"air_pressure": {"product": "air_pressure", "hours": hours},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
completed = 0
|
||||||
|
total = len(requests)
|
||||||
|
|
||||||
async def fetch(name: str, params: dict) -> tuple[str, dict | str]:
|
async def fetch(name: str, params: dict) -> tuple[str, dict | str]:
|
||||||
|
nonlocal completed
|
||||||
try:
|
try:
|
||||||
data = await noaa.get_data(station_id, **params)
|
data = await noaa.get_data(station_id, **params)
|
||||||
return name, data
|
result = name, data
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
msg = str(exc) or type(exc).__name__
|
msg = str(exc) or type(exc).__name__
|
||||||
return name, f"{type(exc).__name__}: {msg}"
|
result = name, f"{type(exc).__name__}: {msg}"
|
||||||
|
completed += 1
|
||||||
|
await ctx.report_progress(completed, total, f"Fetched {name}")
|
||||||
|
return result
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[fetch(name, params) for name, params in requests.items()]
|
*[fetch(name, params) for name, params in requests.items()]
|
||||||
|
|||||||
@ -72,6 +72,8 @@ def register(mcp: FastMCP) -> None:
|
|||||||
GPS coordinates find the nearest tidal station automatically.
|
GPS coordinates find the nearest tidal station automatically.
|
||||||
"""
|
"""
|
||||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||||
|
|
||||||
|
await ctx.report_progress(1, 3, "Resolving station")
|
||||||
station, distance_nm = await _resolve_station(
|
station, distance_nm = await _resolve_station(
|
||||||
noaa, station_id, latitude, longitude,
|
noaa, station_id, latitude, longitude,
|
||||||
)
|
)
|
||||||
@ -81,6 +83,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
begin = (now_utc - timedelta(hours=12)).strftime("%Y%m%d %H:%M")
|
begin = (now_utc - timedelta(hours=12)).strftime("%Y%m%d %H:%M")
|
||||||
end = (now_utc + timedelta(hours=12)).strftime("%Y%m%d %H:%M")
|
end = (now_utc + timedelta(hours=12)).strftime("%Y%m%d %H:%M")
|
||||||
|
|
||||||
|
await ctx.report_progress(2, 3, "Fetching tide data")
|
||||||
hilo_data, obs_data = await asyncio.gather(
|
hilo_data, obs_data = await asyncio.gather(
|
||||||
noaa.get_data(
|
noaa.get_data(
|
||||||
station["id"], "predictions",
|
station["id"], "predictions",
|
||||||
@ -94,6 +97,8 @@ def register(mcp: FastMCP) -> None:
|
|||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await ctx.report_progress(3, 3, "Classifying phase")
|
||||||
|
|
||||||
# Parse hilo predictions
|
# Parse hilo predictions
|
||||||
hilo_preds = []
|
hilo_preds = []
|
||||||
if isinstance(hilo_data, dict):
|
if isinstance(hilo_data, dict):
|
||||||
@ -142,6 +147,8 @@ def register(mcp: FastMCP) -> None:
|
|||||||
Returns station info, tide schedule, conditions, and assessment.
|
Returns station info, tide schedule, conditions, and assessment.
|
||||||
"""
|
"""
|
||||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||||
|
|
||||||
|
await ctx.report_progress(1, 6, "Resolving station")
|
||||||
station, distance_nm = await _resolve_station(
|
station, distance_nm = await _resolve_station(
|
||||||
noaa, latitude=latitude, longitude=longitude,
|
noaa, latitude=latitude, longitude=longitude,
|
||||||
)
|
)
|
||||||
@ -151,6 +158,9 @@ def register(mcp: FastMCP) -> None:
|
|||||||
end = (now_utc + timedelta(hours=soak_hours)).strftime("%Y%m%d %H:%M")
|
end = (now_utc + timedelta(hours=soak_hours)).strftime("%Y%m%d %H:%M")
|
||||||
|
|
||||||
# Parallel fetch: hilo predictions for soak window + current conditions
|
# Parallel fetch: hilo predictions for soak window + current conditions
|
||||||
|
await ctx.report_progress(2, 6, "Fetching tide and conditions data")
|
||||||
|
completed = 0
|
||||||
|
|
||||||
hilo_fut = noaa.get_data(
|
hilo_fut = noaa.get_data(
|
||||||
station["id"], "predictions",
|
station["id"], "predictions",
|
||||||
begin_date=begin, end_date=end,
|
begin_date=begin, end_date=end,
|
||||||
@ -158,10 +168,14 @@ def register(mcp: FastMCP) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _safe_fetch(product, **kwargs):
|
async def _safe_fetch(product, **kwargs):
|
||||||
|
nonlocal completed
|
||||||
try:
|
try:
|
||||||
return await noaa.get_data(station["id"], product, **kwargs)
|
result = await noaa.get_data(station["id"], product, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
result = None
|
||||||
|
completed += 1
|
||||||
|
await ctx.report_progress(2 + completed, 6, f"Fetched {product}")
|
||||||
|
return result
|
||||||
|
|
||||||
hilo_data, wind_data, temp_data, pressure_data = await asyncio.gather(
|
hilo_data, wind_data, temp_data, pressure_data = await asyncio.gather(
|
||||||
hilo_fut,
|
hilo_fut,
|
||||||
@ -170,6 +184,8 @@ def register(mcp: FastMCP) -> None:
|
|||||||
_safe_fetch("air_pressure", hours=3, time_zone="gmt"),
|
_safe_fetch("air_pressure", hours=3, time_zone="gmt"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await ctx.report_progress(6, 6, "Assessing conditions")
|
||||||
|
|
||||||
# Build tide schedule
|
# Build tide schedule
|
||||||
hilo_events = parse_hilo_predictions(hilo_data.get("predictions", []))
|
hilo_events = parse_hilo_predictions(hilo_data.get("predictions", []))
|
||||||
tide_schedule = [
|
tide_schedule = [
|
||||||
@ -270,6 +286,8 @@ def register(mcp: FastMCP) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||||
|
num_events = len(events)
|
||||||
|
num_stations = 0 # determined after resolution
|
||||||
|
|
||||||
# Group events by nearest station to batch hilo fetches
|
# Group events by nearest station to batch hilo fetches
|
||||||
station_groups: dict[str, list[tuple[int, dict, float | None]]] = {}
|
station_groups: dict[str, list[tuple[int, dict, float | None]]] = {}
|
||||||
@ -279,6 +297,9 @@ def register(mcp: FastMCP) -> None:
|
|||||||
lon = event.get("longitude")
|
lon = event.get("longitude")
|
||||||
if lat is None or lon is None:
|
if lat is None or lon is None:
|
||||||
continue
|
continue
|
||||||
|
await ctx.report_progress(
|
||||||
|
idx + 1, num_events, f"Resolving station for event {idx + 1}/{num_events}",
|
||||||
|
)
|
||||||
nearest = await noaa.find_nearest(float(lat), float(lon), limit=1)
|
nearest = await noaa.find_nearest(float(lat), float(lon), limit=1)
|
||||||
if not nearest:
|
if not nearest:
|
||||||
continue
|
continue
|
||||||
@ -290,8 +311,16 @@ def register(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
# For each station group, fetch hilo once for the full time window
|
# For each station group, fetch hilo once for the full time window
|
||||||
enriched = [None] * len(events)
|
enriched = [None] * len(events)
|
||||||
|
num_stations = len(station_groups)
|
||||||
|
station_idx = 0
|
||||||
|
|
||||||
for station_id, group in station_groups.items():
|
for station_id, group in station_groups.items():
|
||||||
|
station_idx += 1
|
||||||
|
await ctx.report_progress(
|
||||||
|
station_idx, num_stations,
|
||||||
|
f"Processing station {station_id}",
|
||||||
|
)
|
||||||
|
|
||||||
timestamps = []
|
timestamps = []
|
||||||
for _, event, _ in group:
|
for _, event, _ in group:
|
||||||
ts_str = event.get("timestamp", "")
|
ts_str = event.get("timestamp", "")
|
||||||
@ -370,6 +399,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
noaa: NOAAClient = ctx.lifespan_context["noaa_client"]
|
||||||
|
|
||||||
# Parallel fetch: observed levels + 6-minute predictions
|
# Parallel fetch: observed levels + 6-minute predictions
|
||||||
|
await ctx.report_progress(1, 3, "Fetching observations and predictions")
|
||||||
obs_data, pred_data = await asyncio.gather(
|
obs_data, pred_data = await asyncio.gather(
|
||||||
noaa.get_data(
|
noaa.get_data(
|
||||||
station_id, "water_level",
|
station_id, "water_level",
|
||||||
@ -392,6 +422,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Parse predictions into parallel lists for interpolation
|
# Parse predictions into parallel lists for interpolation
|
||||||
|
await ctx.report_progress(2, 3, "Computing deviations")
|
||||||
pred_times = []
|
pred_times = []
|
||||||
pred_values = []
|
pred_values = []
|
||||||
for p in pred_records:
|
for p in pred_records:
|
||||||
@ -437,6 +468,7 @@ def register(mcp: FastMCP) -> None:
|
|||||||
direction = "above" if mean_signed > 0 else "below"
|
direction = "above" if mean_signed > 0 else "below"
|
||||||
|
|
||||||
# Risk classification
|
# Risk classification
|
||||||
|
await ctx.report_progress(3, 3, "Classifying risk")
|
||||||
if max_dev >= threshold_ft * 2:
|
if max_dev >= threshold_ft * 2:
|
||||||
risk = "high"
|
risk = "high"
|
||||||
explanation = (
|
explanation = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user