Harden codebase from Hamilton code review findings

Address 20 findings from safety-critical review:

CRITICAL:
- C1: Replace ET.fromstring with defusedxml across all XML parsers
- C2: Fix client init failure leaving half-initialized state; clean
  up HTTP client on startup failure so next connection can retry

HIGH:
- H1: Replace unbounded dicts with LRU caches (maxsize=500)
- H2: Move Nominatim rate limiter from module globals to per-instance
  state on GIBSClient, eliminating shared mutable state
- H3: Validate _parse_rgb input, return (0,0,0) on malformed data
- H4: Add exponential backoff retry for capabilities loading
- H5: Invert WMS error detection to verify image content-type
- H6: Clamp image dimensions to 4096 max to prevent OOM

MEDIUM:
- M1: Convert images to RGB mode in compare_dates for RGBA safety
- M2: Narrow DescribeDomains XML matching to TimeDomain elements
- M3: Add BBox model_validator for coordinate range validation
- M4: Add ET.ParseError to colormap fetch exception handling
- M5: Replace bare except Exception with specific types in server
- M6: Catch ValueError from _resolve_bbox in imagery tools for
  consistent error returns
- M7: Only cache successful geocoding lookups (no negative caching)

LOW:
- L3: Derive USER_AGENT version from package __version__
- L5: Remove unused start_date/end_date params from check_layer_dates
This commit is contained in:
Ryan Malloy 2026-02-18 15:14:02 -07:00
parent f7fad32a9e
commit d163ade9a4
10 changed files with 341 additions and 188 deletions

View File

@ -18,6 +18,7 @@ classifiers = [
dependencies = [ dependencies = [
"fastmcp>=3.0.0", "fastmcp>=3.0.0",
"pillow>=12.0.0", "pillow>=12.0.0",
"defusedxml>=0.7.1",
] ]
[project.scripts] [project.scripts]

View File

@ -3,6 +3,8 @@
import logging import logging
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import defusedxml.ElementTree as DefusedET
from mcgibs.constants import WMTS_NS from mcgibs.constants import WMTS_NS
from mcgibs.models import BBox, LayerInfo, TimeDimension from mcgibs.models import BBox, LayerInfo, TimeDimension
@ -209,7 +211,7 @@ def parse_capabilities(xml_text: str) -> dict[str, LayerInfo]:
Handles the ~5MB GetCapabilities document from NASA GIBS by iterating Handles the ~5MB GetCapabilities document from NASA GIBS by iterating
through all Layer elements found under the wmts:Contents element. through all Layer elements found under the wmts:Contents element.
""" """
root = ET.fromstring(xml_text) root = DefusedET.fromstring(xml_text)
layers: dict[str, LayerInfo] = {} layers: dict[str, LayerInfo] = {}

View File

@ -1,8 +1,13 @@
"""Async HTTP client for all NASA GIBS API interactions.""" """Async HTTP client for all NASA GIBS API interactions."""
import asyncio
import logging import logging
import time
import xml.etree.ElementTree as ET
from collections import OrderedDict
from io import BytesIO from io import BytesIO
import defusedxml.ElementTree as DefusedET
import httpx import httpx
from PIL import Image from PIL import Image
@ -23,26 +28,68 @@ from mcgibs.models import BBox, ColorMapSet, GeocodingResult, LayerInfo
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Maximum entries per cache before LRU eviction
_MAX_CACHE_SIZE = 500
# Maximum retries for capabilities loading at startup
_INIT_MAX_RETRIES = 3
_INIT_RETRY_DELAYS = [2.0, 4.0, 8.0]
# Maximum image dimensions to prevent OOM from LLM requests
MAX_IMAGE_DIMENSION = 4096
class _LRUCache(OrderedDict):
"""Simple LRU cache backed by OrderedDict."""
def __init__(self, maxsize: int = _MAX_CACHE_SIZE) -> None:
super().__init__()
self._maxsize = maxsize
def get_cached(self, key):
if key in self:
self.move_to_end(key)
return self[key]
return _SENTINEL
def put(self, key, value):
if key in self:
self.move_to_end(key)
self[key] = value
while len(self) > self._maxsize:
self.popitem(last=False)
_SENTINEL = object()
class GIBSClient: class GIBSClient:
"""Async client for NASA GIBS APIs. """Async client for NASA GIBS APIs.
Wraps WMTS (discovery), WMS (imagery), layer metadata, and colormaps Wraps WMTS (discovery), WMS (imagery), layer metadata, and colormaps
behind a single interface. Designed to be initialized once per session behind a single interface. Designed to be initialized once per session
via FastMCP lifespan and reused across tool calls. via FastMCP middleware and reused across tool calls.
Rate limiting for Nominatim (1 req/sec) is managed per-client instance
rather than as module-level globals. Caches are LRU-bounded to prevent
unbounded memory growth in long-running sessions.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._http: httpx.AsyncClient | None = None self._http: httpx.AsyncClient | None = None
self.layer_index: dict[str, LayerInfo] = {} self.layer_index: dict[str, LayerInfo] = {}
self._metadata_cache: dict[str, dict] = {} self._metadata_cache: _LRUCache = _LRUCache()
self._colormap_cache: dict[str, ColorMapSet] = {} self._colormap_cache: _LRUCache = _LRUCache()
self._geocode_cache: dict[str, GeocodingResult | None] = {} self._geocode_cache: _LRUCache = _LRUCache()
# Rate limiter state — owned by this instance, not module globals
self._nominatim_lock = asyncio.Lock()
self._last_nominatim_time: float = 0.0
async def initialize(self) -> None: async def initialize(self) -> None:
"""Create HTTP client and load WMTS capabilities.""" """Create HTTP client and load WMTS capabilities with retry."""
self._http = httpx.AsyncClient(timeout=60.0, follow_redirects=True) self._http = httpx.AsyncClient(timeout=60.0, follow_redirects=True)
await self._load_capabilities() await self._load_capabilities_with_retry()
async def close(self) -> None: async def close(self) -> None:
"""Shut down the HTTP client.""" """Shut down the HTTP client."""
@ -56,8 +103,47 @@ class GIBSClient:
raise RuntimeError("GIBSClient not initialized — call initialize() first") raise RuntimeError("GIBSClient not initialized — call initialize() first")
return self._http return self._http
# --- Rate limiting (per-instance) ---
async def nominatim_rate_limit(self) -> None:
"""Enforce 1-second gap between Nominatim requests."""
async with self._nominatim_lock:
now = time.monotonic()
elapsed = now - self._last_nominatim_time
if elapsed < 1.0:
await asyncio.sleep(1.0 - elapsed)
self._last_nominatim_time = time.monotonic()
# --- Capabilities --- # --- Capabilities ---
async def _load_capabilities_with_retry(self) -> None:
"""Fetch capabilities with exponential backoff retry."""
last_exc: Exception | None = None
for attempt in range(_INIT_MAX_RETRIES):
try:
await self._load_capabilities()
return
except (httpx.HTTPError, ET.ParseError) as exc:
last_exc = exc
if attempt < _INIT_MAX_RETRIES - 1:
delay = _INIT_RETRY_DELAYS[attempt]
log.warning(
"Capabilities fetch failed (attempt %d/%d), retrying in %.0fs: %s",
attempt + 1,
_INIT_MAX_RETRIES,
delay,
exc,
)
await asyncio.sleep(delay)
log.error(
"Failed to load capabilities after %d attempts",
_INIT_MAX_RETRIES,
)
raise RuntimeError(
f"GIBS capabilities unavailable after {_INIT_MAX_RETRIES} attempts: {last_exc}"
)
async def _load_capabilities(self) -> None: async def _load_capabilities(self) -> None:
"""Fetch and parse WMTS GetCapabilities for EPSG:4326.""" """Fetch and parse WMTS GetCapabilities for EPSG:4326."""
url = WMTS_CAPABILITIES_URL.format(epsg="4326") url = WMTS_CAPABILITIES_URL.format(epsg="4326")
@ -74,13 +160,10 @@ class GIBSClient:
# --- Layer Metadata (enrichment) --- # --- Layer Metadata (enrichment) ---
async def fetch_layer_metadata(self, layer_id: str) -> dict: async def fetch_layer_metadata(self, layer_id: str) -> dict:
"""Fetch the enriched JSON metadata for a layer. """Fetch the enriched JSON metadata for a layer."""
cached = self._metadata_cache.get_cached(layer_id)
This supplements the GetCapabilities data with measurement, if cached is not _SENTINEL:
instrument, platform, and other fields not in the WMTS XML. return cached
"""
if layer_id in self._metadata_cache:
return self._metadata_cache[layer_id]
url = f"{LAYER_METADATA_BASE}/{layer_id}.json" url = f"{LAYER_METADATA_BASE}/{layer_id}.json"
try: try:
@ -91,7 +174,7 @@ class GIBSClient:
log.debug("Layer metadata not available for %s: %s", layer_id, exc) log.debug("Layer metadata not available for %s: %s", layer_id, exc)
data = {} data = {}
self._metadata_cache[layer_id] = data self._metadata_cache.put(layer_id, data)
# Enrich the layer_index entry if it exists # Enrich the layer_index entry if it exists
layer = self.layer_index.get(layer_id) layer = self.layer_index.get(layer_id)
@ -110,10 +193,10 @@ class GIBSClient:
async def fetch_colormap(self, layer_id: str) -> ColorMapSet | None: async def fetch_colormap(self, layer_id: str) -> ColorMapSet | None:
"""Fetch and parse the colormap XML for a layer.""" """Fetch and parse the colormap XML for a layer."""
if layer_id in self._colormap_cache: cached = self._colormap_cache.get_cached(layer_id)
return self._colormap_cache[layer_id] if cached is not _SENTINEL:
return cached
# Derive colormap URL — GIBS uses the layer identifier as filename
layer = self.layer_index.get(layer_id) layer = self.layer_index.get(layer_id)
colormap_id = (layer.colormap_id if layer else None) or layer_id colormap_id = (layer.colormap_id if layer else None) or layer_id
url = f"{COLORMAP_BASE}/{colormap_id}.xml" url = f"{COLORMAP_BASE}/{colormap_id}.xml"
@ -122,11 +205,11 @@ class GIBSClient:
resp = await self.http.get(url) resp = await self.http.get(url)
resp.raise_for_status() resp.raise_for_status()
colormap_set = parse_colormap(resp.text) colormap_set = parse_colormap(resp.text)
except (httpx.HTTPError, ValueError) as exc: except (httpx.HTTPError, ET.ParseError, ValueError) as exc:
log.debug("Colormap not available for %s: %s", layer_id, exc) log.debug("Colormap not available for %s: %s", layer_id, exc)
return None return None
self._colormap_cache[layer_id] = colormap_set self._colormap_cache.put(layer_id, colormap_set)
return colormap_set return colormap_set
async def explain_layer_colormap(self, layer_id: str) -> str: async def explain_layer_colormap(self, layer_id: str) -> str:
@ -139,8 +222,21 @@ class GIBSClient:
# --- Geocoding --- # --- Geocoding ---
async def resolve_place(self, place: str) -> GeocodingResult | None: async def resolve_place(self, place: str) -> GeocodingResult | None:
"""Geocode a place name via Nominatim.""" """Geocode a place name via Nominatim.
return await geocode(self.http, place, self._geocode_cache)
Uses per-instance rate limiting and LRU-bounded cache.
"""
key = place.strip().lower()
cached = self._geocode_cache.get_cached(key)
if cached is not _SENTINEL:
return cached
result = await geocode(self.http, place, rate_limiter=self)
# Only cache successful lookups — don't permanently cache failures
# (M7: transient Nominatim failures should be retryable)
if result is not None:
self._geocode_cache.put(key, result)
return result
# --- WMS Imagery --- # --- WMS Imagery ---
@ -158,8 +254,13 @@ class GIBSClient:
Returns raw image bytes (JPEG or PNG). Returns raw image bytes (JPEG or PNG).
""" """
# H6: clamp dimensions to prevent OOM from unbounded requests
width = max(1, min(width, MAX_IMAGE_DIMENSION))
height = max(1, min(height, MAX_IMAGE_DIMENSION))
params = dict(WMS_DEFAULTS) params = dict(WMS_DEFAULTS)
params.update({ params.update(
{
"LAYERS": layer_id, "LAYERS": layer_id,
"SRS": f"EPSG:{epsg}", "SRS": f"EPSG:{epsg}",
"BBOX": bbox.wms_bbox, "BBOX": bbox.wms_bbox,
@ -167,15 +268,20 @@ class GIBSClient:
"HEIGHT": str(height), "HEIGHT": str(height),
"FORMAT": image_format, "FORMAT": image_format,
"TIME": date, "TIME": date,
}) }
)
url = WMS_BASE.format(epsg=epsg) url = WMS_BASE.format(epsg=epsg)
resp = await self.http.get(url, params=params) resp = await self.http.get(url, params=params)
resp.raise_for_status() resp.raise_for_status()
# H5: verify the response IS an image, not just that it ISN'T XML.
# WMS can return HTTP 200 with OGC ServiceException XML.
content_type = resp.headers.get("content-type", "") content_type = resp.headers.get("content-type", "")
if "xml" in content_type or "text" in content_type: if not content_type.startswith("image/"):
raise RuntimeError(f"WMS returned error: {resp.text[:500]}") raise RuntimeError(
f"WMS returned non-image content-type '{content_type}': {resp.text[:500]}"
)
return resp.content return resp.content
@ -189,10 +295,7 @@ class GIBSClient:
image_format: str = "image/jpeg", image_format: str = "image/jpeg",
epsg: str = DEFAULT_EPSG, epsg: str = DEFAULT_EPSG,
) -> bytes: ) -> bytes:
"""Fetch a multi-layer WMS composite image. """Fetch a multi-layer WMS composite image."""
WMS supports comma-separated LAYERS for overlay compositing.
"""
return await self.get_wms_image( return await self.get_wms_image(
layer_id=",".join(layer_ids), layer_id=",".join(layer_ids),
date=date, date=date,
@ -213,22 +316,28 @@ class GIBSClient:
height: int = 512, height: int = 512,
image_format: str = "image/jpeg", image_format: str = "image/jpeg",
) -> bytes: ) -> bytes:
"""Fetch two images and compose a side-by-side comparison. """Fetch two images and compose a side-by-side comparison."""
Returns a single image with the "before" date on the left
and the "after" date on the right, each labeled.
"""
img_before = await self.get_wms_image( img_before = await self.get_wms_image(
layer_id, date_before, bbox, width, height, image_format, layer_id,
date_before,
bbox,
width,
height,
image_format,
) )
img_after = await self.get_wms_image( img_after = await self.get_wms_image(
layer_id, date_after, bbox, width, height, image_format, layer_id,
date_after,
bbox,
width,
height,
image_format,
) )
pil_before = Image.open(BytesIO(img_before)) # M1: convert to RGB to avoid mode mismatch with RGBA PNGs
pil_after = Image.open(BytesIO(img_after)) pil_before = Image.open(BytesIO(img_before)).convert("RGB")
pil_after = Image.open(BytesIO(img_after)).convert("RGB")
# Create side-by-side composite
total_width = pil_before.width + pil_after.width total_width = pil_before.width + pil_after.width
max_height = max(pil_before.height, pil_after.height) max_height = max(pil_before.height, pil_after.height)
composite = Image.new("RGB", (total_width, max_height)) composite = Image.new("RGB", (total_width, max_height))
@ -246,28 +355,35 @@ class GIBSClient:
layer_id: str, layer_id: str,
epsg: str = DEFAULT_EPSG, epsg: str = DEFAULT_EPSG,
) -> dict: ) -> dict:
"""Query WMTS DescribeDomains for available date ranges. """Query WMTS DescribeDomains for available date ranges."""
url = WMTS_DESCRIBE_DOMAINS_URL.format(
Returns a dict with 'time_domain' key (ISO 8601 interval or list epsg=epsg,
of dates) and 'spatial_domain' if available. layer_id=layer_id,
""" )
url = WMTS_DESCRIBE_DOMAINS_URL.format(epsg=epsg, layer_id=layer_id)
resp = await self.http.get(url) resp = await self.http.get(url)
resp.raise_for_status() resp.raise_for_status()
# DescribeDomains returns XML — extract time domain root = DefusedET.fromstring(resp.text)
import xml.etree.ElementTree as ET
root = ET.fromstring(resp.text)
result: dict[str, str] = {} result: dict[str, str] = {}
# Look for time domain in various possible locations # M2: look specifically for TimeDomain elements rather than
# broadly matching any "value" element
for elem in root.iter(): for elem in root.iter():
tag = elem.tag.rpartition("}")[-1] if "}" in elem.tag else elem.tag local = elem.tag.rpartition("}")[-1] if "}" in elem.tag else elem.tag
if tag.lower() in ("timedomain", "value") and elem.text: if local.lower() == "timedomain" and elem.text:
text = elem.text.strip() text = elem.text.strip()
if text and ("/" in text or "-" in text): if text:
result["time_domain"] = text
break
# Fallback: look in child Value elements of a Dimension
if "time_domain" not in result:
for elem in root.iter():
local = elem.tag.rpartition("}")[-1] if "}" in elem.tag else elem.tag
if local.lower() == "value" and elem.text:
text = elem.text.strip()
# Only match ISO 8601-ish values (contain date separators)
if "/" in text and len(text) > 8:
result["time_domain"] = text result["time_domain"] = text
break break
@ -305,10 +421,7 @@ class GIBSClient:
layer_id: str, layer_id: str,
orientation: str = "horizontal", orientation: str = "horizontal",
) -> bytes | None: ) -> bytes | None:
"""Fetch the pre-rendered legend image for a layer. """Fetch the pre-rendered legend image for a layer."""
GIBS provides legend images via the GetLegendGraphic WMS call.
"""
layer = self.layer_index.get(layer_id) layer = self.layer_index.get(layer_id)
if layer and layer.legend_url: if layer and layer.legend_url:
try: try:
@ -334,7 +447,7 @@ class GIBSClient:
resp = await self.http.get(url, params=params) resp = await self.http.get(url, params=params)
resp.raise_for_status() resp.raise_for_status()
content_type = resp.headers.get("content-type", "") content_type = resp.headers.get("content-type", "")
if "image" in content_type: if content_type.startswith("image/"):
return resp.content return resp.content
except httpx.HTTPError as exc: except httpx.HTTPError as exc:
log.debug("Legend not available for %s: %s", layer_id, exc) log.debug("Legend not available for %s: %s", layer_id, exc)

View File

@ -7,7 +7,8 @@ the colors actually mean.
""" """
import re import re
import xml.etree.ElementTree as ET
import defusedxml.ElementTree as DefusedET
from mcgibs.models import ColorMap, ColorMapEntry, ColorMapSet, LegendEntry from mcgibs.models import ColorMap, ColorMapEntry, ColorMapSet, LegendEntry
@ -31,6 +32,7 @@ _UNIT_CONVERTERS: dict[str, tuple] = {
# --- Color naming --- # --- Color naming ---
def _describe_rgb(rgb: tuple[int, int, int]) -> str: def _describe_rgb(rgb: tuple[int, int, int]) -> str:
"""Return an approximate human-friendly color name for an RGB triple. """Return an approximate human-friendly color name for an RGB triple.
@ -132,9 +134,7 @@ _INTERVAL_RE = re.compile(
re.VERBOSE | re.IGNORECASE, re.VERBOSE | re.IGNORECASE,
) )
_SINGLE_VALUE_RE = re.compile( _SINGLE_VALUE_RE = re.compile(r"[\[\(]\s*([+\-]?\d+(?:\.\d+)?(?:[eE][+\-]?\d+)?)\s*[\]\)]")
r"[\[\(]\s*([+\-]?\d+(?:\.\d+)?(?:[eE][+\-]?\d+)?)\s*[\]\)]"
)
def _parse_interval_value(interval: str) -> tuple[float | None, float | None]: def _parse_interval_value(interval: str) -> tuple[float | None, float | None]:
@ -174,10 +174,19 @@ def _parse_interval_value(interval: str) -> tuple[float | None, float | None]:
# --- XML parsing --- # --- XML parsing ---
def _parse_rgb(raw: str) -> tuple[int, int, int]: def _parse_rgb(raw: str) -> tuple[int, int, int]:
"""Parse "r,g,b" string into an integer triple.""" """Parse "r,g,b" string into an integer triple.
Returns (0, 0, 0) for malformed input rather than crashing.
"""
parts = raw.split(",") parts = raw.split(",")
if len(parts) < 3:
return (0, 0, 0)
try:
return (int(parts[0]), int(parts[1]), int(parts[2])) return (int(parts[0]), int(parts[1]), int(parts[2]))
except (ValueError, IndexError):
return (0, 0, 0)
def parse_colormap(xml_text: str) -> ColorMapSet: def parse_colormap(xml_text: str) -> ColorMapSet:
@ -187,7 +196,7 @@ def parse_colormap(xml_text: str) -> ColorMapSet:
<ColorMap> children. Each <ColorMap> contains <ColorMapEntry> elements <ColorMap> children. Each <ColorMap> contains <ColorMapEntry> elements
and an optional <Legend> with <LegendEntry> children. and an optional <Legend> with <LegendEntry> children.
""" """
root = ET.fromstring(xml_text) root = DefusedET.fromstring(xml_text)
maps: list[ColorMap] = [] maps: list[ColorMap] = []
@ -250,6 +259,7 @@ def parse_colormap(xml_text: str) -> ColorMapSet:
# --- Natural-language explanation --- # --- Natural-language explanation ---
def _format_value(val: float, units: str) -> str: def _format_value(val: float, units: str) -> str:
"""Format a numeric value with optional unit conversion.""" """Format a numeric value with optional unit conversion."""
units_lower = units.lower().strip() units_lower = units.lower().strip()
@ -404,10 +414,7 @@ def explain_colormap(colormap_set: ColorMapSet) -> str:
return "No colormap data available." return "No colormap data available."
# Filter to non-transparent, non-nodata entries for analysis # Filter to non-transparent, non-nodata entries for analysis
data_entries = [ data_entries = [e for e in data_map.entries if not e.transparent and not e.nodata]
e for e in data_map.entries
if not e.transparent and not e.nodata
]
if not data_entries: if not data_entries:
return "This colormap contains only no-data / transparent entries." return "This colormap contains only no-data / transparent entries."

View File

@ -1,5 +1,7 @@
"""GIBS API endpoints, EPSG codes, and TileMatrixSet definitions.""" """GIBS API endpoints, EPSG codes, and TileMatrixSet definitions."""
from mcgibs import __version__ as _version
# GIBS base URLs — domain sharding (gibs-a/b/c) for parallel tile fetches # GIBS base URLs — domain sharding (gibs-a/b/c) for parallel tile fetches
WMTS_BASE = "https://gibs.earthdata.nasa.gov/wmts/epsg{epsg}/best" WMTS_BASE = "https://gibs.earthdata.nasa.gov/wmts/epsg{epsg}/best"
WMS_BASE = "https://gibs.earthdata.nasa.gov/wms/epsg{epsg}/best/wms.cgi" WMS_BASE = "https://gibs.earthdata.nasa.gov/wms/epsg{epsg}/best/wms.cgi"
@ -9,15 +11,11 @@ LAYER_METADATA_BASE = "https://gibs.earthdata.nasa.gov/layer-metadata/v1.0"
# GetCapabilities and DescribeDomains # GetCapabilities and DescribeDomains
WMTS_CAPABILITIES_URL = WMTS_BASE + "/1.0.0/WMTSCapabilities.xml" WMTS_CAPABILITIES_URL = WMTS_BASE + "/1.0.0/WMTSCapabilities.xml"
WMTS_DESCRIBE_DOMAINS_URL = ( WMTS_DESCRIBE_DOMAINS_URL = (
WMTS_BASE + "/wmts.cgi?SERVICE=WMTS&VERSION=1.0.0" WMTS_BASE + "/wmts.cgi?SERVICE=WMTS&VERSION=1.0.0&REQUEST=DescribeDomains&LAYER={layer_id}"
"&REQUEST=DescribeDomains&LAYER={layer_id}"
) )
# WMTS REST tile URL pattern # WMTS REST tile URL pattern
WMTS_TILE_URL = ( WMTS_TILE_URL = WMTS_BASE + "/{layer_id}/default/{date}/{tile_matrix_set}/{z}/{row}/{col}.{ext}"
WMTS_BASE + "/{layer_id}/default/{date}/{tile_matrix_set}"
"/{z}/{row}/{col}.{ext}"
)
# Nominatim geocoding # Nominatim geocoding
NOMINATIM_BASE = "https://nominatim.openstreetmap.org" NOMINATIM_BASE = "https://nominatim.openstreetmap.org"
@ -63,5 +61,6 @@ WMS_DEFAULTS = {
"HEIGHT": "1024", "HEIGHT": "1024",
} }
# User-Agent for Nominatim (required by their usage policy) # User-Agent for Nominatim (required by their usage policy).
USER_AGENT = "mcgibs-mcp-server/2026.02.18 (ryan@supported.systems)" # Version derived from package metadata to stay in sync with pyproject.toml.
USER_AGENT = f"mcgibs-mcp-server/{_version} (ryan@supported.systems)"

View File

@ -1,8 +1,6 @@
"""Async Nominatim geocoding with rate limiting and in-memory caching.""" """Async Nominatim geocoding with bbox utilities."""
import asyncio
import logging import logging
import time
import httpx import httpx
@ -11,43 +9,29 @@ from mcgibs.models import BBox, GeocodingResult
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Nominatim usage policy: max 1 request per second.
_nominatim_lock = asyncio.Lock()
_last_request_time = 0.0
async def _rate_limit() -> None:
"""Enforce a minimum 1-second gap between Nominatim requests."""
global _last_request_time
async with _nominatim_lock:
now = time.monotonic()
elapsed = now - _last_request_time
if elapsed < 1.0:
await asyncio.sleep(1.0 - elapsed)
_last_request_time = time.monotonic()
async def geocode( async def geocode(
client: httpx.AsyncClient, client: httpx.AsyncClient,
query: str, query: str,
cache: dict, rate_limiter=None,
) -> GeocodingResult | None: ) -> GeocodingResult | None:
"""Geocode a place name via Nominatim. """Geocode a place name via Nominatim.
Caching is the caller's responsibility (GIBSClient.resolve_place manages
an LRU cache). This function is a pure HTTP-call-and-parse layer.
Args: Args:
client: Shared httpx async client. client: Shared httpx async client.
query: Free-form place name (e.g. "Tokyo", "Amazon River"). query: Free-form place name (e.g. "Tokyo", "Amazon River").
cache: Dict used as an in-memory dedup cache (query -> result). rate_limiter: Object with async ``nominatim_rate_limit()`` method
for enforcing Nominatim's 1 req/sec policy. When called from
GIBSClient, this is the client instance itself.
Returns: Returns:
GeocodingResult on success, None if no results found. GeocodingResult on success, None if no results found.
""" """
key = query.strip().lower() if rate_limiter is not None and hasattr(rate_limiter, "nominatim_rate_limit"):
if key in cache: await rate_limiter.nominatim_rate_limit()
log.debug("Geocode cache hit: %s", key)
return cache[key]
await _rate_limit()
params = { params = {
"q": query, "q": query,
@ -71,7 +55,6 @@ async def geocode(
if not data: if not data:
log.debug("Nominatim returned no results for %r", query) log.debug("Nominatim returned no results for %r", query)
cache[key] = None
return None return None
hit = data[0] hit = data[0]
@ -96,7 +79,6 @@ async def geocode(
importance=float(hit.get("importance", 0.0)), importance=float(hit.get("importance", 0.0)),
) )
cache[key] = result
log.debug("Geocoded %r -> %s", query, result.display_name) log.debug("Geocoded %r -> %s", query, result.display_name)
return result return result

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
class BBox(BaseModel): class BBox(BaseModel):
@ -13,6 +13,20 @@ class BBox(BaseModel):
east: float = Field(description="Eastern longitude (-180 to 180)") east: float = Field(description="Eastern longitude (-180 to 180)")
north: float = Field(description="Northern latitude (-90 to 90)") north: float = Field(description="Northern latitude (-90 to 90)")
@model_validator(mode="after")
def _validate_ranges(self) -> BBox:
if self.south > self.north:
raise ValueError(f"south ({self.south}) must be <= north ({self.north})")
if not (-90 <= self.south <= 90 and -90 <= self.north <= 90):
raise ValueError(
f"Latitudes must be in [-90, 90], got south={self.south}, north={self.north}"
)
if not (-180 <= self.west <= 180 and -180 <= self.east <= 180):
raise ValueError(
f"Longitudes must be in [-180, 180], got west={self.west}, east={self.east}"
)
return self
@property @property
def wms_bbox(self) -> str: def wms_bbox(self) -> str:
"""Format as WMS BBOX parameter: minx,miny,maxx,maxy.""" """Format as WMS BBOX parameter: minx,miny,maxx,maxy."""

View File

@ -12,6 +12,7 @@ import base64
import json import json
import logging import logging
import httpx
from fastmcp import FastMCP from fastmcp import FastMCP
from fastmcp.server.middleware import Middleware from fastmcp.server.middleware import Middleware
@ -45,6 +46,7 @@ def _get_client() -> GIBSClient:
# --- Middleware: initialize client on session start --- # --- Middleware: initialize client on session start ---
class GIBSInitMiddleware(Middleware): class GIBSInitMiddleware(Middleware):
"""Load GIBS capabilities when the first client connects.""" """Load GIBS capabilities when the first client connects."""
@ -52,8 +54,15 @@ class GIBSInitMiddleware(Middleware):
global _client global _client
if _client is None: if _client is None:
log.info("Initializing GIBS client and loading capabilities...") log.info("Initializing GIBS client and loading capabilities...")
_client = GIBSClient() client = GIBSClient()
await _client.initialize() try:
await client.initialize()
except Exception:
# C2: clean up the HTTP client so we don't leak sockets,
# then re-raise so the next connection can retry.
await client.close()
raise
_client = client
log.info("GIBS client ready with %d layers", len(_client.layer_index)) log.info("GIBS client ready with %d layers", len(_client.layer_index))
return await call_next(context) return await call_next(context)
@ -65,6 +74,7 @@ mcp.middleware.append(GIBSInitMiddleware())
# TOOLS — Discovery # TOOLS — Discovery
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@mcp.tool( @mcp.tool(
description="Search NASA GIBS satellite imagery layers by keyword. " description="Search NASA GIBS satellite imagery layers by keyword. "
"Returns matching layers with titles, identifiers, and date ranges." "Returns matching layers with titles, identifiers, and date ranges."
@ -87,7 +97,12 @@ async def search_gibs_layers(
""" """
client = _get_client() client = _get_client()
results = search_layers( results = search_layers(
client.layer_index, query, measurement, period, ongoing, limit, client.layer_index,
query,
measurement,
period,
ongoing,
limit,
) )
if not results: if not results:
@ -175,9 +190,7 @@ async def get_layer_info(layer_id: str) -> str:
return json.dumps(info, indent=2) return json.dumps(info, indent=2)
@mcp.tool( @mcp.tool(description="List all measurement categories available in GIBS with layer counts.")
description="List all measurement categories available in GIBS with layer counts."
)
async def list_measurements() -> str: async def list_measurements() -> str:
"""List measurement categories across all layers.""" """List measurement categories across all layers."""
client = _get_client() client = _get_client()
@ -201,20 +214,12 @@ async def list_measurements() -> str:
return "\n".join(lines) return "\n".join(lines)
@mcp.tool( @mcp.tool(description="Check available date ranges for a GIBS layer via WMTS DescribeDomains.")
description="Check available date ranges for a GIBS layer via WMTS DescribeDomains." async def check_layer_dates(layer_id: str) -> str:
)
async def check_layer_dates(
layer_id: str,
start_date: str | None = None,
end_date: str | None = None,
) -> str:
"""Query what dates are available for a specific layer. """Query what dates are available for a specific layer.
Args: Args:
layer_id: The GIBS layer identifier. layer_id: The GIBS layer identifier.
start_date: Optional start date filter (YYYY-MM-DD).
end_date: Optional end date filter (YYYY-MM-DD).
""" """
client = _get_client() client = _get_client()
layer = client.get_layer(layer_id) layer = client.get_layer(layer_id)
@ -241,7 +246,7 @@ async def check_layer_dates(
domains = await client.describe_domains(layer_id) domains = await client.describe_domains(layer_id)
if "time_domain" in domains: if "time_domain" in domains:
lines.append(f" Live time domain: {domains['time_domain']}") lines.append(f" Live time domain: {domains['time_domain']}")
except Exception as exc: except (httpx.HTTPError, RuntimeError) as exc:
log.debug("DescribeDomains failed for %s: %s", layer_id, exc) log.debug("DescribeDomains failed for %s: %s", layer_id, exc)
return "\n".join(lines) return "\n".join(lines)
@ -251,6 +256,7 @@ async def check_layer_dates(
# TOOLS — Imagery # TOOLS — Imagery
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
async def _resolve_bbox( async def _resolve_bbox(
client: GIBSClient, client: GIBSClient,
bbox: list[float] | None, bbox: list[float] | None,
@ -299,11 +305,20 @@ async def get_imagery(
if layer is None: if layer is None:
return [{"type": "text", "text": f"Layer '{layer_id}' not found."}] return [{"type": "text", "text": f"Layer '{layer_id}' not found."}]
try:
resolved_bbox = await _resolve_bbox(client, bbox, place) resolved_bbox = await _resolve_bbox(client, bbox, place)
except (ValueError, Exception) as exc:
return [{"type": "text", "text": str(exc)}]
image_format = f"image/{format}" image_format = f"image/{format}"
image_bytes = await client.get_wms_image( image_bytes = await client.get_wms_image(
layer_id, date, resolved_bbox, width, height, image_format, layer_id,
date,
resolved_bbox,
width,
height,
image_format,
) )
description = ( description = (
@ -351,10 +366,16 @@ async def compare_dates(
if layer is None: if layer is None:
return [{"type": "text", "text": f"Layer '{layer_id}' not found."}] return [{"type": "text", "text": f"Layer '{layer_id}' not found."}]
try:
resolved_bbox = await _resolve_bbox(client, bbox, place) resolved_bbox = await _resolve_bbox(client, bbox, place)
except (ValueError, Exception) as exc:
return [{"type": "text", "text": str(exc)}]
composite_bytes = await client.compare_dates( composite_bytes = await client.compare_dates(
layer_id, date_before, date_after, resolved_bbox, layer_id,
date_before,
date_after,
resolved_bbox,
) )
description = ( description = (
@ -397,17 +418,22 @@ async def get_imagery_composite(
if len(layer_ids) > 5: if len(layer_ids) > 5:
return [{"type": "text", "text": "WMS supports at most 5 layers per composite."}] return [{"type": "text", "text": "WMS supports at most 5 layers per composite."}]
try:
resolved_bbox = await _resolve_bbox(client, bbox, place) resolved_bbox = await _resolve_bbox(client, bbox, place)
except (ValueError, Exception) as exc:
return [{"type": "text", "text": str(exc)}]
image_bytes = await client.get_wms_composite( image_bytes = await client.get_wms_composite(
layer_ids, date, resolved_bbox, width, height, layer_ids,
date,
resolved_bbox,
width,
height,
) )
layer_names = ", ".join(layer_ids) layer_names = ", ".join(layer_ids)
description = ( description = (
f"Composite: {layer_names}\n" f"Composite: {layer_names}\nDate: {date}\nRegion: {place or resolved_bbox.wms_bbox}"
f"Date: {date}\n"
f"Region: {place or resolved_bbox.wms_bbox}"
) )
b64 = base64.b64encode(image_bytes).decode() b64 = base64.b64encode(image_bytes).decode()
@ -421,6 +447,7 @@ async def get_imagery_composite(
# TOOLS — Interpretation # TOOLS — Interpretation
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@mcp.tool( @mcp.tool(
description="Explain what the colors in a GIBS layer mean. " description="Explain what the colors in a GIBS layer mean. "
"Returns a natural-language description mapping colors to scientific values and units." "Returns a natural-language description mapping colors to scientific values and units."
@ -435,9 +462,7 @@ async def explain_layer_colormap(layer_id: str) -> str:
return await client.explain_layer_colormap(layer_id) return await client.explain_layer_colormap(layer_id)
@mcp.tool( @mcp.tool(description="Fetch the pre-rendered legend image for a GIBS layer.")
description="Fetch the pre-rendered legend image for a GIBS layer."
)
async def get_legend( async def get_legend(
layer_id: str, layer_id: str,
orientation: str = "horizontal", orientation: str = "horizontal",
@ -465,6 +490,7 @@ async def get_legend(
# TOOLS — Utility # TOOLS — Utility
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@mcp.tool( @mcp.tool(
description="Geocode a place name to geographic coordinates and bounding box. " description="Geocode a place name to geographic coordinates and bounding box. "
"Uses OpenStreetMap Nominatim." "Uses OpenStreetMap Nominatim."
@ -481,7 +507,8 @@ async def resolve_place(place: str) -> str:
if result is None: if result is None:
return f"Could not geocode '{place}'. Try a more specific name." return f"Could not geocode '{place}'. Try a more specific name."
return json.dumps({ return json.dumps(
{
"display_name": result.display_name, "display_name": result.display_name,
"lat": result.lat, "lat": result.lat,
"lon": result.lon, "lon": result.lon,
@ -491,7 +518,9 @@ async def resolve_place(place: str) -> str:
"east": result.bbox.east, "east": result.bbox.east,
"north": result.bbox.north, "north": result.bbox.north,
}, },
}, indent=2) },
indent=2,
)
@mcp.tool( @mcp.tool(
@ -528,7 +557,14 @@ async def build_tile_url(
tile_matrix_set = layer.tile_matrix_sets[0] tile_matrix_set = layer.tile_matrix_sets[0]
url = client.build_tile_url( url = client.build_tile_url(
layer_id, date, zoom, row, col, tile_matrix_set, ext, projection, layer_id,
date,
zoom,
row,
col,
tile_matrix_set,
ext,
projection,
) )
return url return url
@ -537,6 +573,7 @@ async def build_tile_url(
# RESOURCES # RESOURCES
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@mcp.resource("gibs://catalog") @mcp.resource("gibs://catalog")
async def catalog_resource() -> str: async def catalog_resource() -> str:
"""Full GIBS layer catalog grouped by measurement category.""" """Full GIBS layer catalog grouped by measurement category."""
@ -547,11 +584,13 @@ async def catalog_resource() -> str:
key = layer.measurement or "Unknown" key = layer.measurement or "Unknown"
if key not in by_measurement: if key not in by_measurement:
by_measurement[key] = [] by_measurement[key] = []
by_measurement[key].append({ by_measurement[key].append(
{
"id": layer.identifier, "id": layer.identifier,
"title": layer.title, "title": layer.title,
"has_colormap": layer.has_colormap, "has_colormap": layer.has_colormap,
}) }
)
return json.dumps(by_measurement, indent=2) return json.dumps(by_measurement, indent=2)
@ -578,6 +617,7 @@ async def projections_resource() -> str:
# PROMPTS # PROMPTS
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@mcp.prompt @mcp.prompt
def investigate_event( def investigate_event(
event_type: str, event_type: str,
@ -597,7 +637,7 @@ def investigate_event(
"", "",
"Follow this workflow:", "Follow this workflow:",
"", "",
f'1. **Find relevant layers**: Search for layers related to ' f"1. **Find relevant layers**: Search for layers related to "
f'"{event_type}" (e.g. fire layers for wildfires, ' f'"{event_type}" (e.g. fire layers for wildfires, '
f'precipitation for floods). Also search for "true color" ' f'precipitation for floods). Also search for "true color" '
f'or "corrected reflectance" for visual context.', f'or "corrected reflectance" for visual context.',
@ -662,5 +702,6 @@ def earth_overview() -> str:
# Entry point # Entry point
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def main(): def main():
mcp.run() mcp.run()

View File

@ -1,9 +1,8 @@
"""Tests for mcgibs.geo — geocoding, bbox helpers, and caching.""" """Tests for mcgibs.geo — geocoding and bbox helpers."""
import httpx import httpx
import respx import respx
import mcgibs.geo as _geo_module
from mcgibs.geo import bbox_from_point, expand_bbox, geocode from mcgibs.geo import bbox_from_point, expand_bbox, geocode
from mcgibs.models import BBox from mcgibs.models import BBox
@ -23,11 +22,6 @@ TOKYO_HIT = {
} }
def _reset_rate_limit() -> None:
"""Reset the module-level rate-limit timestamp so tests don't stall."""
_geo_module._last_request_time = 0.0
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# geocode() tests # geocode() tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -35,15 +29,12 @@ def _reset_rate_limit() -> None:
@respx.mock @respx.mock
async def test_geocode_success(): async def test_geocode_success():
_reset_rate_limit()
respx.get(NOMINATIM_URL).mock( respx.get(NOMINATIM_URL).mock(
return_value=httpx.Response(200, json=[TOKYO_HIT]), return_value=httpx.Response(200, json=[TOKYO_HIT]),
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
cache: dict = {} result = await geocode(client, "Tokyo")
result = await geocode(client, "Tokyo", cache)
assert result is not None assert result is not None
assert result.display_name == "Tokyo, Japan" assert result.display_name == "Tokyo, Japan"
@ -61,34 +52,29 @@ async def test_geocode_success():
@respx.mock @respx.mock
async def test_geocode_no_results(): async def test_geocode_no_results():
_reset_rate_limit()
respx.get(NOMINATIM_URL).mock( respx.get(NOMINATIM_URL).mock(
return_value=httpx.Response(200, json=[]), return_value=httpx.Response(200, json=[]),
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
cache: dict = {} result = await geocode(client, "xyznonexistent")
result = await geocode(client, "xyznonexistent", cache)
assert result is None assert result is None
@respx.mock @respx.mock
async def test_geocode_caching(): async def test_geocode_repeated_calls():
"""Second call with the same query must be served from cache — no extra HTTP request.""" """Each geocode() call makes an HTTP request (caching is caller's responsibility)."""
_reset_rate_limit()
route = respx.get(NOMINATIM_URL).mock( route = respx.get(NOMINATIM_URL).mock(
return_value=httpx.Response(200, json=[TOKYO_HIT]), return_value=httpx.Response(200, json=[TOKYO_HIT]),
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
cache: dict = {} first = await geocode(client, "Tokyo")
first = await geocode(client, "Tokyo", cache) second = await geocode(client, "Tokyo")
second = await geocode(client, "Tokyo", cache)
assert route.call_count == 1 # geocode() no longer caches — GIBSClient.resolve_place() handles that
assert route.call_count == 2
assert first is not None assert first is not None
assert second is not None assert second is not None
assert first.display_name == second.display_name assert first.display_name == second.display_name
@ -97,15 +83,12 @@ async def test_geocode_caching():
@respx.mock @respx.mock
async def test_geocode_http_error(): async def test_geocode_http_error():
"""A 500 response should return None without raising an exception.""" """A 500 response should return None without raising an exception."""
_reset_rate_limit()
respx.get(NOMINATIM_URL).mock( respx.get(NOMINATIM_URL).mock(
return_value=httpx.Response(500, text="Internal Server Error"), return_value=httpx.Response(500, text="Internal Server Error"),
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
cache: dict = {} result = await geocode(client, "ServerError")
result = await geocode(client, "ServerError", cache)
assert result is None assert result is None

11
uv.lock generated
View File

@ -302,6 +302,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/1f/d8bce383a90d8a6a11033327777afa4d4d611ec11869284adb6f48152906/cyclopts-4.5.3-py3-none-any.whl", hash = "sha256:50af3085bb15d4a6f2582dd383dad5e4ba6a0d4d4c64ee63326d881a752a6919", size = 200231, upload-time = "2026-02-16T15:07:13.045Z" }, { url = "https://files.pythonhosted.org/packages/3a/1f/d8bce383a90d8a6a11033327777afa4d4d611ec11869284adb6f48152906/cyclopts-4.5.3-py3-none-any.whl", hash = "sha256:50af3085bb15d4a6f2582dd383dad5e4ba6a0d4d4c64ee63326d881a752a6919", size = 200231, upload-time = "2026-02-16T15:07:13.045Z" },
] ]
[[package]]
name = "defusedxml"
version = "0.7.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" },
]
[[package]] [[package]]
name = "dnspython" name = "dnspython"
version = "2.8.0" version = "2.8.0"
@ -588,6 +597,7 @@ name = "mcgibs"
version = "2026.2.18" version = "2026.2.18"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "defusedxml" },
{ name = "fastmcp" }, { name = "fastmcp" },
{ name = "pillow" }, { name = "pillow" },
] ]
@ -602,6 +612,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "defusedxml", specifier = ">=0.7.1" },
{ name = "fastmcp", specifier = ">=3.0.0" }, { name = "fastmcp", specifier = ">=3.0.0" },
{ name = "pillow", specifier = ">=12.0.0" }, { name = "pillow", specifier = ">=12.0.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },