Resilience: address Hamilton tier-2 findings
H7 — Process-wide shared httpx.AsyncClient via get_shared_client(). Each tool call no longer pays a TCP+TLS handshake; connection pool is reused across the server's lifetime. Tests inject mock transports directly via ArchiveClient(transport=...) so the singleton stays clean. M1 — Retry/backoff on 429/502/503/504 with Retry-After honored (both delta-seconds and HTTP-date forms). Exponential backoff with jitter, capped at 30s, max 3 attempts. Applied to both _fetch_json and stream_file (retry happens BEFORE any bytes are yielded so it can't corrupt a partial write). M2 — Per-(identifier, filename) asyncio.Lock in download_file serializes concurrent downloads of the same file inside one process. Different files still download in parallel. M5 — collection field normalized to list[str] in all output paths (search docs, scrape items, item metadata). LLMs can write `if 'foo' in doc['collection']` without checking the type first. M7 — `is_collection: bool` derived from mediatype on every doc / metadata response, so LLMs can route collection containers vs. real media items without re-querying. H1 — Stream-abort errors (httpx.ReadError, RemoteProtocolError, ConnectError, ReadTimeout) caught and re-raised as ArchiveError with bytes-written context so the caller knows where the partial download ended. Bytes already on disk remain valid for resume. 19 new regression tests (52 total, all green, ruff clean): - 4 tests covering retry/backoff, exhaustion, HTTP-date Retry-After - 1 test for stream-abort byte-count surfacing - 6 tests for collection normalization shapes - 4 tests for is_collection in real tool flow + shared client lifecycle - 2 tests verifying download lock: same-file serialized, different files parallel
This commit is contained in:
parent
4a03af1675
commit
6198defeca
@ -2,11 +2,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import errno
|
import errno
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
@ -23,6 +26,30 @@ DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0)
|
|||||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$")
|
_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$")
|
||||||
_MAX_FILENAME = 512
|
_MAX_FILENAME = 512
|
||||||
|
|
||||||
|
# Retry policy for archive.org's transient overload responses.
|
||||||
|
_RETRY_STATUSES = frozenset({429, 502, 503, 504})
|
||||||
|
_RETRY_MAX_ATTEMPTS = 3
|
||||||
|
_RETRY_MAX_DELAY = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
def _retry_delay(response: httpx.Response, attempt: int) -> float:
|
||||||
|
"""Compute backoff for a retryable response. Honors Retry-After when present."""
|
||||||
|
retry_after = response.headers.get("Retry-After")
|
||||||
|
if retry_after:
|
||||||
|
try:
|
||||||
|
return min(float(retry_after), _RETRY_MAX_DELAY)
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
# HTTP-date form
|
||||||
|
dt = parsedate_to_datetime(retry_after)
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
wait = (dt - datetime.now(tz=timezone.utc)).total_seconds()
|
||||||
|
return min(max(wait, 0.0), _RETRY_MAX_DELAY)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
# Exponential backoff with jitter: 1s, 2s, 4s + [0,1)s
|
||||||
|
return min((2 ** attempt) + random.random(), _RETRY_MAX_DELAY)
|
||||||
|
|
||||||
|
|
||||||
class ArchiveError(RuntimeError):
|
class ArchiveError(RuntimeError):
|
||||||
"""Raised when archive.org returns an error payload or unexpected status."""
|
"""Raised when archive.org returns an error payload or unexpected status."""
|
||||||
@ -131,12 +158,20 @@ class ArchiveClient:
|
|||||||
# ---------- internal: error-surfacing fetch ----------
|
# ---------- internal: error-surfacing fetch ----------
|
||||||
|
|
||||||
async def _fetch_json(self, url: str, params: Any = None) -> Any:
|
async def _fetch_json(self, url: str, params: Any = None) -> Any:
|
||||||
"""GET + JSON decode with archive.org-friendly error messages.
|
"""GET + JSON decode with archive.org-friendly error messages and retry.
|
||||||
|
|
||||||
Wraps raise_for_status so that 4xx/5xx responses include a body preview
|
Retries on 429/502/503/504 with Retry-After honored. 4xx/5xx responses
|
||||||
in the exception — invaluable for an LLM trying to fix a bad query.
|
that aren't retryable include a body preview in the exception —
|
||||||
|
invaluable for an LLM trying to fix a bad query.
|
||||||
"""
|
"""
|
||||||
r = await self._client.get(url, params=params)
|
r: httpx.Response | None = None
|
||||||
|
for attempt in range(_RETRY_MAX_ATTEMPTS):
|
||||||
|
r = await self._client.get(url, params=params)
|
||||||
|
if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1:
|
||||||
|
await asyncio.sleep(_retry_delay(r, attempt))
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
assert r is not None # the loop runs at least once
|
||||||
if r.is_error:
|
if r.is_error:
|
||||||
body = r.text[:500] if r.content else ""
|
body = r.text[:500] if r.content else ""
|
||||||
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
||||||
@ -243,8 +278,11 @@ class ArchiveClient:
|
|||||||
) -> AsyncIterator[bytes]:
|
) -> AsyncIterator[bytes]:
|
||||||
"""Async byte iterator. If resume_from > 0, requires a 206 response.
|
"""Async byte iterator. If resume_from > 0, requires a 206 response.
|
||||||
|
|
||||||
|
Retries on 429/502/503/504 BEFORE yielding any bytes (so retry never
|
||||||
|
risks corrupting a partially-written file).
|
||||||
|
|
||||||
Raises ArchiveError BEFORE yielding any bytes if:
|
Raises ArchiveError BEFORE yielding any bytes if:
|
||||||
- the server returns a 4xx/5xx
|
- the server returns a non-retryable 4xx/5xx
|
||||||
- resume was requested but the server returned 200 (Range ignored)
|
- resume was requested but the server returned 200 (Range ignored)
|
||||||
- the Content-Range start byte doesn't match resume_from
|
- the Content-Range start byte doesn't match resume_from
|
||||||
"""
|
"""
|
||||||
@ -256,27 +294,32 @@ class ArchiveClient:
|
|||||||
headers["Range"] = f"bytes={resume_from}-"
|
headers["Range"] = f"bytes={resume_from}-"
|
||||||
url = self.download_url(identifier, filename)
|
url = self.download_url(identifier, filename)
|
||||||
|
|
||||||
async with self._client.stream("GET", url, headers=headers) as r:
|
for attempt in range(_RETRY_MAX_ATTEMPTS):
|
||||||
if r.is_error:
|
async with self._client.stream("GET", url, headers=headers) as r:
|
||||||
body = (await r.aread())[:500].decode("utf-8", errors="replace")
|
if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1:
|
||||||
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
await asyncio.sleep(_retry_delay(r, attempt))
|
||||||
if resume_from > 0:
|
continue
|
||||||
if r.status_code != 206:
|
if r.is_error:
|
||||||
raise ArchiveError(
|
body = (await r.aread())[:500].decode("utf-8", errors="replace")
|
||||||
f"server ignored Range request (got HTTP {r.status_code}); "
|
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
||||||
f"local file may be stale — retry download_file with overwrite=True"
|
if resume_from > 0:
|
||||||
)
|
if r.status_code != 206:
|
||||||
# Verify the byte range starts where we expect. archive.org's CDN
|
raise ArchiveError(
|
||||||
# is normally well-behaved here, but trust-but-verify.
|
f"server ignored Range request (got HTTP {r.status_code}); "
|
||||||
cr = r.headers.get("Content-Range", "")
|
f"local file may be stale — retry download_file with overwrite=True"
|
||||||
m = re.match(r"bytes\s+(\d+)-", cr)
|
)
|
||||||
if m and int(m.group(1)) != resume_from:
|
# Verify the byte range starts where we expect. archive.org's CDN
|
||||||
raise ArchiveError(
|
# is normally well-behaved here, but trust-but-verify.
|
||||||
f"Content-Range start {m.group(1)} != resume_from {resume_from}; "
|
cr = r.headers.get("Content-Range", "")
|
||||||
f"refusing to corrupt {filename}"
|
m = re.match(r"bytes\s+(\d+)-", cr)
|
||||||
)
|
if m and int(m.group(1)) != resume_from:
|
||||||
async for chunk in r.aiter_bytes(chunk_size=1 << 16):
|
raise ArchiveError(
|
||||||
yield chunk
|
f"Content-Range start {m.group(1)} != resume_from {resume_from}; "
|
||||||
|
f"refusing to corrupt {filename}"
|
||||||
|
)
|
||||||
|
async for chunk in r.aiter_bytes(chunk_size=1 << 16):
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
async def download_to_file(
|
async def download_to_file(
|
||||||
self,
|
self,
|
||||||
@ -321,6 +364,16 @@ class ArchiveClient:
|
|||||||
hasher.update(chunk)
|
hasher.update(chunk)
|
||||||
if chunk_cb:
|
if chunk_cb:
|
||||||
chunk_cb(bytes_written)
|
chunk_cb(bytes_written)
|
||||||
|
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e:
|
||||||
|
# H1: surface partial-state context so the caller can decide whether
|
||||||
|
# to resume or restart. The bytes already on disk are valid (we only
|
||||||
|
# write whole chunks), so a follow-up call with overwrite=False will
|
||||||
|
# resume cleanly from `bytes_written`.
|
||||||
|
raise ArchiveError(
|
||||||
|
f"download interrupted after {bytes_written - resume_from} new bytes "
|
||||||
|
f"({bytes_written} total on disk, resumed from {resume_from}). "
|
||||||
|
f"File at {dest} — call download_file again to resume. Cause: {e!r}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
@ -335,3 +388,38 @@ class ArchiveClient:
|
|||||||
result["md5_expected"] = verify_md5
|
result["md5_expected"] = verify_md5
|
||||||
result["md5_ok"] = actual.lower() == verify_md5.lower()
|
result["md5_ok"] = actual.lower() == verify_md5.lower()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- H7: process-wide shared client ----------
|
||||||
|
#
|
||||||
|
# An MCP server typically lives forever serving one LLM. Spinning up a fresh
|
||||||
|
# httpx.AsyncClient per tool call wastes a TCP+TLS handshake (~200-400ms) and
|
||||||
|
# burns ephemeral ports under parallel fan-out. Share one client across calls.
|
||||||
|
|
||||||
|
_shared_client: ArchiveClient | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_shared_client() -> ArchiveClient:
|
||||||
|
"""Return the process-wide shared ArchiveClient, creating it on first use.
|
||||||
|
|
||||||
|
Tests should NOT use this — construct an ArchiveClient(transport=...)
|
||||||
|
directly so MockTransport injection works without leaking into other tests.
|
||||||
|
"""
|
||||||
|
global _shared_client
|
||||||
|
if _shared_client is None:
|
||||||
|
# The race window here is small. If two coroutines both create, one
|
||||||
|
# client gets discarded — wasteful but not corrupting.
|
||||||
|
client = ArchiveClient()
|
||||||
|
if _shared_client is None:
|
||||||
|
_shared_client = client
|
||||||
|
else:
|
||||||
|
await client.aclose()
|
||||||
|
return _shared_client
|
||||||
|
|
||||||
|
|
||||||
|
async def close_shared_client() -> None:
|
||||||
|
"""Close and clear the shared client. Useful for tests; safe to call many times."""
|
||||||
|
global _shared_client
|
||||||
|
if _shared_client is not None:
|
||||||
|
c, _shared_client = _shared_client, None
|
||||||
|
await c.aclose()
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -13,12 +14,32 @@ from pydantic import Field
|
|||||||
|
|
||||||
from mcarchive_org import __version__
|
from mcarchive_org import __version__
|
||||||
from mcarchive_org.client import (
|
from mcarchive_org.client import (
|
||||||
ArchiveClient,
|
|
||||||
ArchiveError,
|
ArchiveError,
|
||||||
|
get_shared_client,
|
||||||
validate_filename,
|
validate_filename,
|
||||||
validate_identifier,
|
validate_identifier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-(identifier, filename) locks to serialize concurrent downloads of the same
|
||||||
|
# file inside the same process. Cross-process races would still need an fcntl
|
||||||
|
# advisory lock; a single MCP-server process is the common case.
|
||||||
|
_download_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _download_lock_for(identifier: str, filename: str) -> asyncio.Lock:
|
||||||
|
"""Get-or-create the in-process lock for one (identifier, filename) pair.
|
||||||
|
|
||||||
|
Safe without an outer lock because asyncio is cooperative: there are no
|
||||||
|
awaits between the dict check and the dict set, so two coroutines can't
|
||||||
|
interleave inside this function.
|
||||||
|
"""
|
||||||
|
key = f"{identifier}::{filename}"
|
||||||
|
lock = _download_locks.get(key)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
_download_locks[key] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
def _resolve_download_root() -> Path:
|
def _resolve_download_root() -> Path:
|
||||||
"""Resolve the download root lazily so env-var changes after import are honored."""
|
"""Resolve the download root lazily so env-var changes after import are honored."""
|
||||||
@ -83,6 +104,30 @@ def _matches(name: str, format_: str | None, name_glob: str | None, formats: lis
|
|||||||
return not (formats and (format_ or "").lower() not in {f.lower() for f in formats})
|
return not (formats and (format_ or "").lower() not in {f.lower() for f in formats})
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_collection(value: Any) -> list[str]:
|
||||||
|
"""archive.org returns `collection` as either str or list[str]. Normalize to list.
|
||||||
|
|
||||||
|
Stable shape lets LLMs write `if 'foo' in doc['collection']` without first
|
||||||
|
having to remember which type they're holding.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [value] if value else []
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [str(x) for x in value if x]
|
||||||
|
return [str(value)]
|
||||||
|
|
||||||
|
|
||||||
|
def _enrich_doc(doc: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Apply normalization + derived fields to a search-result doc or metadata blob."""
|
||||||
|
out = dict(doc)
|
||||||
|
if "collection" in out:
|
||||||
|
out["collection"] = _normalize_collection(out["collection"])
|
||||||
|
out["is_collection"] = out.get("mediatype") == "collection"
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path:
|
def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path:
|
||||||
"""Construct + verify the download destination path.
|
"""Construct + verify the download destination path.
|
||||||
|
|
||||||
@ -133,19 +178,21 @@ async def search_items(
|
|||||||
"""Search archive.org items. Good for small/interactive queries.
|
"""Search archive.org items. Good for small/interactive queries.
|
||||||
|
|
||||||
Returns up to `rows` matching items plus `num_found` (total hits) and `has_more`.
|
Returns up to `rows` matching items plus `num_found` (total hits) and `has_more`.
|
||||||
|
Each doc has `is_collection` derived from mediatype; `collection` is always a list.
|
||||||
Use scrape_items for bulk iteration over large result sets.
|
Use scrape_items for bulk iteration over large result sets.
|
||||||
"""
|
"""
|
||||||
async with ArchiveClient() as c:
|
c = await get_shared_client()
|
||||||
result = await c.search(query=query, fields=fields, sort=sort, rows=rows, page=page)
|
result = await c.search(query=query, fields=fields, sort=sort, rows=rows, page=page)
|
||||||
total = result["num_found"]
|
total = result["num_found"]
|
||||||
seen = (page - 1) * rows + len(result["docs"])
|
docs = [_enrich_doc(d) for d in result["docs"]]
|
||||||
|
seen = (page - 1) * rows + len(docs)
|
||||||
return {
|
return {
|
||||||
"query": query,
|
"query": query,
|
||||||
"num_found": total,
|
"num_found": total,
|
||||||
"page": page,
|
"page": page,
|
||||||
"rows": rows,
|
"rows": rows,
|
||||||
"has_more": seen < total,
|
"has_more": seen < total,
|
||||||
"docs": result["docs"],
|
"docs": docs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -160,11 +207,12 @@ async def scrape_items(
|
|||||||
"""Scrape API — high-throughput cursor-paginated search. count >= 100.
|
"""Scrape API — high-throughput cursor-paginated search. count >= 100.
|
||||||
|
|
||||||
Response includes `cursor` (for next page) when more results exist; missing when done.
|
Response includes `cursor` (for next page) when more results exist; missing when done.
|
||||||
|
Each item has `is_collection` derived from mediatype; `collection` is always a list.
|
||||||
"""
|
"""
|
||||||
async with ArchiveClient() as c:
|
c = await get_shared_client()
|
||||||
data = await c.scrape(query=query, fields=fields, sorts=sorts, count=count, cursor=cursor)
|
data = await c.scrape(query=query, fields=fields, sorts=sorts, count=count, cursor=cursor)
|
||||||
return {
|
return {
|
||||||
"items": data.get("items", []),
|
"items": [_enrich_doc(d) for d in data.get("items", [])],
|
||||||
"count": data.get("count"),
|
"count": data.get("count"),
|
||||||
"total": data.get("total"),
|
"total": data.get("total"),
|
||||||
"next_cursor": data.get("cursor"),
|
"next_cursor": data.get("cursor"),
|
||||||
@ -189,15 +237,17 @@ async def _fetch_item_metadata(identifier: str, include_files: bool) -> dict[str
|
|||||||
"""Shared metadata-fetching logic. Used by both the tool and the MCP resource
|
"""Shared metadata-fetching logic. Used by both the tool and the MCP resource
|
||||||
so neither has to depend on the other's `.fn` attribute."""
|
so neither has to depend on the other's `.fn` attribute."""
|
||||||
validate_identifier(identifier)
|
validate_identifier(identifier)
|
||||||
async with ArchiveClient() as c:
|
c = await get_shared_client()
|
||||||
data = await c.metadata(identifier)
|
data = await c.metadata(identifier)
|
||||||
|
|
||||||
md = data.get("metadata", {})
|
md = data.get("metadata", {})
|
||||||
|
mediatype = md.get("mediatype")
|
||||||
out: dict[str, Any] = {
|
out: dict[str, Any] = {
|
||||||
"identifier": md.get("identifier", identifier),
|
"identifier": md.get("identifier", identifier),
|
||||||
"title": md.get("title"),
|
"title": md.get("title"),
|
||||||
"mediatype": md.get("mediatype"),
|
"mediatype": mediatype,
|
||||||
"collection": md.get("collection"),
|
"is_collection": mediatype == "collection",
|
||||||
|
"collection": _normalize_collection(md.get("collection")),
|
||||||
"creator": md.get("creator"),
|
"creator": md.get("creator"),
|
||||||
"date": md.get("date"),
|
"date": md.get("date"),
|
||||||
"description": md.get("description"),
|
"description": md.get("description"),
|
||||||
@ -235,8 +285,8 @@ async def list_files(
|
|||||||
Each entry includes a ready-to-use `download_url`.
|
Each entry includes a ready-to-use `download_url`.
|
||||||
"""
|
"""
|
||||||
validate_identifier(identifier)
|
validate_identifier(identifier)
|
||||||
async with ArchiveClient() as c:
|
c = await get_shared_client()
|
||||||
files = await c.files(identifier)
|
files = await c.files(identifier)
|
||||||
|
|
||||||
matches = [
|
matches = [
|
||||||
_enrich_file(identifier, f)
|
_enrich_file(identifier, f)
|
||||||
@ -287,20 +337,25 @@ async def download_file(
|
|||||||
The destination is path-confined to either `dest_dir` (when given) or
|
The destination is path-confined to either `dest_dir` (when given) or
|
||||||
$MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute
|
$MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute
|
||||||
paths, or NUL bytes are rejected before any FS or network I/O.
|
paths, or NUL bytes are rejected before any FS or network I/O.
|
||||||
|
|
||||||
|
Concurrent calls for the same (identifier, filename) are serialized in-process
|
||||||
|
so two parallel tool invocations can't race on the same destination file.
|
||||||
"""
|
"""
|
||||||
dest = _confine_dest(identifier, filename, dest_dir)
|
dest = _confine_dest(identifier, filename, dest_dir)
|
||||||
if overwrite and dest.exists() and not dest.is_symlink():
|
|
||||||
dest.unlink()
|
|
||||||
elif overwrite and dest.is_symlink():
|
|
||||||
# Don't follow the symlink to delete the target; remove the link itself.
|
|
||||||
dest.unlink()
|
|
||||||
|
|
||||||
try:
|
lock = _download_lock_for(identifier, filename)
|
||||||
async with ArchiveClient() as c:
|
async with lock:
|
||||||
|
if overwrite and (dest.exists() or dest.is_symlink()):
|
||||||
|
# is_symlink() before exists() — a dangling symlink reports exists()=False
|
||||||
|
# but we still want to remove the link itself rather than follow it.
|
||||||
|
dest.unlink()
|
||||||
|
|
||||||
|
try:
|
||||||
|
c = await get_shared_client()
|
||||||
result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5)
|
result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5)
|
||||||
except ArchiveError as e:
|
except ArchiveError as e:
|
||||||
# Re-raise with the destination context so the caller can act on it.
|
# Re-raise with the destination context so the caller can act on it.
|
||||||
raise ArchiveError(f"{e} (dest={dest})") from e
|
raise ArchiveError(f"{e} (dest={dest})") from e
|
||||||
|
|
||||||
result["identifier"] = identifier
|
result["identifier"] = identifier
|
||||||
result["filename"] = filename
|
result["filename"] = filename
|
||||||
|
|||||||
@ -223,6 +223,127 @@ async def test_invalid_json_response_surfaced():
|
|||||||
# ---------- happy path ----------
|
# ---------- happy path ----------
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- M1: retry/backoff with Retry-After ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_retry_on_429_then_success(monkeypatch):
|
||||||
|
"""First call gets 429 with Retry-After: 0, second call succeeds."""
|
||||||
|
sleeps: list[float] = []
|
||||||
|
|
||||||
|
async def fake_sleep(d: float) -> None:
|
||||||
|
sleeps.append(d)
|
||||||
|
|
||||||
|
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep)
|
||||||
|
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
calls["n"] += 1
|
||||||
|
if calls["n"] == 1:
|
||||||
|
return httpx.Response(429, headers={"Retry-After": "0"}, json={"error": "slow down"})
|
||||||
|
return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}})
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
result = await c.search(query="x", rows=1)
|
||||||
|
assert result["num_found"] == 0
|
||||||
|
assert calls["n"] == 2
|
||||||
|
assert sleeps == [0.0] # honored Retry-After: 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_retry_exhaustion_raises_with_body(monkeypatch):
|
||||||
|
"""If 429 persists past max_attempts, the final error body is surfaced."""
|
||||||
|
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep())
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(429, json={"error": "rate limit exhausted"})
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="rate limit exhausted"):
|
||||||
|
await c.search(query="x")
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop_sleep():
|
||||||
|
"""Used in place of asyncio.sleep when we don't care about backoff timing."""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_retry_on_503_for_stream(monkeypatch, tmp_path):
|
||||||
|
"""Stream-level retry: 503 once, then 200 with body."""
|
||||||
|
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep())
|
||||||
|
|
||||||
|
body = b"actual file body"
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
calls["n"] += 1
|
||||||
|
if calls["n"] == 1:
|
||||||
|
return httpx.Response(503, text="overloaded")
|
||||||
|
return httpx.Response(200, content=body)
|
||||||
|
|
||||||
|
dest = tmp_path / "f.bin"
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
result = await c.download_to_file("nasa", "f.bin", dest)
|
||||||
|
assert result["bytes_written"] == len(body)
|
||||||
|
assert calls["n"] == 2
|
||||||
|
assert dest.read_bytes() == body
|
||||||
|
|
||||||
|
|
||||||
|
async def test_retry_after_http_date_form(monkeypatch):
|
||||||
|
"""Retry-After can be an HTTP-date; we must parse it to a delta seconds."""
|
||||||
|
sleeps: list[float] = []
|
||||||
|
|
||||||
|
async def fake_sleep(d: float) -> None:
|
||||||
|
sleeps.append(d)
|
||||||
|
|
||||||
|
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep)
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
calls["n"] += 1
|
||||||
|
if calls["n"] == 1:
|
||||||
|
# An HTTP-date in the past should produce a 0-or-negative wait, clamped to 0.
|
||||||
|
return httpx.Response(429, headers={"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"})
|
||||||
|
return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}})
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
await c.search(query="x")
|
||||||
|
assert sleeps == [0.0]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- H1: stream-abort error context ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path):
|
||||||
|
"""If httpx raises mid-stream, we wrap it in ArchiveError with byte count
|
||||||
|
so the caller knows where the partial download ended."""
|
||||||
|
|
||||||
|
# Yield enough bytes to flush past httpx's internal chunk buffer (64KB) so
|
||||||
|
# at least one chunk reaches our writer before the error fires.
|
||||||
|
chunk_payload = b"X" * (1 << 17) # 128KB — multiple buffer fills
|
||||||
|
|
||||||
|
async def evil_body():
|
||||||
|
yield chunk_payload
|
||||||
|
raise httpx.ReadError("simulated network drop")
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, content=evil_body())
|
||||||
|
|
||||||
|
dest = tmp_path / "interrupted.bin"
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError) as exc_info:
|
||||||
|
await c.download_to_file("nasa", "interrupted.bin", dest)
|
||||||
|
|
||||||
|
msg = str(exc_info.value)
|
||||||
|
assert "interrupted after" in msg
|
||||||
|
assert "ReadError" in msg
|
||||||
|
# Partial bytes ARE on disk — at least the first delivered chunk.
|
||||||
|
on_disk = dest.read_bytes()
|
||||||
|
assert len(on_disk) > 0
|
||||||
|
assert on_disk == chunk_payload[: len(on_disk)]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- happy path ----------
|
||||||
|
|
||||||
|
|
||||||
async def test_fresh_download_writes_full_body(tmp_path):
|
async def test_fresh_download_writes_full_body(tmp_path):
|
||||||
body = b"hello world" * 100
|
body = b"hello world" * 100
|
||||||
dest = tmp_path / "new.bin"
|
dest = tmp_path / "new.bin"
|
||||||
|
|||||||
193
tests/test_server_mocked.py
Normal file
193
tests/test_server_mocked.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
"""Server-layer regression tests using a swapped-in shared client.
|
||||||
|
|
||||||
|
These exercise the MCP tool functions directly and verify:
|
||||||
|
- Collection normalization (M5)
|
||||||
|
- `is_collection` derived flag (M7)
|
||||||
|
- Shared client lifecycle (H7)
|
||||||
|
- Concurrent-download serialization (M2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mcarchive_org import client as client_mod
|
||||||
|
from mcarchive_org.client import ArchiveClient
|
||||||
|
from mcarchive_org.server import (
|
||||||
|
_enrich_doc,
|
||||||
|
_normalize_collection,
|
||||||
|
download_file,
|
||||||
|
get_item_metadata,
|
||||||
|
search_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def swap_shared_client(handler):
|
||||||
|
"""Temporarily replace the process-wide shared client with a mock-backed one.
|
||||||
|
|
||||||
|
Tests that exercise server.py tools need this because those tools call
|
||||||
|
get_shared_client() under the hood, and we can't pass a transport in.
|
||||||
|
"""
|
||||||
|
saved = client_mod._shared_client
|
||||||
|
mock = ArchiveClient(transport=httpx.MockTransport(handler))
|
||||||
|
client_mod._shared_client = mock
|
||||||
|
try:
|
||||||
|
yield mock
|
||||||
|
finally:
|
||||||
|
client_mod._shared_client = saved
|
||||||
|
await mock.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- M5: collection normalization ----------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw,expected",
|
||||||
|
[
|
||||||
|
(None, []),
|
||||||
|
("", []),
|
||||||
|
("nasa", ["nasa"]),
|
||||||
|
(["nasa", "opensource"], ["nasa", "opensource"]),
|
||||||
|
([], []),
|
||||||
|
([None, "nasa", ""], ["nasa"]), # falsy items dropped
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_normalize_collection_shapes(raw, expected):
|
||||||
|
assert _normalize_collection(raw) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_enrich_doc_marks_is_collection():
|
||||||
|
assert _enrich_doc({"mediatype": "collection", "identifier": "nasa"})["is_collection"] is True
|
||||||
|
assert _enrich_doc({"mediatype": "audio", "identifier": "x"})["is_collection"] is False
|
||||||
|
assert _enrich_doc({"identifier": "x"})["is_collection"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_enrich_doc_normalizes_collection_field():
|
||||||
|
out = _enrich_doc({"identifier": "x", "collection": "single"})
|
||||||
|
assert out["collection"] == ["single"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- M7: is_collection in real tool flow ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_items_decorates_docs_with_is_collection():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"response": {
|
||||||
|
"numFound": 2,
|
||||||
|
"docs": [
|
||||||
|
{"identifier": "nasa", "mediatype": "collection", "collection": "nasa"},
|
||||||
|
{"identifier": "song1", "mediatype": "audio", "collection": ["etree", "GratefulDead"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async with swap_shared_client(handler):
|
||||||
|
result = await search_items(query="x", rows=2)
|
||||||
|
|
||||||
|
assert len(result["docs"]) == 2
|
||||||
|
nasa, song = result["docs"]
|
||||||
|
assert nasa["is_collection"] is True
|
||||||
|
assert nasa["collection"] == ["nasa"]
|
||||||
|
assert song["is_collection"] is False
|
||||||
|
assert song["collection"] == ["etree", "GratefulDead"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_item_metadata_normalizes_collection():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"metadata": {
|
||||||
|
"identifier": "nasa",
|
||||||
|
"title": "NASA Images",
|
||||||
|
"mediatype": "collection",
|
||||||
|
"collection": "internetarchive",
|
||||||
|
},
|
||||||
|
"files_count": 0,
|
||||||
|
"item_size": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async with swap_shared_client(handler):
|
||||||
|
result = await get_item_metadata(identifier="nasa")
|
||||||
|
|
||||||
|
assert result["is_collection"] is True
|
||||||
|
assert result["collection"] == ["internetarchive"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- H7: shared client lifecycle ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_shared_client_returns_same_instance():
|
||||||
|
await client_mod.close_shared_client()
|
||||||
|
a = await client_mod.get_shared_client()
|
||||||
|
b = await client_mod.get_shared_client()
|
||||||
|
assert a is b
|
||||||
|
await client_mod.close_shared_client()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_close_shared_client_clears_singleton():
|
||||||
|
a = await client_mod.get_shared_client()
|
||||||
|
await client_mod.close_shared_client()
|
||||||
|
b = await client_mod.get_shared_client()
|
||||||
|
assert a is not b
|
||||||
|
await client_mod.close_shared_client()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- M2: concurrent-download serialization ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent_downloads_same_file_are_serialized(tmp_path, monkeypatch):
|
||||||
|
"""Two parallel download_file calls for the same (id, filename) must not
|
||||||
|
interleave — otherwise they'd race on the destination file."""
|
||||||
|
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||||
|
|
||||||
|
state = {"active": 0, "max_active": 0}
|
||||||
|
|
||||||
|
async def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
state["active"] += 1
|
||||||
|
state["max_active"] = max(state["max_active"], state["active"])
|
||||||
|
await asyncio.sleep(0.05) # hold the request long enough to overlap
|
||||||
|
state["active"] -= 1
|
||||||
|
return httpx.Response(200, content=b"file-content")
|
||||||
|
|
||||||
|
async with swap_shared_client(handler):
|
||||||
|
await asyncio.gather(
|
||||||
|
download_file(identifier="nasa", filename="shared.bin", overwrite=True),
|
||||||
|
download_file(identifier="nasa", filename="shared.bin", overwrite=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The lock should have prevented any overlap.
|
||||||
|
assert state["max_active"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch):
|
||||||
|
"""Different filenames get different locks — they should run concurrently."""
|
||||||
|
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||||
|
|
||||||
|
state = {"active": 0, "max_active": 0}
|
||||||
|
|
||||||
|
async def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
state["active"] += 1
|
||||||
|
state["max_active"] = max(state["max_active"], state["active"])
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
state["active"] -= 1
|
||||||
|
return httpx.Response(200, content=b"data")
|
||||||
|
|
||||||
|
async with swap_shared_client(handler):
|
||||||
|
await asyncio.gather(
|
||||||
|
download_file(identifier="nasa", filename="a.bin", overwrite=True),
|
||||||
|
download_file(identifier="nasa", filename="b.bin", overwrite=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different files — should overlap.
|
||||||
|
assert state["max_active"] == 2
|
||||||
Loading…
x
Reference in New Issue
Block a user