diff --git a/src/mcarchive_org/client.py b/src/mcarchive_org/client.py index a822cdf..6aa8b6e 100644 --- a/src/mcarchive_org/client.py +++ b/src/mcarchive_org/client.py @@ -2,11 +2,14 @@ from __future__ import annotations +import asyncio import errno import hashlib import os +import random import re from collections.abc import AsyncIterator +from email.utils import parsedate_to_datetime from pathlib import Path from typing import Any from urllib.parse import quote @@ -23,6 +26,30 @@ DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0) _IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$") _MAX_FILENAME = 512 +# Retry policy for archive.org's transient overload responses. +_RETRY_STATUSES = frozenset({429, 502, 503, 504}) +_RETRY_MAX_ATTEMPTS = 3 +_RETRY_MAX_DELAY = 30.0 + + +def _retry_delay(response: httpx.Response, attempt: int) -> float: + """Compute backoff for a retryable response. Honors Retry-After when present.""" + retry_after = response.headers.get("Retry-After") + if retry_after: + try: + return min(float(retry_after), _RETRY_MAX_DELAY) + except ValueError: + try: + # HTTP-date form + dt = parsedate_to_datetime(retry_after) + from datetime import datetime, timezone + wait = (dt - datetime.now(tz=timezone.utc)).total_seconds() + return min(max(wait, 0.0), _RETRY_MAX_DELAY) + except (TypeError, ValueError): + pass + # Exponential backoff with jitter: 1s, 2s, 4s + [0,1)s + return min((2 ** attempt) + random.random(), _RETRY_MAX_DELAY) + class ArchiveError(RuntimeError): """Raised when archive.org returns an error payload or unexpected status.""" @@ -131,12 +158,20 @@ class ArchiveClient: # ---------- internal: error-surfacing fetch ---------- async def _fetch_json(self, url: str, params: Any = None) -> Any: - """GET + JSON decode with archive.org-friendly error messages. + """GET + JSON decode with archive.org-friendly error messages and retry. - Wraps raise_for_status so that 4xx/5xx responses include a body preview - in the exception — invaluable for an LLM trying to fix a bad query. + Retries on 429/502/503/504 with Retry-After honored. 4xx/5xx responses + that aren't retryable include a body preview in the exception — + invaluable for an LLM trying to fix a bad query. """ - r = await self._client.get(url, params=params) + r: httpx.Response | None = None + for attempt in range(_RETRY_MAX_ATTEMPTS): + r = await self._client.get(url, params=params) + if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1: + await asyncio.sleep(_retry_delay(r, attempt)) + continue + break + assert r is not None # the loop runs at least once if r.is_error: body = r.text[:500] if r.content else "" raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip()) @@ -243,8 +278,11 @@ class ArchiveClient: ) -> AsyncIterator[bytes]: """Async byte iterator. If resume_from > 0, requires a 206 response. + Retries on 429/502/503/504 BEFORE yielding any bytes (so retry never + risks corrupting a partially-written file). + Raises ArchiveError BEFORE yielding any bytes if: - - the server returns a 4xx/5xx + - the server returns a non-retryable 4xx/5xx - resume was requested but the server returned 200 (Range ignored) - the Content-Range start byte doesn't match resume_from """ @@ -256,27 +294,32 @@ class ArchiveClient: headers["Range"] = f"bytes={resume_from}-" url = self.download_url(identifier, filename) - async with self._client.stream("GET", url, headers=headers) as r: - if r.is_error: - body = (await r.aread())[:500].decode("utf-8", errors="replace") - raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip()) - if resume_from > 0: - if r.status_code != 206: - raise ArchiveError( - f"server ignored Range request (got HTTP {r.status_code}); " - f"local file may be stale — retry download_file with overwrite=True" - ) - # Verify the byte range starts where we expect. archive.org's CDN - # is normally well-behaved here, but trust-but-verify. - cr = r.headers.get("Content-Range", "") - m = re.match(r"bytes\s+(\d+)-", cr) - if m and int(m.group(1)) != resume_from: - raise ArchiveError( - f"Content-Range start {m.group(1)} != resume_from {resume_from}; " - f"refusing to corrupt {filename}" - ) - async for chunk in r.aiter_bytes(chunk_size=1 << 16): - yield chunk + for attempt in range(_RETRY_MAX_ATTEMPTS): + async with self._client.stream("GET", url, headers=headers) as r: + if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1: + await asyncio.sleep(_retry_delay(r, attempt)) + continue + if r.is_error: + body = (await r.aread())[:500].decode("utf-8", errors="replace") + raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip()) + if resume_from > 0: + if r.status_code != 206: + raise ArchiveError( + f"server ignored Range request (got HTTP {r.status_code}); " + f"local file may be stale — retry download_file with overwrite=True" + ) + # Verify the byte range starts where we expect. archive.org's CDN + # is normally well-behaved here, but trust-but-verify. + cr = r.headers.get("Content-Range", "") + m = re.match(r"bytes\s+(\d+)-", cr) + if m and int(m.group(1)) != resume_from: + raise ArchiveError( + f"Content-Range start {m.group(1)} != resume_from {resume_from}; " + f"refusing to corrupt {filename}" + ) + async for chunk in r.aiter_bytes(chunk_size=1 << 16): + yield chunk + return async def download_to_file( self, @@ -321,6 +364,16 @@ class ArchiveClient: hasher.update(chunk) if chunk_cb: chunk_cb(bytes_written) + except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e: + # H1: surface partial-state context so the caller can decide whether + # to resume or restart. The bytes already on disk are valid (we only + # write whole chunks), so a follow-up call with overwrite=False will + # resume cleanly from `bytes_written`. + raise ArchiveError( + f"download interrupted after {bytes_written - resume_from} new bytes " + f"({bytes_written} total on disk, resumed from {resume_from}). " + f"File at {dest} — call download_file again to resume. Cause: {e!r}" + ) from e finally: f.close() @@ -335,3 +388,38 @@ class ArchiveClient: result["md5_expected"] = verify_md5 result["md5_ok"] = actual.lower() == verify_md5.lower() return result + + +# ---------- H7: process-wide shared client ---------- +# +# An MCP server typically lives forever serving one LLM. Spinning up a fresh +# httpx.AsyncClient per tool call wastes a TCP+TLS handshake (~200-400ms) and +# burns ephemeral ports under parallel fan-out. Share one client across calls. + +_shared_client: ArchiveClient | None = None + + +async def get_shared_client() -> ArchiveClient: + """Return the process-wide shared ArchiveClient, creating it on first use. + + Tests should NOT use this — construct an ArchiveClient(transport=...) + directly so MockTransport injection works without leaking into other tests. + """ + global _shared_client + if _shared_client is None: + # The race window here is small. If two coroutines both create, one + # client gets discarded — wasteful but not corrupting. + client = ArchiveClient() + if _shared_client is None: + _shared_client = client + else: + await client.aclose() + return _shared_client + + +async def close_shared_client() -> None: + """Close and clear the shared client. Useful for tests; safe to call many times.""" + global _shared_client + if _shared_client is not None: + c, _shared_client = _shared_client, None + await c.aclose() diff --git a/src/mcarchive_org/server.py b/src/mcarchive_org/server.py index ceb33df..f6a3e56 100644 --- a/src/mcarchive_org/server.py +++ b/src/mcarchive_org/server.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import fnmatch import os from pathlib import Path @@ -13,12 +14,32 @@ from pydantic import Field from mcarchive_org import __version__ from mcarchive_org.client import ( - ArchiveClient, ArchiveError, + get_shared_client, validate_filename, validate_identifier, ) +# Per-(identifier, filename) locks to serialize concurrent downloads of the same +# file inside the same process. Cross-process races would still need an fcntl +# advisory lock; a single MCP-server process is the common case. +_download_locks: dict[str, asyncio.Lock] = {} + + +def _download_lock_for(identifier: str, filename: str) -> asyncio.Lock: + """Get-or-create the in-process lock for one (identifier, filename) pair. + + Safe without an outer lock because asyncio is cooperative: there are no + awaits between the dict check and the dict set, so two coroutines can't + interleave inside this function. + """ + key = f"{identifier}::{filename}" + lock = _download_locks.get(key) + if lock is None: + lock = asyncio.Lock() + _download_locks[key] = lock + return lock + def _resolve_download_root() -> Path: """Resolve the download root lazily so env-var changes after import are honored.""" @@ -83,6 +104,30 @@ def _matches(name: str, format_: str | None, name_glob: str | None, formats: lis return not (formats and (format_ or "").lower() not in {f.lower() for f in formats}) +def _normalize_collection(value: Any) -> list[str]: + """archive.org returns `collection` as either str or list[str]. Normalize to list. + + Stable shape lets LLMs write `if 'foo' in doc['collection']` without first + having to remember which type they're holding. + """ + if value is None: + return [] + if isinstance(value, str): + return [value] if value else [] + if isinstance(value, list): + return [str(x) for x in value if x] + return [str(value)] + + +def _enrich_doc(doc: dict[str, Any]) -> dict[str, Any]: + """Apply normalization + derived fields to a search-result doc or metadata blob.""" + out = dict(doc) + if "collection" in out: + out["collection"] = _normalize_collection(out["collection"]) + out["is_collection"] = out.get("mediatype") == "collection" + return out + + def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path: """Construct + verify the download destination path. @@ -133,19 +178,21 @@ async def search_items( """Search archive.org items. Good for small/interactive queries. Returns up to `rows` matching items plus `num_found` (total hits) and `has_more`. + Each doc has `is_collection` derived from mediatype; `collection` is always a list. Use scrape_items for bulk iteration over large result sets. """ - async with ArchiveClient() as c: - result = await c.search(query=query, fields=fields, sort=sort, rows=rows, page=page) + c = await get_shared_client() + result = await c.search(query=query, fields=fields, sort=sort, rows=rows, page=page) total = result["num_found"] - seen = (page - 1) * rows + len(result["docs"]) + docs = [_enrich_doc(d) for d in result["docs"]] + seen = (page - 1) * rows + len(docs) return { "query": query, "num_found": total, "page": page, "rows": rows, "has_more": seen < total, - "docs": result["docs"], + "docs": docs, } @@ -160,11 +207,12 @@ async def scrape_items( """Scrape API — high-throughput cursor-paginated search. count >= 100. Response includes `cursor` (for next page) when more results exist; missing when done. + Each item has `is_collection` derived from mediatype; `collection` is always a list. """ - async with ArchiveClient() as c: - data = await c.scrape(query=query, fields=fields, sorts=sorts, count=count, cursor=cursor) + c = await get_shared_client() + data = await c.scrape(query=query, fields=fields, sorts=sorts, count=count, cursor=cursor) return { - "items": data.get("items", []), + "items": [_enrich_doc(d) for d in data.get("items", [])], "count": data.get("count"), "total": data.get("total"), "next_cursor": data.get("cursor"), @@ -189,15 +237,17 @@ async def _fetch_item_metadata(identifier: str, include_files: bool) -> dict[str """Shared metadata-fetching logic. Used by both the tool and the MCP resource so neither has to depend on the other's `.fn` attribute.""" validate_identifier(identifier) - async with ArchiveClient() as c: - data = await c.metadata(identifier) + c = await get_shared_client() + data = await c.metadata(identifier) md = data.get("metadata", {}) + mediatype = md.get("mediatype") out: dict[str, Any] = { "identifier": md.get("identifier", identifier), "title": md.get("title"), - "mediatype": md.get("mediatype"), - "collection": md.get("collection"), + "mediatype": mediatype, + "is_collection": mediatype == "collection", + "collection": _normalize_collection(md.get("collection")), "creator": md.get("creator"), "date": md.get("date"), "description": md.get("description"), @@ -235,8 +285,8 @@ async def list_files( Each entry includes a ready-to-use `download_url`. """ validate_identifier(identifier) - async with ArchiveClient() as c: - files = await c.files(identifier) + c = await get_shared_client() + files = await c.files(identifier) matches = [ _enrich_file(identifier, f) @@ -287,20 +337,25 @@ async def download_file( The destination is path-confined to either `dest_dir` (when given) or $MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute paths, or NUL bytes are rejected before any FS or network I/O. + + Concurrent calls for the same (identifier, filename) are serialized in-process + so two parallel tool invocations can't race on the same destination file. """ dest = _confine_dest(identifier, filename, dest_dir) - if overwrite and dest.exists() and not dest.is_symlink(): - dest.unlink() - elif overwrite and dest.is_symlink(): - # Don't follow the symlink to delete the target; remove the link itself. - dest.unlink() - try: - async with ArchiveClient() as c: + lock = _download_lock_for(identifier, filename) + async with lock: + if overwrite and (dest.exists() or dest.is_symlink()): + # is_symlink() before exists() — a dangling symlink reports exists()=False + # but we still want to remove the link itself rather than follow it. + dest.unlink() + + try: + c = await get_shared_client() result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5) - except ArchiveError as e: - # Re-raise with the destination context so the caller can act on it. - raise ArchiveError(f"{e} (dest={dest})") from e + except ArchiveError as e: + # Re-raise with the destination context so the caller can act on it. + raise ArchiveError(f"{e} (dest={dest})") from e result["identifier"] = identifier result["filename"] = filename diff --git a/tests/test_client_mocked.py b/tests/test_client_mocked.py index 0957353..7a64769 100644 --- a/tests/test_client_mocked.py +++ b/tests/test_client_mocked.py @@ -223,6 +223,127 @@ async def test_invalid_json_response_surfaced(): # ---------- happy path ---------- +# ---------- M1: retry/backoff with Retry-After ---------- + + +async def test_retry_on_429_then_success(monkeypatch): + """First call gets 429 with Retry-After: 0, second call succeeds.""" + sleeps: list[float] = [] + + async def fake_sleep(d: float) -> None: + sleeps.append(d) + + monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep) + + calls = {"n": 0} + + def handler(req: httpx.Request) -> httpx.Response: + calls["n"] += 1 + if calls["n"] == 1: + return httpx.Response(429, headers={"Retry-After": "0"}, json={"error": "slow down"}) + return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}}) + + async with _client_with(handler) as c: + result = await c.search(query="x", rows=1) + assert result["num_found"] == 0 + assert calls["n"] == 2 + assert sleeps == [0.0] # honored Retry-After: 0 + + +async def test_retry_exhaustion_raises_with_body(monkeypatch): + """If 429 persists past max_attempts, the final error body is surfaced.""" + monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep()) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(429, json={"error": "rate limit exhausted"}) + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="rate limit exhausted"): + await c.search(query="x") + + +async def _noop_sleep(): + """Used in place of asyncio.sleep when we don't care about backoff timing.""" + + +async def test_retry_on_503_for_stream(monkeypatch, tmp_path): + """Stream-level retry: 503 once, then 200 with body.""" + monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep()) + + body = b"actual file body" + calls = {"n": 0} + + def handler(req: httpx.Request) -> httpx.Response: + calls["n"] += 1 + if calls["n"] == 1: + return httpx.Response(503, text="overloaded") + return httpx.Response(200, content=body) + + dest = tmp_path / "f.bin" + async with _client_with(handler) as c: + result = await c.download_to_file("nasa", "f.bin", dest) + assert result["bytes_written"] == len(body) + assert calls["n"] == 2 + assert dest.read_bytes() == body + + +async def test_retry_after_http_date_form(monkeypatch): + """Retry-After can be an HTTP-date; we must parse it to a delta seconds.""" + sleeps: list[float] = [] + + async def fake_sleep(d: float) -> None: + sleeps.append(d) + + monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep) + calls = {"n": 0} + + def handler(req: httpx.Request) -> httpx.Response: + calls["n"] += 1 + if calls["n"] == 1: + # An HTTP-date in the past should produce a 0-or-negative wait, clamped to 0. + return httpx.Response(429, headers={"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"}) + return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}}) + + async with _client_with(handler) as c: + await c.search(query="x") + assert sleeps == [0.0] + + +# ---------- H1: stream-abort error context ---------- + + +async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path): + """If httpx raises mid-stream, we wrap it in ArchiveError with byte count + so the caller knows where the partial download ended.""" + + # Yield enough bytes to flush past httpx's internal chunk buffer (64KB) so + # at least one chunk reaches our writer before the error fires. + chunk_payload = b"X" * (1 << 17) # 128KB — multiple buffer fills + + async def evil_body(): + yield chunk_payload + raise httpx.ReadError("simulated network drop") + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=evil_body()) + + dest = tmp_path / "interrupted.bin" + async with _client_with(handler) as c: + with pytest.raises(ArchiveError) as exc_info: + await c.download_to_file("nasa", "interrupted.bin", dest) + + msg = str(exc_info.value) + assert "interrupted after" in msg + assert "ReadError" in msg + # Partial bytes ARE on disk — at least the first delivered chunk. + on_disk = dest.read_bytes() + assert len(on_disk) > 0 + assert on_disk == chunk_payload[: len(on_disk)] + + +# ---------- happy path ---------- + + async def test_fresh_download_writes_full_body(tmp_path): body = b"hello world" * 100 dest = tmp_path / "new.bin" diff --git a/tests/test_server_mocked.py b/tests/test_server_mocked.py new file mode 100644 index 0000000..4c48328 --- /dev/null +++ b/tests/test_server_mocked.py @@ -0,0 +1,193 @@ +"""Server-layer regression tests using a swapped-in shared client. + +These exercise the MCP tool functions directly and verify: +- Collection normalization (M5) +- `is_collection` derived flag (M7) +- Shared client lifecycle (H7) +- Concurrent-download serialization (M2) +""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager + +import httpx +import pytest + +from mcarchive_org import client as client_mod +from mcarchive_org.client import ArchiveClient +from mcarchive_org.server import ( + _enrich_doc, + _normalize_collection, + download_file, + get_item_metadata, + search_items, +) + + +@asynccontextmanager +async def swap_shared_client(handler): + """Temporarily replace the process-wide shared client with a mock-backed one. + + Tests that exercise server.py tools need this because those tools call + get_shared_client() under the hood, and we can't pass a transport in. + """ + saved = client_mod._shared_client + mock = ArchiveClient(transport=httpx.MockTransport(handler)) + client_mod._shared_client = mock + try: + yield mock + finally: + client_mod._shared_client = saved + await mock.aclose() + + +# ---------- M5: collection normalization ---------- + + +@pytest.mark.parametrize( + "raw,expected", + [ + (None, []), + ("", []), + ("nasa", ["nasa"]), + (["nasa", "opensource"], ["nasa", "opensource"]), + ([], []), + ([None, "nasa", ""], ["nasa"]), # falsy items dropped + ], +) +def test_normalize_collection_shapes(raw, expected): + assert _normalize_collection(raw) == expected + + +def test_enrich_doc_marks_is_collection(): + assert _enrich_doc({"mediatype": "collection", "identifier": "nasa"})["is_collection"] is True + assert _enrich_doc({"mediatype": "audio", "identifier": "x"})["is_collection"] is False + assert _enrich_doc({"identifier": "x"})["is_collection"] is False + + +def test_enrich_doc_normalizes_collection_field(): + out = _enrich_doc({"identifier": "x", "collection": "single"}) + assert out["collection"] == ["single"] + + +# ---------- M7: is_collection in real tool flow ---------- + + +async def test_search_items_decorates_docs_with_is_collection(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "response": { + "numFound": 2, + "docs": [ + {"identifier": "nasa", "mediatype": "collection", "collection": "nasa"}, + {"identifier": "song1", "mediatype": "audio", "collection": ["etree", "GratefulDead"]}, + ], + } + }, + ) + + async with swap_shared_client(handler): + result = await search_items(query="x", rows=2) + + assert len(result["docs"]) == 2 + nasa, song = result["docs"] + assert nasa["is_collection"] is True + assert nasa["collection"] == ["nasa"] + assert song["is_collection"] is False + assert song["collection"] == ["etree", "GratefulDead"] + + +async def test_get_item_metadata_normalizes_collection(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "metadata": { + "identifier": "nasa", + "title": "NASA Images", + "mediatype": "collection", + "collection": "internetarchive", + }, + "files_count": 0, + "item_size": 0, + }, + ) + + async with swap_shared_client(handler): + result = await get_item_metadata(identifier="nasa") + + assert result["is_collection"] is True + assert result["collection"] == ["internetarchive"] + + +# ---------- H7: shared client lifecycle ---------- + + +async def test_get_shared_client_returns_same_instance(): + await client_mod.close_shared_client() + a = await client_mod.get_shared_client() + b = await client_mod.get_shared_client() + assert a is b + await client_mod.close_shared_client() + + +async def test_close_shared_client_clears_singleton(): + a = await client_mod.get_shared_client() + await client_mod.close_shared_client() + b = await client_mod.get_shared_client() + assert a is not b + await client_mod.close_shared_client() + + +# ---------- M2: concurrent-download serialization ---------- + + +async def test_concurrent_downloads_same_file_are_serialized(tmp_path, monkeypatch): + """Two parallel download_file calls for the same (id, filename) must not + interleave — otherwise they'd race on the destination file.""" + monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path)) + + state = {"active": 0, "max_active": 0} + + async def handler(req: httpx.Request) -> httpx.Response: + state["active"] += 1 + state["max_active"] = max(state["max_active"], state["active"]) + await asyncio.sleep(0.05) # hold the request long enough to overlap + state["active"] -= 1 + return httpx.Response(200, content=b"file-content") + + async with swap_shared_client(handler): + await asyncio.gather( + download_file(identifier="nasa", filename="shared.bin", overwrite=True), + download_file(identifier="nasa", filename="shared.bin", overwrite=True), + ) + + # The lock should have prevented any overlap. + assert state["max_active"] == 1 + + +async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch): + """Different filenames get different locks — they should run concurrently.""" + monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path)) + + state = {"active": 0, "max_active": 0} + + async def handler(req: httpx.Request) -> httpx.Response: + state["active"] += 1 + state["max_active"] = max(state["max_active"], state["active"]) + await asyncio.sleep(0.05) + state["active"] -= 1 + return httpx.Response(200, content=b"data") + + async with swap_shared_client(handler): + await asyncio.gather( + download_file(identifier="nasa", filename="a.bin", overwrite=True), + download_file(identifier="nasa", filename="b.bin", overwrite=True), + ) + + # Different files — should overlap. + assert state["max_active"] == 2