From d163ade9a41a31546fcda2123fb7a79d903aa68a Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Wed, 18 Feb 2026 15:14:02 -0700 Subject: [PATCH] Harden codebase from Hamilton code review findings Address 20 findings from safety-critical review: CRITICAL: - C1: Replace ET.fromstring with defusedxml across all XML parsers - C2: Fix client init failure leaving half-initialized state; clean up HTTP client on startup failure so next connection can retry HIGH: - H1: Replace unbounded dicts with LRU caches (maxsize=500) - H2: Move Nominatim rate limiter from module globals to per-instance state on GIBSClient, eliminating shared mutable state - H3: Validate _parse_rgb input, return (0,0,0) on malformed data - H4: Add exponential backoff retry for capabilities loading - H5: Invert WMS error detection to verify image content-type - H6: Clamp image dimensions to 4096 max to prevent OOM MEDIUM: - M1: Convert images to RGB mode in compare_dates for RGBA safety - M2: Narrow DescribeDomains XML matching to TimeDomain elements - M3: Add BBox model_validator for coordinate range validation - M4: Add ET.ParseError to colormap fetch exception handling - M5: Replace bare except Exception with specific types in server - M6: Catch ValueError from _resolve_bbox in imagery tools for consistent error returns - M7: Only cache successful geocoding lookups (no negative caching) LOW: - L3: Derive USER_AGENT version from package __version__ - L5: Remove unused start_date/end_date params from check_layer_dates --- pyproject.toml | 1 + src/mcgibs/capabilities.py | 4 +- src/mcgibs/client.py | 245 +++++++++++++++++++++++++++---------- src/mcgibs/colormaps.py | 29 +++-- src/mcgibs/constants.py | 15 ++- src/mcgibs/geo.py | 38 ++---- src/mcgibs/models.py | 16 ++- src/mcgibs/server.py | 133 +++++++++++++------- tests/test_geo.py | 37 ++---- uv.lock | 11 ++ 10 files changed, 341 insertions(+), 188 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3af8771..5eb4d89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "fastmcp>=3.0.0", "pillow>=12.0.0", + "defusedxml>=0.7.1", ] [project.scripts] diff --git a/src/mcgibs/capabilities.py b/src/mcgibs/capabilities.py index ebe7483..8151797 100644 --- a/src/mcgibs/capabilities.py +++ b/src/mcgibs/capabilities.py @@ -3,6 +3,8 @@ import logging import xml.etree.ElementTree as ET +import defusedxml.ElementTree as DefusedET + from mcgibs.constants import WMTS_NS from mcgibs.models import BBox, LayerInfo, TimeDimension @@ -209,7 +211,7 @@ def parse_capabilities(xml_text: str) -> dict[str, LayerInfo]: Handles the ~5MB GetCapabilities document from NASA GIBS by iterating through all Layer elements found under the wmts:Contents element. """ - root = ET.fromstring(xml_text) + root = DefusedET.fromstring(xml_text) layers: dict[str, LayerInfo] = {} diff --git a/src/mcgibs/client.py b/src/mcgibs/client.py index d31407c..101aff1 100644 --- a/src/mcgibs/client.py +++ b/src/mcgibs/client.py @@ -1,8 +1,13 @@ """Async HTTP client for all NASA GIBS API interactions.""" +import asyncio import logging +import time +import xml.etree.ElementTree as ET +from collections import OrderedDict from io import BytesIO +import defusedxml.ElementTree as DefusedET import httpx from PIL import Image @@ -23,26 +28,68 @@ from mcgibs.models import BBox, ColorMapSet, GeocodingResult, LayerInfo log = logging.getLogger(__name__) +# Maximum entries per cache before LRU eviction +_MAX_CACHE_SIZE = 500 + +# Maximum retries for capabilities loading at startup +_INIT_MAX_RETRIES = 3 +_INIT_RETRY_DELAYS = [2.0, 4.0, 8.0] + +# Maximum image dimensions to prevent OOM from LLM requests +MAX_IMAGE_DIMENSION = 4096 + + +class _LRUCache(OrderedDict): + """Simple LRU cache backed by OrderedDict.""" + + def __init__(self, maxsize: int = _MAX_CACHE_SIZE) -> None: + super().__init__() + self._maxsize = maxsize + + def get_cached(self, key): + if key in self: + self.move_to_end(key) + return self[key] + return _SENTINEL + + def put(self, key, value): + if key in self: + self.move_to_end(key) + self[key] = value + while len(self) > self._maxsize: + self.popitem(last=False) + + +_SENTINEL = object() + class GIBSClient: """Async client for NASA GIBS APIs. Wraps WMTS (discovery), WMS (imagery), layer metadata, and colormaps behind a single interface. Designed to be initialized once per session - via FastMCP lifespan and reused across tool calls. + via FastMCP middleware and reused across tool calls. + + Rate limiting for Nominatim (1 req/sec) is managed per-client instance + rather than as module-level globals. Caches are LRU-bounded to prevent + unbounded memory growth in long-running sessions. """ def __init__(self) -> None: self._http: httpx.AsyncClient | None = None self.layer_index: dict[str, LayerInfo] = {} - self._metadata_cache: dict[str, dict] = {} - self._colormap_cache: dict[str, ColorMapSet] = {} - self._geocode_cache: dict[str, GeocodingResult | None] = {} + self._metadata_cache: _LRUCache = _LRUCache() + self._colormap_cache: _LRUCache = _LRUCache() + self._geocode_cache: _LRUCache = _LRUCache() + + # Rate limiter state — owned by this instance, not module globals + self._nominatim_lock = asyncio.Lock() + self._last_nominatim_time: float = 0.0 async def initialize(self) -> None: - """Create HTTP client and load WMTS capabilities.""" + """Create HTTP client and load WMTS capabilities with retry.""" self._http = httpx.AsyncClient(timeout=60.0, follow_redirects=True) - await self._load_capabilities() + await self._load_capabilities_with_retry() async def close(self) -> None: """Shut down the HTTP client.""" @@ -56,8 +103,47 @@ class GIBSClient: raise RuntimeError("GIBSClient not initialized — call initialize() first") return self._http + # --- Rate limiting (per-instance) --- + + async def nominatim_rate_limit(self) -> None: + """Enforce 1-second gap between Nominatim requests.""" + async with self._nominatim_lock: + now = time.monotonic() + elapsed = now - self._last_nominatim_time + if elapsed < 1.0: + await asyncio.sleep(1.0 - elapsed) + self._last_nominatim_time = time.monotonic() + # --- Capabilities --- + async def _load_capabilities_with_retry(self) -> None: + """Fetch capabilities with exponential backoff retry.""" + last_exc: Exception | None = None + for attempt in range(_INIT_MAX_RETRIES): + try: + await self._load_capabilities() + return + except (httpx.HTTPError, ET.ParseError) as exc: + last_exc = exc + if attempt < _INIT_MAX_RETRIES - 1: + delay = _INIT_RETRY_DELAYS[attempt] + log.warning( + "Capabilities fetch failed (attempt %d/%d), retrying in %.0fs: %s", + attempt + 1, + _INIT_MAX_RETRIES, + delay, + exc, + ) + await asyncio.sleep(delay) + + log.error( + "Failed to load capabilities after %d attempts", + _INIT_MAX_RETRIES, + ) + raise RuntimeError( + f"GIBS capabilities unavailable after {_INIT_MAX_RETRIES} attempts: {last_exc}" + ) + async def _load_capabilities(self) -> None: """Fetch and parse WMTS GetCapabilities for EPSG:4326.""" url = WMTS_CAPABILITIES_URL.format(epsg="4326") @@ -74,13 +160,10 @@ class GIBSClient: # --- Layer Metadata (enrichment) --- async def fetch_layer_metadata(self, layer_id: str) -> dict: - """Fetch the enriched JSON metadata for a layer. - - This supplements the GetCapabilities data with measurement, - instrument, platform, and other fields not in the WMTS XML. - """ - if layer_id in self._metadata_cache: - return self._metadata_cache[layer_id] + """Fetch the enriched JSON metadata for a layer.""" + cached = self._metadata_cache.get_cached(layer_id) + if cached is not _SENTINEL: + return cached url = f"{LAYER_METADATA_BASE}/{layer_id}.json" try: @@ -91,7 +174,7 @@ class GIBSClient: log.debug("Layer metadata not available for %s: %s", layer_id, exc) data = {} - self._metadata_cache[layer_id] = data + self._metadata_cache.put(layer_id, data) # Enrich the layer_index entry if it exists layer = self.layer_index.get(layer_id) @@ -110,10 +193,10 @@ class GIBSClient: async def fetch_colormap(self, layer_id: str) -> ColorMapSet | None: """Fetch and parse the colormap XML for a layer.""" - if layer_id in self._colormap_cache: - return self._colormap_cache[layer_id] + cached = self._colormap_cache.get_cached(layer_id) + if cached is not _SENTINEL: + return cached - # Derive colormap URL — GIBS uses the layer identifier as filename layer = self.layer_index.get(layer_id) colormap_id = (layer.colormap_id if layer else None) or layer_id url = f"{COLORMAP_BASE}/{colormap_id}.xml" @@ -122,11 +205,11 @@ class GIBSClient: resp = await self.http.get(url) resp.raise_for_status() colormap_set = parse_colormap(resp.text) - except (httpx.HTTPError, ValueError) as exc: + except (httpx.HTTPError, ET.ParseError, ValueError) as exc: log.debug("Colormap not available for %s: %s", layer_id, exc) return None - self._colormap_cache[layer_id] = colormap_set + self._colormap_cache.put(layer_id, colormap_set) return colormap_set async def explain_layer_colormap(self, layer_id: str) -> str: @@ -139,8 +222,21 @@ class GIBSClient: # --- Geocoding --- async def resolve_place(self, place: str) -> GeocodingResult | None: - """Geocode a place name via Nominatim.""" - return await geocode(self.http, place, self._geocode_cache) + """Geocode a place name via Nominatim. + + Uses per-instance rate limiting and LRU-bounded cache. + """ + key = place.strip().lower() + cached = self._geocode_cache.get_cached(key) + if cached is not _SENTINEL: + return cached + + result = await geocode(self.http, place, rate_limiter=self) + # Only cache successful lookups — don't permanently cache failures + # (M7: transient Nominatim failures should be retryable) + if result is not None: + self._geocode_cache.put(key, result) + return result # --- WMS Imagery --- @@ -158,24 +254,34 @@ class GIBSClient: Returns raw image bytes (JPEG or PNG). """ + # H6: clamp dimensions to prevent OOM from unbounded requests + width = max(1, min(width, MAX_IMAGE_DIMENSION)) + height = max(1, min(height, MAX_IMAGE_DIMENSION)) + params = dict(WMS_DEFAULTS) - params.update({ - "LAYERS": layer_id, - "SRS": f"EPSG:{epsg}", - "BBOX": bbox.wms_bbox, - "WIDTH": str(width), - "HEIGHT": str(height), - "FORMAT": image_format, - "TIME": date, - }) + params.update( + { + "LAYERS": layer_id, + "SRS": f"EPSG:{epsg}", + "BBOX": bbox.wms_bbox, + "WIDTH": str(width), + "HEIGHT": str(height), + "FORMAT": image_format, + "TIME": date, + } + ) url = WMS_BASE.format(epsg=epsg) resp = await self.http.get(url, params=params) resp.raise_for_status() + # H5: verify the response IS an image, not just that it ISN'T XML. + # WMS can return HTTP 200 with OGC ServiceException XML. content_type = resp.headers.get("content-type", "") - if "xml" in content_type or "text" in content_type: - raise RuntimeError(f"WMS returned error: {resp.text[:500]}") + if not content_type.startswith("image/"): + raise RuntimeError( + f"WMS returned non-image content-type '{content_type}': {resp.text[:500]}" + ) return resp.content @@ -189,10 +295,7 @@ class GIBSClient: image_format: str = "image/jpeg", epsg: str = DEFAULT_EPSG, ) -> bytes: - """Fetch a multi-layer WMS composite image. - - WMS supports comma-separated LAYERS for overlay compositing. - """ + """Fetch a multi-layer WMS composite image.""" return await self.get_wms_image( layer_id=",".join(layer_ids), date=date, @@ -213,22 +316,28 @@ class GIBSClient: height: int = 512, image_format: str = "image/jpeg", ) -> bytes: - """Fetch two images and compose a side-by-side comparison. - - Returns a single image with the "before" date on the left - and the "after" date on the right, each labeled. - """ + """Fetch two images and compose a side-by-side comparison.""" img_before = await self.get_wms_image( - layer_id, date_before, bbox, width, height, image_format, + layer_id, + date_before, + bbox, + width, + height, + image_format, ) img_after = await self.get_wms_image( - layer_id, date_after, bbox, width, height, image_format, + layer_id, + date_after, + bbox, + width, + height, + image_format, ) - pil_before = Image.open(BytesIO(img_before)) - pil_after = Image.open(BytesIO(img_after)) + # M1: convert to RGB to avoid mode mismatch with RGBA PNGs + pil_before = Image.open(BytesIO(img_before)).convert("RGB") + pil_after = Image.open(BytesIO(img_after)).convert("RGB") - # Create side-by-side composite total_width = pil_before.width + pil_after.width max_height = max(pil_before.height, pil_after.height) composite = Image.new("RGB", (total_width, max_height)) @@ -246,31 +355,38 @@ class GIBSClient: layer_id: str, epsg: str = DEFAULT_EPSG, ) -> dict: - """Query WMTS DescribeDomains for available date ranges. - - Returns a dict with 'time_domain' key (ISO 8601 interval or list - of dates) and 'spatial_domain' if available. - """ - url = WMTS_DESCRIBE_DOMAINS_URL.format(epsg=epsg, layer_id=layer_id) + """Query WMTS DescribeDomains for available date ranges.""" + url = WMTS_DESCRIBE_DOMAINS_URL.format( + epsg=epsg, + layer_id=layer_id, + ) resp = await self.http.get(url) resp.raise_for_status() - # DescribeDomains returns XML — extract time domain - import xml.etree.ElementTree as ET - - root = ET.fromstring(resp.text) - + root = DefusedET.fromstring(resp.text) result: dict[str, str] = {} - # Look for time domain in various possible locations + # M2: look specifically for TimeDomain elements rather than + # broadly matching any "value" element for elem in root.iter(): - tag = elem.tag.rpartition("}")[-1] if "}" in elem.tag else elem.tag - if tag.lower() in ("timedomain", "value") and elem.text: + local = elem.tag.rpartition("}")[-1] if "}" in elem.tag else elem.tag + if local.lower() == "timedomain" and elem.text: text = elem.text.strip() - if text and ("/" in text or "-" in text): + if text: result["time_domain"] = text break + # Fallback: look in child Value elements of a Dimension + if "time_domain" not in result: + for elem in root.iter(): + local = elem.tag.rpartition("}")[-1] if "}" in elem.tag else elem.tag + if local.lower() == "value" and elem.text: + text = elem.text.strip() + # Only match ISO 8601-ish values (contain date separators) + if "/" in text and len(text) > 8: + result["time_domain"] = text + break + return result # --- WMTS tile URL builder --- @@ -305,10 +421,7 @@ class GIBSClient: layer_id: str, orientation: str = "horizontal", ) -> bytes | None: - """Fetch the pre-rendered legend image for a layer. - - GIBS provides legend images via the GetLegendGraphic WMS call. - """ + """Fetch the pre-rendered legend image for a layer.""" layer = self.layer_index.get(layer_id) if layer and layer.legend_url: try: @@ -334,7 +447,7 @@ class GIBSClient: resp = await self.http.get(url, params=params) resp.raise_for_status() content_type = resp.headers.get("content-type", "") - if "image" in content_type: + if content_type.startswith("image/"): return resp.content except httpx.HTTPError as exc: log.debug("Legend not available for %s: %s", layer_id, exc) diff --git a/src/mcgibs/colormaps.py b/src/mcgibs/colormaps.py index 392de11..8e7dd2d 100644 --- a/src/mcgibs/colormaps.py +++ b/src/mcgibs/colormaps.py @@ -7,7 +7,8 @@ the colors actually mean. """ import re -import xml.etree.ElementTree as ET + +import defusedxml.ElementTree as DefusedET from mcgibs.models import ColorMap, ColorMapEntry, ColorMapSet, LegendEntry @@ -31,6 +32,7 @@ _UNIT_CONVERTERS: dict[str, tuple] = { # --- Color naming --- + def _describe_rgb(rgb: tuple[int, int, int]) -> str: """Return an approximate human-friendly color name for an RGB triple. @@ -132,9 +134,7 @@ _INTERVAL_RE = re.compile( re.VERBOSE | re.IGNORECASE, ) -_SINGLE_VALUE_RE = re.compile( - r"[\[\(]\s*([+\-]?\d+(?:\.\d+)?(?:[eE][+\-]?\d+)?)\s*[\]\)]" -) +_SINGLE_VALUE_RE = re.compile(r"[\[\(]\s*([+\-]?\d+(?:\.\d+)?(?:[eE][+\-]?\d+)?)\s*[\]\)]") def _parse_interval_value(interval: str) -> tuple[float | None, float | None]: @@ -174,10 +174,19 @@ def _parse_interval_value(interval: str) -> tuple[float | None, float | None]: # --- XML parsing --- + def _parse_rgb(raw: str) -> tuple[int, int, int]: - """Parse "r,g,b" string into an integer triple.""" + """Parse "r,g,b" string into an integer triple. + + Returns (0, 0, 0) for malformed input rather than crashing. + """ parts = raw.split(",") - return (int(parts[0]), int(parts[1]), int(parts[2])) + if len(parts) < 3: + return (0, 0, 0) + try: + return (int(parts[0]), int(parts[1]), int(parts[2])) + except (ValueError, IndexError): + return (0, 0, 0) def parse_colormap(xml_text: str) -> ColorMapSet: @@ -187,7 +196,7 @@ def parse_colormap(xml_text: str) -> ColorMapSet: children. Each contains elements and an optional with children. """ - root = ET.fromstring(xml_text) + root = DefusedET.fromstring(xml_text) maps: list[ColorMap] = [] @@ -250,6 +259,7 @@ def parse_colormap(xml_text: str) -> ColorMapSet: # --- Natural-language explanation --- + def _format_value(val: float, units: str) -> str: """Format a numeric value with optional unit conversion.""" units_lower = units.lower().strip() @@ -404,10 +414,7 @@ def explain_colormap(colormap_set: ColorMapSet) -> str: return "No colormap data available." # Filter to non-transparent, non-nodata entries for analysis - data_entries = [ - e for e in data_map.entries - if not e.transparent and not e.nodata - ] + data_entries = [e for e in data_map.entries if not e.transparent and not e.nodata] if not data_entries: return "This colormap contains only no-data / transparent entries." diff --git a/src/mcgibs/constants.py b/src/mcgibs/constants.py index cfa8f00..7df7e30 100644 --- a/src/mcgibs/constants.py +++ b/src/mcgibs/constants.py @@ -1,5 +1,7 @@ """GIBS API endpoints, EPSG codes, and TileMatrixSet definitions.""" +from mcgibs import __version__ as _version + # GIBS base URLs — domain sharding (gibs-a/b/c) for parallel tile fetches WMTS_BASE = "https://gibs.earthdata.nasa.gov/wmts/epsg{epsg}/best" WMS_BASE = "https://gibs.earthdata.nasa.gov/wms/epsg{epsg}/best/wms.cgi" @@ -9,15 +11,11 @@ LAYER_METADATA_BASE = "https://gibs.earthdata.nasa.gov/layer-metadata/v1.0" # GetCapabilities and DescribeDomains WMTS_CAPABILITIES_URL = WMTS_BASE + "/1.0.0/WMTSCapabilities.xml" WMTS_DESCRIBE_DOMAINS_URL = ( - WMTS_BASE + "/wmts.cgi?SERVICE=WMTS&VERSION=1.0.0" - "&REQUEST=DescribeDomains&LAYER={layer_id}" + WMTS_BASE + "/wmts.cgi?SERVICE=WMTS&VERSION=1.0.0&REQUEST=DescribeDomains&LAYER={layer_id}" ) # WMTS REST tile URL pattern -WMTS_TILE_URL = ( - WMTS_BASE + "/{layer_id}/default/{date}/{tile_matrix_set}" - "/{z}/{row}/{col}.{ext}" -) +WMTS_TILE_URL = WMTS_BASE + "/{layer_id}/default/{date}/{tile_matrix_set}/{z}/{row}/{col}.{ext}" # Nominatim geocoding NOMINATIM_BASE = "https://nominatim.openstreetmap.org" @@ -63,5 +61,6 @@ WMS_DEFAULTS = { "HEIGHT": "1024", } -# User-Agent for Nominatim (required by their usage policy) -USER_AGENT = "mcgibs-mcp-server/2026.02.18 (ryan@supported.systems)" +# User-Agent for Nominatim (required by their usage policy). +# Version derived from package metadata to stay in sync with pyproject.toml. +USER_AGENT = f"mcgibs-mcp-server/{_version} (ryan@supported.systems)" diff --git a/src/mcgibs/geo.py b/src/mcgibs/geo.py index 760bf56..deb6f6c 100644 --- a/src/mcgibs/geo.py +++ b/src/mcgibs/geo.py @@ -1,8 +1,6 @@ -"""Async Nominatim geocoding with rate limiting and in-memory caching.""" +"""Async Nominatim geocoding with bbox utilities.""" -import asyncio import logging -import time import httpx @@ -11,43 +9,29 @@ from mcgibs.models import BBox, GeocodingResult log = logging.getLogger(__name__) -# Nominatim usage policy: max 1 request per second. -_nominatim_lock = asyncio.Lock() -_last_request_time = 0.0 - - -async def _rate_limit() -> None: - """Enforce a minimum 1-second gap between Nominatim requests.""" - global _last_request_time - async with _nominatim_lock: - now = time.monotonic() - elapsed = now - _last_request_time - if elapsed < 1.0: - await asyncio.sleep(1.0 - elapsed) - _last_request_time = time.monotonic() - async def geocode( client: httpx.AsyncClient, query: str, - cache: dict, + rate_limiter=None, ) -> GeocodingResult | None: """Geocode a place name via Nominatim. + Caching is the caller's responsibility (GIBSClient.resolve_place manages + an LRU cache). This function is a pure HTTP-call-and-parse layer. + Args: client: Shared httpx async client. query: Free-form place name (e.g. "Tokyo", "Amazon River"). - cache: Dict used as an in-memory dedup cache (query -> result). + rate_limiter: Object with async ``nominatim_rate_limit()`` method + for enforcing Nominatim's 1 req/sec policy. When called from + GIBSClient, this is the client instance itself. Returns: GeocodingResult on success, None if no results found. """ - key = query.strip().lower() - if key in cache: - log.debug("Geocode cache hit: %s", key) - return cache[key] - - await _rate_limit() + if rate_limiter is not None and hasattr(rate_limiter, "nominatim_rate_limit"): + await rate_limiter.nominatim_rate_limit() params = { "q": query, @@ -71,7 +55,6 @@ async def geocode( if not data: log.debug("Nominatim returned no results for %r", query) - cache[key] = None return None hit = data[0] @@ -96,7 +79,6 @@ async def geocode( importance=float(hit.get("importance", 0.0)), ) - cache[key] = result log.debug("Geocoded %r -> %s", query, result.display_name) return result diff --git a/src/mcgibs/models.py b/src/mcgibs/models.py index 8d00680..887fa7e 100644 --- a/src/mcgibs/models.py +++ b/src/mcgibs/models.py @@ -2,7 +2,7 @@ from __future__ import annotations -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class BBox(BaseModel): @@ -13,6 +13,20 @@ class BBox(BaseModel): east: float = Field(description="Eastern longitude (-180 to 180)") north: float = Field(description="Northern latitude (-90 to 90)") + @model_validator(mode="after") + def _validate_ranges(self) -> BBox: + if self.south > self.north: + raise ValueError(f"south ({self.south}) must be <= north ({self.north})") + if not (-90 <= self.south <= 90 and -90 <= self.north <= 90): + raise ValueError( + f"Latitudes must be in [-90, 90], got south={self.south}, north={self.north}" + ) + if not (-180 <= self.west <= 180 and -180 <= self.east <= 180): + raise ValueError( + f"Longitudes must be in [-180, 180], got west={self.west}, east={self.east}" + ) + return self + @property def wms_bbox(self) -> str: """Format as WMS BBOX parameter: minx,miny,maxx,maxy.""" diff --git a/src/mcgibs/server.py b/src/mcgibs/server.py index 8c371cf..cfb5b65 100644 --- a/src/mcgibs/server.py +++ b/src/mcgibs/server.py @@ -12,6 +12,7 @@ import base64 import json import logging +import httpx from fastmcp import FastMCP from fastmcp.server.middleware import Middleware @@ -45,6 +46,7 @@ def _get_client() -> GIBSClient: # --- Middleware: initialize client on session start --- + class GIBSInitMiddleware(Middleware): """Load GIBS capabilities when the first client connects.""" @@ -52,8 +54,15 @@ class GIBSInitMiddleware(Middleware): global _client if _client is None: log.info("Initializing GIBS client and loading capabilities...") - _client = GIBSClient() - await _client.initialize() + client = GIBSClient() + try: + await client.initialize() + except Exception: + # C2: clean up the HTTP client so we don't leak sockets, + # then re-raise so the next connection can retry. + await client.close() + raise + _client = client log.info("GIBS client ready with %d layers", len(_client.layer_index)) return await call_next(context) @@ -65,6 +74,7 @@ mcp.middleware.append(GIBSInitMiddleware()) # TOOLS — Discovery # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + @mcp.tool( description="Search NASA GIBS satellite imagery layers by keyword. " "Returns matching layers with titles, identifiers, and date ranges." @@ -87,7 +97,12 @@ async def search_gibs_layers( """ client = _get_client() results = search_layers( - client.layer_index, query, measurement, period, ongoing, limit, + client.layer_index, + query, + measurement, + period, + ongoing, + limit, ) if not results: @@ -175,9 +190,7 @@ async def get_layer_info(layer_id: str) -> str: return json.dumps(info, indent=2) -@mcp.tool( - description="List all measurement categories available in GIBS with layer counts." -) +@mcp.tool(description="List all measurement categories available in GIBS with layer counts.") async def list_measurements() -> str: """List measurement categories across all layers.""" client = _get_client() @@ -201,20 +214,12 @@ async def list_measurements() -> str: return "\n".join(lines) -@mcp.tool( - description="Check available date ranges for a GIBS layer via WMTS DescribeDomains." -) -async def check_layer_dates( - layer_id: str, - start_date: str | None = None, - end_date: str | None = None, -) -> str: +@mcp.tool(description="Check available date ranges for a GIBS layer via WMTS DescribeDomains.") +async def check_layer_dates(layer_id: str) -> str: """Query what dates are available for a specific layer. Args: layer_id: The GIBS layer identifier. - start_date: Optional start date filter (YYYY-MM-DD). - end_date: Optional end date filter (YYYY-MM-DD). """ client = _get_client() layer = client.get_layer(layer_id) @@ -241,7 +246,7 @@ async def check_layer_dates( domains = await client.describe_domains(layer_id) if "time_domain" in domains: lines.append(f" Live time domain: {domains['time_domain']}") - except Exception as exc: + except (httpx.HTTPError, RuntimeError) as exc: log.debug("DescribeDomains failed for %s: %s", layer_id, exc) return "\n".join(lines) @@ -251,6 +256,7 @@ async def check_layer_dates( # TOOLS — Imagery # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + async def _resolve_bbox( client: GIBSClient, bbox: list[float] | None, @@ -299,11 +305,20 @@ async def get_imagery( if layer is None: return [{"type": "text", "text": f"Layer '{layer_id}' not found."}] - resolved_bbox = await _resolve_bbox(client, bbox, place) + try: + resolved_bbox = await _resolve_bbox(client, bbox, place) + except (ValueError, Exception) as exc: + return [{"type": "text", "text": str(exc)}] + image_format = f"image/{format}" image_bytes = await client.get_wms_image( - layer_id, date, resolved_bbox, width, height, image_format, + layer_id, + date, + resolved_bbox, + width, + height, + image_format, ) description = ( @@ -351,10 +366,16 @@ async def compare_dates( if layer is None: return [{"type": "text", "text": f"Layer '{layer_id}' not found."}] - resolved_bbox = await _resolve_bbox(client, bbox, place) + try: + resolved_bbox = await _resolve_bbox(client, bbox, place) + except (ValueError, Exception) as exc: + return [{"type": "text", "text": str(exc)}] composite_bytes = await client.compare_dates( - layer_id, date_before, date_after, resolved_bbox, + layer_id, + date_before, + date_after, + resolved_bbox, ) description = ( @@ -397,17 +418,22 @@ async def get_imagery_composite( if len(layer_ids) > 5: return [{"type": "text", "text": "WMS supports at most 5 layers per composite."}] - resolved_bbox = await _resolve_bbox(client, bbox, place) + try: + resolved_bbox = await _resolve_bbox(client, bbox, place) + except (ValueError, Exception) as exc: + return [{"type": "text", "text": str(exc)}] image_bytes = await client.get_wms_composite( - layer_ids, date, resolved_bbox, width, height, + layer_ids, + date, + resolved_bbox, + width, + height, ) layer_names = ", ".join(layer_ids) description = ( - f"Composite: {layer_names}\n" - f"Date: {date}\n" - f"Region: {place or resolved_bbox.wms_bbox}" + f"Composite: {layer_names}\nDate: {date}\nRegion: {place or resolved_bbox.wms_bbox}" ) b64 = base64.b64encode(image_bytes).decode() @@ -421,6 +447,7 @@ async def get_imagery_composite( # TOOLS — Interpretation # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + @mcp.tool( description="Explain what the colors in a GIBS layer mean. " "Returns a natural-language description mapping colors to scientific values and units." @@ -435,9 +462,7 @@ async def explain_layer_colormap(layer_id: str) -> str: return await client.explain_layer_colormap(layer_id) -@mcp.tool( - description="Fetch the pre-rendered legend image for a GIBS layer." -) +@mcp.tool(description="Fetch the pre-rendered legend image for a GIBS layer.") async def get_legend( layer_id: str, orientation: str = "horizontal", @@ -465,6 +490,7 @@ async def get_legend( # TOOLS — Utility # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + @mcp.tool( description="Geocode a place name to geographic coordinates and bounding box. " "Uses OpenStreetMap Nominatim." @@ -481,17 +507,20 @@ async def resolve_place(place: str) -> str: if result is None: return f"Could not geocode '{place}'. Try a more specific name." - return json.dumps({ - "display_name": result.display_name, - "lat": result.lat, - "lon": result.lon, - "bbox": { - "west": result.bbox.west, - "south": result.bbox.south, - "east": result.bbox.east, - "north": result.bbox.north, + return json.dumps( + { + "display_name": result.display_name, + "lat": result.lat, + "lon": result.lon, + "bbox": { + "west": result.bbox.west, + "south": result.bbox.south, + "east": result.bbox.east, + "north": result.bbox.north, + }, }, - }, indent=2) + indent=2, + ) @mcp.tool( @@ -528,7 +557,14 @@ async def build_tile_url( tile_matrix_set = layer.tile_matrix_sets[0] url = client.build_tile_url( - layer_id, date, zoom, row, col, tile_matrix_set, ext, projection, + layer_id, + date, + zoom, + row, + col, + tile_matrix_set, + ext, + projection, ) return url @@ -537,6 +573,7 @@ async def build_tile_url( # RESOURCES # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + @mcp.resource("gibs://catalog") async def catalog_resource() -> str: """Full GIBS layer catalog grouped by measurement category.""" @@ -547,11 +584,13 @@ async def catalog_resource() -> str: key = layer.measurement or "Unknown" if key not in by_measurement: by_measurement[key] = [] - by_measurement[key].append({ - "id": layer.identifier, - "title": layer.title, - "has_colormap": layer.has_colormap, - }) + by_measurement[key].append( + { + "id": layer.identifier, + "title": layer.title, + "has_colormap": layer.has_colormap, + } + ) return json.dumps(by_measurement, indent=2) @@ -578,6 +617,7 @@ async def projections_resource() -> str: # PROMPTS # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + @mcp.prompt def investigate_event( event_type: str, @@ -597,7 +637,7 @@ def investigate_event( "", "Follow this workflow:", "", - f'1. **Find relevant layers**: Search for layers related to ' + f"1. **Find relevant layers**: Search for layers related to " f'"{event_type}" (e.g. fire layers for wildfires, ' f'precipitation for floods). Also search for "true color" ' f'or "corrected reflectance" for visual context.', @@ -662,5 +702,6 @@ def earth_overview() -> str: # Entry point # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + def main(): mcp.run() diff --git a/tests/test_geo.py b/tests/test_geo.py index a56e793..c491caf 100644 --- a/tests/test_geo.py +++ b/tests/test_geo.py @@ -1,9 +1,8 @@ -"""Tests for mcgibs.geo — geocoding, bbox helpers, and caching.""" +"""Tests for mcgibs.geo — geocoding and bbox helpers.""" import httpx import respx -import mcgibs.geo as _geo_module from mcgibs.geo import bbox_from_point, expand_bbox, geocode from mcgibs.models import BBox @@ -23,11 +22,6 @@ TOKYO_HIT = { } -def _reset_rate_limit() -> None: - """Reset the module-level rate-limit timestamp so tests don't stall.""" - _geo_module._last_request_time = 0.0 - - # --------------------------------------------------------------------------- # geocode() tests # --------------------------------------------------------------------------- @@ -35,15 +29,12 @@ def _reset_rate_limit() -> None: @respx.mock async def test_geocode_success(): - _reset_rate_limit() - respx.get(NOMINATIM_URL).mock( return_value=httpx.Response(200, json=[TOKYO_HIT]), ) async with httpx.AsyncClient() as client: - cache: dict = {} - result = await geocode(client, "Tokyo", cache) + result = await geocode(client, "Tokyo") assert result is not None assert result.display_name == "Tokyo, Japan" @@ -61,34 +52,29 @@ async def test_geocode_success(): @respx.mock async def test_geocode_no_results(): - _reset_rate_limit() - respx.get(NOMINATIM_URL).mock( return_value=httpx.Response(200, json=[]), ) async with httpx.AsyncClient() as client: - cache: dict = {} - result = await geocode(client, "xyznonexistent", cache) + result = await geocode(client, "xyznonexistent") assert result is None @respx.mock -async def test_geocode_caching(): - """Second call with the same query must be served from cache — no extra HTTP request.""" - _reset_rate_limit() - +async def test_geocode_repeated_calls(): + """Each geocode() call makes an HTTP request (caching is caller's responsibility).""" route = respx.get(NOMINATIM_URL).mock( return_value=httpx.Response(200, json=[TOKYO_HIT]), ) async with httpx.AsyncClient() as client: - cache: dict = {} - first = await geocode(client, "Tokyo", cache) - second = await geocode(client, "Tokyo", cache) + first = await geocode(client, "Tokyo") + second = await geocode(client, "Tokyo") - assert route.call_count == 1 + # geocode() no longer caches — GIBSClient.resolve_place() handles that + assert route.call_count == 2 assert first is not None assert second is not None assert first.display_name == second.display_name @@ -97,15 +83,12 @@ async def test_geocode_caching(): @respx.mock async def test_geocode_http_error(): """A 500 response should return None without raising an exception.""" - _reset_rate_limit() - respx.get(NOMINATIM_URL).mock( return_value=httpx.Response(500, text="Internal Server Error"), ) async with httpx.AsyncClient() as client: - cache: dict = {} - result = await geocode(client, "ServerError", cache) + result = await geocode(client, "ServerError") assert result is None diff --git a/uv.lock b/uv.lock index b93b5fd..1eda8f3 100644 --- a/uv.lock +++ b/uv.lock @@ -302,6 +302,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/1f/d8bce383a90d8a6a11033327777afa4d4d611ec11869284adb6f48152906/cyclopts-4.5.3-py3-none-any.whl", hash = "sha256:50af3085bb15d4a6f2582dd383dad5e4ba6a0d4d4c64ee63326d881a752a6919", size = 200231, upload-time = "2026-02-16T15:07:13.045Z" }, ] +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, +] + [[package]] name = "dnspython" version = "2.8.0" @@ -588,6 +597,7 @@ name = "mcgibs" version = "2026.2.18" source = { editable = "." } dependencies = [ + { name = "defusedxml" }, { name = "fastmcp" }, { name = "pillow" }, ] @@ -602,6 +612,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "defusedxml", specifier = ">=0.7.1" }, { name = "fastmcp", specifier = ">=3.0.0" }, { name = "pillow", specifier = ">=12.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },