Resilience: address Hamilton tier-2 findings

H7 — Process-wide shared httpx.AsyncClient via get_shared_client().
Each tool call no longer pays a TCP+TLS handshake; connection pool is
reused across the server's lifetime. Tests inject mock transports
directly via ArchiveClient(transport=...) so the singleton stays clean.

M1 — Retry/backoff on 429/502/503/504 with Retry-After honored
(both delta-seconds and HTTP-date forms). Exponential backoff with
jitter, capped at 30s, max 3 attempts. Applied to both _fetch_json
and stream_file (retry happens BEFORE any bytes are yielded so it
can't corrupt a partial write).

M2 — Per-(identifier, filename) asyncio.Lock in download_file
serializes concurrent downloads of the same file inside one process.
Different files still download in parallel.

M5 — collection field normalized to list[str] in all output paths
(search docs, scrape items, item metadata). LLMs can write
`if 'foo' in doc['collection']` without checking the type first.

M7 — `is_collection: bool` derived from mediatype on every doc /
metadata response, so LLMs can route collection containers vs.
real media items without re-querying.

H1 — Stream-abort errors (httpx.ReadError, RemoteProtocolError,
ConnectError, ReadTimeout) caught and re-raised as ArchiveError
with bytes-written context so the caller knows where the partial
download ended. Bytes already on disk remain valid for resume.

19 new regression tests (52 total, all green, ruff clean):
- 4 tests covering retry/backoff, exhaustion, HTTP-date Retry-After
- 1 test for stream-abort byte-count surfacing
- 6 tests for collection normalization shapes
- 4 tests for is_collection in real tool flow + shared client lifecycle
- 2 tests verifying download lock: same-file serialized, different files parallel
This commit is contained in:
Ryan Malloy 2026-04-21 20:24:21 -06:00
parent 4a03af1675
commit 6198defeca
4 changed files with 507 additions and 50 deletions

View File

@ -2,11 +2,14 @@
from __future__ import annotations
import asyncio
import errno
import hashlib
import os
import random
import re
from collections.abc import AsyncIterator
from email.utils import parsedate_to_datetime
from pathlib import Path
from typing import Any
from urllib.parse import quote
@ -23,6 +26,30 @@ DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0)
_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$")
_MAX_FILENAME = 512
# Retry policy for archive.org's transient overload responses.
_RETRY_STATUSES = frozenset({429, 502, 503, 504})
_RETRY_MAX_ATTEMPTS = 3
_RETRY_MAX_DELAY = 30.0
def _retry_delay(response: httpx.Response, attempt: int) -> float:
"""Compute backoff for a retryable response. Honors Retry-After when present."""
retry_after = response.headers.get("Retry-After")
if retry_after:
try:
return min(float(retry_after), _RETRY_MAX_DELAY)
except ValueError:
try:
# HTTP-date form
dt = parsedate_to_datetime(retry_after)
from datetime import datetime, timezone
wait = (dt - datetime.now(tz=timezone.utc)).total_seconds()
return min(max(wait, 0.0), _RETRY_MAX_DELAY)
except (TypeError, ValueError):
pass
# Exponential backoff with jitter: 1s, 2s, 4s + [0,1)s
return min((2 ** attempt) + random.random(), _RETRY_MAX_DELAY)
class ArchiveError(RuntimeError):
"""Raised when archive.org returns an error payload or unexpected status."""
@ -131,12 +158,20 @@ class ArchiveClient:
# ---------- internal: error-surfacing fetch ----------
async def _fetch_json(self, url: str, params: Any = None) -> Any:
"""GET + JSON decode with archive.org-friendly error messages.
"""GET + JSON decode with archive.org-friendly error messages and retry.
Wraps raise_for_status so that 4xx/5xx responses include a body preview
in the exception invaluable for an LLM trying to fix a bad query.
Retries on 429/502/503/504 with Retry-After honored. 4xx/5xx responses
that aren't retryable include a body preview in the exception —
invaluable for an LLM trying to fix a bad query.
"""
r: httpx.Response | None = None
for attempt in range(_RETRY_MAX_ATTEMPTS):
r = await self._client.get(url, params=params)
if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1:
await asyncio.sleep(_retry_delay(r, attempt))
continue
break
assert r is not None # the loop runs at least once
if r.is_error:
body = r.text[:500] if r.content else ""
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
@ -243,8 +278,11 @@ class ArchiveClient:
) -> AsyncIterator[bytes]:
"""Async byte iterator. If resume_from > 0, requires a 206 response.
Retries on 429/502/503/504 BEFORE yielding any bytes (so retry never
risks corrupting a partially-written file).
Raises ArchiveError BEFORE yielding any bytes if:
- the server returns a 4xx/5xx
- the server returns a non-retryable 4xx/5xx
- resume was requested but the server returned 200 (Range ignored)
- the Content-Range start byte doesn't match resume_from
"""
@ -256,7 +294,11 @@ class ArchiveClient:
headers["Range"] = f"bytes={resume_from}-"
url = self.download_url(identifier, filename)
for attempt in range(_RETRY_MAX_ATTEMPTS):
async with self._client.stream("GET", url, headers=headers) as r:
if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1:
await asyncio.sleep(_retry_delay(r, attempt))
continue
if r.is_error:
body = (await r.aread())[:500].decode("utf-8", errors="replace")
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
@ -277,6 +319,7 @@ class ArchiveClient:
)
async for chunk in r.aiter_bytes(chunk_size=1 << 16):
yield chunk
return
async def download_to_file(
self,
@ -321,6 +364,16 @@ class ArchiveClient:
hasher.update(chunk)
if chunk_cb:
chunk_cb(bytes_written)
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e:
# H1: surface partial-state context so the caller can decide whether
# to resume or restart. The bytes already on disk are valid (we only
# write whole chunks), so a follow-up call with overwrite=False will
# resume cleanly from `bytes_written`.
raise ArchiveError(
f"download interrupted after {bytes_written - resume_from} new bytes "
f"({bytes_written} total on disk, resumed from {resume_from}). "
f"File at {dest} — call download_file again to resume. Cause: {e!r}"
) from e
finally:
f.close()
@ -335,3 +388,38 @@ class ArchiveClient:
result["md5_expected"] = verify_md5
result["md5_ok"] = actual.lower() == verify_md5.lower()
return result
# ---------- H7: process-wide shared client ----------
#
# An MCP server typically lives forever serving one LLM. Spinning up a fresh
# httpx.AsyncClient per tool call wastes a TCP+TLS handshake (~200-400ms) and
# burns ephemeral ports under parallel fan-out. Share one client across calls.
_shared_client: ArchiveClient | None = None
async def get_shared_client() -> ArchiveClient:
"""Return the process-wide shared ArchiveClient, creating it on first use.
Tests should NOT use this construct an ArchiveClient(transport=...)
directly so MockTransport injection works without leaking into other tests.
"""
global _shared_client
if _shared_client is None:
# The race window here is small. If two coroutines both create, one
# client gets discarded — wasteful but not corrupting.
client = ArchiveClient()
if _shared_client is None:
_shared_client = client
else:
await client.aclose()
return _shared_client
async def close_shared_client() -> None:
"""Close and clear the shared client. Useful for tests; safe to call many times."""
global _shared_client
if _shared_client is not None:
c, _shared_client = _shared_client, None
await c.aclose()

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import fnmatch
import os
from pathlib import Path
@ -13,12 +14,32 @@ from pydantic import Field
from mcarchive_org import __version__
from mcarchive_org.client import (
ArchiveClient,
ArchiveError,
get_shared_client,
validate_filename,
validate_identifier,
)
# Per-(identifier, filename) locks to serialize concurrent downloads of the same
# file inside the same process. Cross-process races would still need an fcntl
# advisory lock; a single MCP-server process is the common case.
_download_locks: dict[str, asyncio.Lock] = {}
def _download_lock_for(identifier: str, filename: str) -> asyncio.Lock:
"""Get-or-create the in-process lock for one (identifier, filename) pair.
Safe without an outer lock because asyncio is cooperative: there are no
awaits between the dict check and the dict set, so two coroutines can't
interleave inside this function.
"""
key = f"{identifier}::{filename}"
lock = _download_locks.get(key)
if lock is None:
lock = asyncio.Lock()
_download_locks[key] = lock
return lock
def _resolve_download_root() -> Path:
"""Resolve the download root lazily so env-var changes after import are honored."""
@ -83,6 +104,30 @@ def _matches(name: str, format_: str | None, name_glob: str | None, formats: lis
return not (formats and (format_ or "").lower() not in {f.lower() for f in formats})
def _normalize_collection(value: Any) -> list[str]:
"""archive.org returns `collection` as either str or list[str]. Normalize to list.
Stable shape lets LLMs write `if 'foo' in doc['collection']` without first
having to remember which type they're holding.
"""
if value is None:
return []
if isinstance(value, str):
return [value] if value else []
if isinstance(value, list):
return [str(x) for x in value if x]
return [str(value)]
def _enrich_doc(doc: dict[str, Any]) -> dict[str, Any]:
"""Apply normalization + derived fields to a search-result doc or metadata blob."""
out = dict(doc)
if "collection" in out:
out["collection"] = _normalize_collection(out["collection"])
out["is_collection"] = out.get("mediatype") == "collection"
return out
def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path:
"""Construct + verify the download destination path.
@ -133,19 +178,21 @@ async def search_items(
"""Search archive.org items. Good for small/interactive queries.
Returns up to `rows` matching items plus `num_found` (total hits) and `has_more`.
Each doc has `is_collection` derived from mediatype; `collection` is always a list.
Use scrape_items for bulk iteration over large result sets.
"""
async with ArchiveClient() as c:
c = await get_shared_client()
result = await c.search(query=query, fields=fields, sort=sort, rows=rows, page=page)
total = result["num_found"]
seen = (page - 1) * rows + len(result["docs"])
docs = [_enrich_doc(d) for d in result["docs"]]
seen = (page - 1) * rows + len(docs)
return {
"query": query,
"num_found": total,
"page": page,
"rows": rows,
"has_more": seen < total,
"docs": result["docs"],
"docs": docs,
}
@ -160,11 +207,12 @@ async def scrape_items(
"""Scrape API — high-throughput cursor-paginated search. count >= 100.
Response includes `cursor` (for next page) when more results exist; missing when done.
Each item has `is_collection` derived from mediatype; `collection` is always a list.
"""
async with ArchiveClient() as c:
c = await get_shared_client()
data = await c.scrape(query=query, fields=fields, sorts=sorts, count=count, cursor=cursor)
return {
"items": data.get("items", []),
"items": [_enrich_doc(d) for d in data.get("items", [])],
"count": data.get("count"),
"total": data.get("total"),
"next_cursor": data.get("cursor"),
@ -189,15 +237,17 @@ async def _fetch_item_metadata(identifier: str, include_files: bool) -> dict[str
"""Shared metadata-fetching logic. Used by both the tool and the MCP resource
so neither has to depend on the other's `.fn` attribute."""
validate_identifier(identifier)
async with ArchiveClient() as c:
c = await get_shared_client()
data = await c.metadata(identifier)
md = data.get("metadata", {})
mediatype = md.get("mediatype")
out: dict[str, Any] = {
"identifier": md.get("identifier", identifier),
"title": md.get("title"),
"mediatype": md.get("mediatype"),
"collection": md.get("collection"),
"mediatype": mediatype,
"is_collection": mediatype == "collection",
"collection": _normalize_collection(md.get("collection")),
"creator": md.get("creator"),
"date": md.get("date"),
"description": md.get("description"),
@ -235,7 +285,7 @@ async def list_files(
Each entry includes a ready-to-use `download_url`.
"""
validate_identifier(identifier)
async with ArchiveClient() as c:
c = await get_shared_client()
files = await c.files(identifier)
matches = [
@ -287,16 +337,21 @@ async def download_file(
The destination is path-confined to either `dest_dir` (when given) or
$MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute
paths, or NUL bytes are rejected before any FS or network I/O.
Concurrent calls for the same (identifier, filename) are serialized in-process
so two parallel tool invocations can't race on the same destination file.
"""
dest = _confine_dest(identifier, filename, dest_dir)
if overwrite and dest.exists() and not dest.is_symlink():
dest.unlink()
elif overwrite and dest.is_symlink():
# Don't follow the symlink to delete the target; remove the link itself.
lock = _download_lock_for(identifier, filename)
async with lock:
if overwrite and (dest.exists() or dest.is_symlink()):
# is_symlink() before exists() — a dangling symlink reports exists()=False
# but we still want to remove the link itself rather than follow it.
dest.unlink()
try:
async with ArchiveClient() as c:
c = await get_shared_client()
result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5)
except ArchiveError as e:
# Re-raise with the destination context so the caller can act on it.

View File

@ -223,6 +223,127 @@ async def test_invalid_json_response_surfaced():
# ---------- happy path ----------
# ---------- M1: retry/backoff with Retry-After ----------
async def test_retry_on_429_then_success(monkeypatch):
"""First call gets 429 with Retry-After: 0, second call succeeds."""
sleeps: list[float] = []
async def fake_sleep(d: float) -> None:
sleeps.append(d)
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep)
calls = {"n": 0}
def handler(req: httpx.Request) -> httpx.Response:
calls["n"] += 1
if calls["n"] == 1:
return httpx.Response(429, headers={"Retry-After": "0"}, json={"error": "slow down"})
return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}})
async with _client_with(handler) as c:
result = await c.search(query="x", rows=1)
assert result["num_found"] == 0
assert calls["n"] == 2
assert sleeps == [0.0] # honored Retry-After: 0
async def test_retry_exhaustion_raises_with_body(monkeypatch):
"""If 429 persists past max_attempts, the final error body is surfaced."""
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep())
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(429, json={"error": "rate limit exhausted"})
async with _client_with(handler) as c:
with pytest.raises(ArchiveError, match="rate limit exhausted"):
await c.search(query="x")
async def _noop_sleep():
"""Used in place of asyncio.sleep when we don't care about backoff timing."""
async def test_retry_on_503_for_stream(monkeypatch, tmp_path):
"""Stream-level retry: 503 once, then 200 with body."""
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep())
body = b"actual file body"
calls = {"n": 0}
def handler(req: httpx.Request) -> httpx.Response:
calls["n"] += 1
if calls["n"] == 1:
return httpx.Response(503, text="overloaded")
return httpx.Response(200, content=body)
dest = tmp_path / "f.bin"
async with _client_with(handler) as c:
result = await c.download_to_file("nasa", "f.bin", dest)
assert result["bytes_written"] == len(body)
assert calls["n"] == 2
assert dest.read_bytes() == body
async def test_retry_after_http_date_form(monkeypatch):
"""Retry-After can be an HTTP-date; we must parse it to a delta seconds."""
sleeps: list[float] = []
async def fake_sleep(d: float) -> None:
sleeps.append(d)
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep)
calls = {"n": 0}
def handler(req: httpx.Request) -> httpx.Response:
calls["n"] += 1
if calls["n"] == 1:
# An HTTP-date in the past should produce a 0-or-negative wait, clamped to 0.
return httpx.Response(429, headers={"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"})
return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}})
async with _client_with(handler) as c:
await c.search(query="x")
assert sleeps == [0.0]
# ---------- H1: stream-abort error context ----------
async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path):
"""If httpx raises mid-stream, we wrap it in ArchiveError with byte count
so the caller knows where the partial download ended."""
# Yield enough bytes to flush past httpx's internal chunk buffer (64KB) so
# at least one chunk reaches our writer before the error fires.
chunk_payload = b"X" * (1 << 17) # 128KB — multiple buffer fills
async def evil_body():
yield chunk_payload
raise httpx.ReadError("simulated network drop")
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=evil_body())
dest = tmp_path / "interrupted.bin"
async with _client_with(handler) as c:
with pytest.raises(ArchiveError) as exc_info:
await c.download_to_file("nasa", "interrupted.bin", dest)
msg = str(exc_info.value)
assert "interrupted after" in msg
assert "ReadError" in msg
# Partial bytes ARE on disk — at least the first delivered chunk.
on_disk = dest.read_bytes()
assert len(on_disk) > 0
assert on_disk == chunk_payload[: len(on_disk)]
# ---------- happy path ----------
async def test_fresh_download_writes_full_body(tmp_path):
body = b"hello world" * 100
dest = tmp_path / "new.bin"

193
tests/test_server_mocked.py Normal file
View File

@ -0,0 +1,193 @@
"""Server-layer regression tests using a swapped-in shared client.
These exercise the MCP tool functions directly and verify:
- Collection normalization (M5)
- `is_collection` derived flag (M7)
- Shared client lifecycle (H7)
- Concurrent-download serialization (M2)
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import httpx
import pytest
from mcarchive_org import client as client_mod
from mcarchive_org.client import ArchiveClient
from mcarchive_org.server import (
_enrich_doc,
_normalize_collection,
download_file,
get_item_metadata,
search_items,
)
@asynccontextmanager
async def swap_shared_client(handler):
"""Temporarily replace the process-wide shared client with a mock-backed one.
Tests that exercise server.py tools need this because those tools call
get_shared_client() under the hood, and we can't pass a transport in.
"""
saved = client_mod._shared_client
mock = ArchiveClient(transport=httpx.MockTransport(handler))
client_mod._shared_client = mock
try:
yield mock
finally:
client_mod._shared_client = saved
await mock.aclose()
# ---------- M5: collection normalization ----------
@pytest.mark.parametrize(
"raw,expected",
[
(None, []),
("", []),
("nasa", ["nasa"]),
(["nasa", "opensource"], ["nasa", "opensource"]),
([], []),
([None, "nasa", ""], ["nasa"]), # falsy items dropped
],
)
def test_normalize_collection_shapes(raw, expected):
assert _normalize_collection(raw) == expected
def test_enrich_doc_marks_is_collection():
assert _enrich_doc({"mediatype": "collection", "identifier": "nasa"})["is_collection"] is True
assert _enrich_doc({"mediatype": "audio", "identifier": "x"})["is_collection"] is False
assert _enrich_doc({"identifier": "x"})["is_collection"] is False
def test_enrich_doc_normalizes_collection_field():
out = _enrich_doc({"identifier": "x", "collection": "single"})
assert out["collection"] == ["single"]
# ---------- M7: is_collection in real tool flow ----------
async def test_search_items_decorates_docs_with_is_collection():
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={
"response": {
"numFound": 2,
"docs": [
{"identifier": "nasa", "mediatype": "collection", "collection": "nasa"},
{"identifier": "song1", "mediatype": "audio", "collection": ["etree", "GratefulDead"]},
],
}
},
)
async with swap_shared_client(handler):
result = await search_items(query="x", rows=2)
assert len(result["docs"]) == 2
nasa, song = result["docs"]
assert nasa["is_collection"] is True
assert nasa["collection"] == ["nasa"]
assert song["is_collection"] is False
assert song["collection"] == ["etree", "GratefulDead"]
async def test_get_item_metadata_normalizes_collection():
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={
"metadata": {
"identifier": "nasa",
"title": "NASA Images",
"mediatype": "collection",
"collection": "internetarchive",
},
"files_count": 0,
"item_size": 0,
},
)
async with swap_shared_client(handler):
result = await get_item_metadata(identifier="nasa")
assert result["is_collection"] is True
assert result["collection"] == ["internetarchive"]
# ---------- H7: shared client lifecycle ----------
async def test_get_shared_client_returns_same_instance():
await client_mod.close_shared_client()
a = await client_mod.get_shared_client()
b = await client_mod.get_shared_client()
assert a is b
await client_mod.close_shared_client()
async def test_close_shared_client_clears_singleton():
a = await client_mod.get_shared_client()
await client_mod.close_shared_client()
b = await client_mod.get_shared_client()
assert a is not b
await client_mod.close_shared_client()
# ---------- M2: concurrent-download serialization ----------
async def test_concurrent_downloads_same_file_are_serialized(tmp_path, monkeypatch):
"""Two parallel download_file calls for the same (id, filename) must not
interleave otherwise they'd race on the destination file."""
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
state = {"active": 0, "max_active": 0}
async def handler(req: httpx.Request) -> httpx.Response:
state["active"] += 1
state["max_active"] = max(state["max_active"], state["active"])
await asyncio.sleep(0.05) # hold the request long enough to overlap
state["active"] -= 1
return httpx.Response(200, content=b"file-content")
async with swap_shared_client(handler):
await asyncio.gather(
download_file(identifier="nasa", filename="shared.bin", overwrite=True),
download_file(identifier="nasa", filename="shared.bin", overwrite=True),
)
# The lock should have prevented any overlap.
assert state["max_active"] == 1
async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch):
"""Different filenames get different locks — they should run concurrently."""
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
state = {"active": 0, "max_active": 0}
async def handler(req: httpx.Request) -> httpx.Response:
state["active"] += 1
state["max_active"] = max(state["max_active"], state["active"])
await asyncio.sleep(0.05)
state["active"] -= 1
return httpx.Response(200, content=b"data")
async with swap_shared_client(handler):
await asyncio.gather(
download_file(identifier="nasa", filename="a.bin", overwrite=True),
download_file(identifier="nasa", filename="b.bin", overwrite=True),
)
# Different files — should overlap.
assert state["max_active"] == 2