Resilience: address Hamilton tier-2 findings

H7 — Process-wide shared httpx.AsyncClient via get_shared_client().
Each tool call no longer pays a TCP+TLS handshake; connection pool is
reused across the server's lifetime. Tests inject mock transports
directly via ArchiveClient(transport=...) so the singleton stays clean.

M1 — Retry/backoff on 429/502/503/504 with Retry-After honored
(both delta-seconds and HTTP-date forms). Exponential backoff with
jitter, capped at 30s, max 3 attempts. Applied to both _fetch_json
and stream_file (retry happens BEFORE any bytes are yielded so it
can't corrupt a partial write).

M2 — Per-(identifier, filename) asyncio.Lock in download_file
serializes concurrent downloads of the same file inside one process.
Different files still download in parallel.

M5 — collection field normalized to list[str] in all output paths
(search docs, scrape items, item metadata). LLMs can write
`if 'foo' in doc['collection']` without checking the type first.

M7 — `is_collection: bool` derived from mediatype on every doc /
metadata response, so LLMs can route collection containers vs.
real media items without re-querying.

H1 — Stream-abort errors (httpx.ReadError, RemoteProtocolError,
ConnectError, ReadTimeout) caught and re-raised as ArchiveError
with bytes-written context so the caller knows where the partial
download ended. Bytes already on disk remain valid for resume.

19 new regression tests (52 total, all green, ruff clean):
- 4 tests covering retry/backoff, exhaustion, HTTP-date Retry-After
- 1 test for stream-abort byte-count surfacing
- 6 tests for collection normalization shapes
- 4 tests for is_collection in real tool flow + shared client lifecycle
- 2 tests verifying download lock: same-file serialized, different files parallel
This commit is contained in:
Ryan Malloy 2026-04-21 20:24:21 -06:00
parent 4a03af1675
commit 6198defeca
4 changed files with 507 additions and 50 deletions

View File

@ -2,11 +2,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import errno import errno
import hashlib import hashlib
import os import os
import random
import re import re
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from email.utils import parsedate_to_datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from urllib.parse import quote from urllib.parse import quote
@ -23,6 +26,30 @@ DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0)
_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$") _IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$")
_MAX_FILENAME = 512 _MAX_FILENAME = 512
# Retry policy for archive.org's transient overload responses.
_RETRY_STATUSES = frozenset({429, 502, 503, 504})
_RETRY_MAX_ATTEMPTS = 3
_RETRY_MAX_DELAY = 30.0
def _retry_delay(response: httpx.Response, attempt: int) -> float:
"""Compute backoff for a retryable response. Honors Retry-After when present."""
retry_after = response.headers.get("Retry-After")
if retry_after:
try:
return min(float(retry_after), _RETRY_MAX_DELAY)
except ValueError:
try:
# HTTP-date form
dt = parsedate_to_datetime(retry_after)
from datetime import datetime, timezone
wait = (dt - datetime.now(tz=timezone.utc)).total_seconds()
return min(max(wait, 0.0), _RETRY_MAX_DELAY)
except (TypeError, ValueError):
pass
# Exponential backoff with jitter: 1s, 2s, 4s + [0,1)s
return min((2 ** attempt) + random.random(), _RETRY_MAX_DELAY)
class ArchiveError(RuntimeError): class ArchiveError(RuntimeError):
"""Raised when archive.org returns an error payload or unexpected status.""" """Raised when archive.org returns an error payload or unexpected status."""
@ -131,12 +158,20 @@ class ArchiveClient:
# ---------- internal: error-surfacing fetch ---------- # ---------- internal: error-surfacing fetch ----------
async def _fetch_json(self, url: str, params: Any = None) -> Any: async def _fetch_json(self, url: str, params: Any = None) -> Any:
"""GET + JSON decode with archive.org-friendly error messages. """GET + JSON decode with archive.org-friendly error messages and retry.
Wraps raise_for_status so that 4xx/5xx responses include a body preview Retries on 429/502/503/504 with Retry-After honored. 4xx/5xx responses
in the exception invaluable for an LLM trying to fix a bad query. that aren't retryable include a body preview in the exception —
invaluable for an LLM trying to fix a bad query.
""" """
r = await self._client.get(url, params=params) r: httpx.Response | None = None
for attempt in range(_RETRY_MAX_ATTEMPTS):
r = await self._client.get(url, params=params)
if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1:
await asyncio.sleep(_retry_delay(r, attempt))
continue
break
assert r is not None # the loop runs at least once
if r.is_error: if r.is_error:
body = r.text[:500] if r.content else "" body = r.text[:500] if r.content else ""
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip()) raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
@ -243,8 +278,11 @@ class ArchiveClient:
) -> AsyncIterator[bytes]: ) -> AsyncIterator[bytes]:
"""Async byte iterator. If resume_from > 0, requires a 206 response. """Async byte iterator. If resume_from > 0, requires a 206 response.
Retries on 429/502/503/504 BEFORE yielding any bytes (so retry never
risks corrupting a partially-written file).
Raises ArchiveError BEFORE yielding any bytes if: Raises ArchiveError BEFORE yielding any bytes if:
- the server returns a 4xx/5xx - the server returns a non-retryable 4xx/5xx
- resume was requested but the server returned 200 (Range ignored) - resume was requested but the server returned 200 (Range ignored)
- the Content-Range start byte doesn't match resume_from - the Content-Range start byte doesn't match resume_from
""" """
@ -256,27 +294,32 @@ class ArchiveClient:
headers["Range"] = f"bytes={resume_from}-" headers["Range"] = f"bytes={resume_from}-"
url = self.download_url(identifier, filename) url = self.download_url(identifier, filename)
async with self._client.stream("GET", url, headers=headers) as r: for attempt in range(_RETRY_MAX_ATTEMPTS):
if r.is_error: async with self._client.stream("GET", url, headers=headers) as r:
body = (await r.aread())[:500].decode("utf-8", errors="replace") if r.status_code in _RETRY_STATUSES and attempt < _RETRY_MAX_ATTEMPTS - 1:
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip()) await asyncio.sleep(_retry_delay(r, attempt))
if resume_from > 0: continue
if r.status_code != 206: if r.is_error:
raise ArchiveError( body = (await r.aread())[:500].decode("utf-8", errors="replace")
f"server ignored Range request (got HTTP {r.status_code}); " raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
f"local file may be stale — retry download_file with overwrite=True" if resume_from > 0:
) if r.status_code != 206:
# Verify the byte range starts where we expect. archive.org's CDN raise ArchiveError(
# is normally well-behaved here, but trust-but-verify. f"server ignored Range request (got HTTP {r.status_code}); "
cr = r.headers.get("Content-Range", "") f"local file may be stale — retry download_file with overwrite=True"
m = re.match(r"bytes\s+(\d+)-", cr) )
if m and int(m.group(1)) != resume_from: # Verify the byte range starts where we expect. archive.org's CDN
raise ArchiveError( # is normally well-behaved here, but trust-but-verify.
f"Content-Range start {m.group(1)} != resume_from {resume_from}; " cr = r.headers.get("Content-Range", "")
f"refusing to corrupt {filename}" m = re.match(r"bytes\s+(\d+)-", cr)
) if m and int(m.group(1)) != resume_from:
async for chunk in r.aiter_bytes(chunk_size=1 << 16): raise ArchiveError(
yield chunk f"Content-Range start {m.group(1)} != resume_from {resume_from}; "
f"refusing to corrupt {filename}"
)
async for chunk in r.aiter_bytes(chunk_size=1 << 16):
yield chunk
return
async def download_to_file( async def download_to_file(
self, self,
@ -321,6 +364,16 @@ class ArchiveClient:
hasher.update(chunk) hasher.update(chunk)
if chunk_cb: if chunk_cb:
chunk_cb(bytes_written) chunk_cb(bytes_written)
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e:
# H1: surface partial-state context so the caller can decide whether
# to resume or restart. The bytes already on disk are valid (we only
# write whole chunks), so a follow-up call with overwrite=False will
# resume cleanly from `bytes_written`.
raise ArchiveError(
f"download interrupted after {bytes_written - resume_from} new bytes "
f"({bytes_written} total on disk, resumed from {resume_from}). "
f"File at {dest} — call download_file again to resume. Cause: {e!r}"
) from e
finally: finally:
f.close() f.close()
@ -335,3 +388,38 @@ class ArchiveClient:
result["md5_expected"] = verify_md5 result["md5_expected"] = verify_md5
result["md5_ok"] = actual.lower() == verify_md5.lower() result["md5_ok"] = actual.lower() == verify_md5.lower()
return result return result
# ---------- H7: process-wide shared client ----------
#
# An MCP server typically lives forever serving one LLM. Spinning up a fresh
# httpx.AsyncClient per tool call wastes a TCP+TLS handshake (~200-400ms) and
# burns ephemeral ports under parallel fan-out. Share one client across calls.
_shared_client: ArchiveClient | None = None
async def get_shared_client() -> ArchiveClient:
"""Return the process-wide shared ArchiveClient, creating it on first use.
Tests should NOT use this construct an ArchiveClient(transport=...)
directly so MockTransport injection works without leaking into other tests.
"""
global _shared_client
if _shared_client is None:
# The race window here is small. If two coroutines both create, one
# client gets discarded — wasteful but not corrupting.
client = ArchiveClient()
if _shared_client is None:
_shared_client = client
else:
await client.aclose()
return _shared_client
async def close_shared_client() -> None:
"""Close and clear the shared client. Useful for tests; safe to call many times."""
global _shared_client
if _shared_client is not None:
c, _shared_client = _shared_client, None
await c.aclose()

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import fnmatch import fnmatch
import os import os
from pathlib import Path from pathlib import Path
@ -13,12 +14,32 @@ from pydantic import Field
from mcarchive_org import __version__ from mcarchive_org import __version__
from mcarchive_org.client import ( from mcarchive_org.client import (
ArchiveClient,
ArchiveError, ArchiveError,
get_shared_client,
validate_filename, validate_filename,
validate_identifier, validate_identifier,
) )
# Per-(identifier, filename) locks to serialize concurrent downloads of the same
# file inside the same process. Cross-process races would still need an fcntl
# advisory lock; a single MCP-server process is the common case.
_download_locks: dict[str, asyncio.Lock] = {}
def _download_lock_for(identifier: str, filename: str) -> asyncio.Lock:
"""Get-or-create the in-process lock for one (identifier, filename) pair.
Safe without an outer lock because asyncio is cooperative: there are no
awaits between the dict check and the dict set, so two coroutines can't
interleave inside this function.
"""
key = f"{identifier}::{filename}"
lock = _download_locks.get(key)
if lock is None:
lock = asyncio.Lock()
_download_locks[key] = lock
return lock
def _resolve_download_root() -> Path: def _resolve_download_root() -> Path:
"""Resolve the download root lazily so env-var changes after import are honored.""" """Resolve the download root lazily so env-var changes after import are honored."""
@ -83,6 +104,30 @@ def _matches(name: str, format_: str | None, name_glob: str | None, formats: lis
return not (formats and (format_ or "").lower() not in {f.lower() for f in formats}) return not (formats and (format_ or "").lower() not in {f.lower() for f in formats})
def _normalize_collection(value: Any) -> list[str]:
"""archive.org returns `collection` as either str or list[str]. Normalize to list.
Stable shape lets LLMs write `if 'foo' in doc['collection']` without first
having to remember which type they're holding.
"""
if value is None:
return []
if isinstance(value, str):
return [value] if value else []
if isinstance(value, list):
return [str(x) for x in value if x]
return [str(value)]
def _enrich_doc(doc: dict[str, Any]) -> dict[str, Any]:
"""Apply normalization + derived fields to a search-result doc or metadata blob."""
out = dict(doc)
if "collection" in out:
out["collection"] = _normalize_collection(out["collection"])
out["is_collection"] = out.get("mediatype") == "collection"
return out
def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path: def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path:
"""Construct + verify the download destination path. """Construct + verify the download destination path.
@ -133,19 +178,21 @@ async def search_items(
"""Search archive.org items. Good for small/interactive queries. """Search archive.org items. Good for small/interactive queries.
Returns up to `rows` matching items plus `num_found` (total hits) and `has_more`. Returns up to `rows` matching items plus `num_found` (total hits) and `has_more`.
Each doc has `is_collection` derived from mediatype; `collection` is always a list.
Use scrape_items for bulk iteration over large result sets. Use scrape_items for bulk iteration over large result sets.
""" """
async with ArchiveClient() as c: c = await get_shared_client()
result = await c.search(query=query, fields=fields, sort=sort, rows=rows, page=page) result = await c.search(query=query, fields=fields, sort=sort, rows=rows, page=page)
total = result["num_found"] total = result["num_found"]
seen = (page - 1) * rows + len(result["docs"]) docs = [_enrich_doc(d) for d in result["docs"]]
seen = (page - 1) * rows + len(docs)
return { return {
"query": query, "query": query,
"num_found": total, "num_found": total,
"page": page, "page": page,
"rows": rows, "rows": rows,
"has_more": seen < total, "has_more": seen < total,
"docs": result["docs"], "docs": docs,
} }
@ -160,11 +207,12 @@ async def scrape_items(
"""Scrape API — high-throughput cursor-paginated search. count >= 100. """Scrape API — high-throughput cursor-paginated search. count >= 100.
Response includes `cursor` (for next page) when more results exist; missing when done. Response includes `cursor` (for next page) when more results exist; missing when done.
Each item has `is_collection` derived from mediatype; `collection` is always a list.
""" """
async with ArchiveClient() as c: c = await get_shared_client()
data = await c.scrape(query=query, fields=fields, sorts=sorts, count=count, cursor=cursor) data = await c.scrape(query=query, fields=fields, sorts=sorts, count=count, cursor=cursor)
return { return {
"items": data.get("items", []), "items": [_enrich_doc(d) for d in data.get("items", [])],
"count": data.get("count"), "count": data.get("count"),
"total": data.get("total"), "total": data.get("total"),
"next_cursor": data.get("cursor"), "next_cursor": data.get("cursor"),
@ -189,15 +237,17 @@ async def _fetch_item_metadata(identifier: str, include_files: bool) -> dict[str
"""Shared metadata-fetching logic. Used by both the tool and the MCP resource """Shared metadata-fetching logic. Used by both the tool and the MCP resource
so neither has to depend on the other's `.fn` attribute.""" so neither has to depend on the other's `.fn` attribute."""
validate_identifier(identifier) validate_identifier(identifier)
async with ArchiveClient() as c: c = await get_shared_client()
data = await c.metadata(identifier) data = await c.metadata(identifier)
md = data.get("metadata", {}) md = data.get("metadata", {})
mediatype = md.get("mediatype")
out: dict[str, Any] = { out: dict[str, Any] = {
"identifier": md.get("identifier", identifier), "identifier": md.get("identifier", identifier),
"title": md.get("title"), "title": md.get("title"),
"mediatype": md.get("mediatype"), "mediatype": mediatype,
"collection": md.get("collection"), "is_collection": mediatype == "collection",
"collection": _normalize_collection(md.get("collection")),
"creator": md.get("creator"), "creator": md.get("creator"),
"date": md.get("date"), "date": md.get("date"),
"description": md.get("description"), "description": md.get("description"),
@ -235,8 +285,8 @@ async def list_files(
Each entry includes a ready-to-use `download_url`. Each entry includes a ready-to-use `download_url`.
""" """
validate_identifier(identifier) validate_identifier(identifier)
async with ArchiveClient() as c: c = await get_shared_client()
files = await c.files(identifier) files = await c.files(identifier)
matches = [ matches = [
_enrich_file(identifier, f) _enrich_file(identifier, f)
@ -287,20 +337,25 @@ async def download_file(
The destination is path-confined to either `dest_dir` (when given) or The destination is path-confined to either `dest_dir` (when given) or
$MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute $MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute
paths, or NUL bytes are rejected before any FS or network I/O. paths, or NUL bytes are rejected before any FS or network I/O.
Concurrent calls for the same (identifier, filename) are serialized in-process
so two parallel tool invocations can't race on the same destination file.
""" """
dest = _confine_dest(identifier, filename, dest_dir) dest = _confine_dest(identifier, filename, dest_dir)
if overwrite and dest.exists() and not dest.is_symlink():
dest.unlink()
elif overwrite and dest.is_symlink():
# Don't follow the symlink to delete the target; remove the link itself.
dest.unlink()
try: lock = _download_lock_for(identifier, filename)
async with ArchiveClient() as c: async with lock:
if overwrite and (dest.exists() or dest.is_symlink()):
# is_symlink() before exists() — a dangling symlink reports exists()=False
# but we still want to remove the link itself rather than follow it.
dest.unlink()
try:
c = await get_shared_client()
result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5) result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5)
except ArchiveError as e: except ArchiveError as e:
# Re-raise with the destination context so the caller can act on it. # Re-raise with the destination context so the caller can act on it.
raise ArchiveError(f"{e} (dest={dest})") from e raise ArchiveError(f"{e} (dest={dest})") from e
result["identifier"] = identifier result["identifier"] = identifier
result["filename"] = filename result["filename"] = filename

View File

@ -223,6 +223,127 @@ async def test_invalid_json_response_surfaced():
# ---------- happy path ---------- # ---------- happy path ----------
# ---------- M1: retry/backoff with Retry-After ----------
async def test_retry_on_429_then_success(monkeypatch):
"""First call gets 429 with Retry-After: 0, second call succeeds."""
sleeps: list[float] = []
async def fake_sleep(d: float) -> None:
sleeps.append(d)
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep)
calls = {"n": 0}
def handler(req: httpx.Request) -> httpx.Response:
calls["n"] += 1
if calls["n"] == 1:
return httpx.Response(429, headers={"Retry-After": "0"}, json={"error": "slow down"})
return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}})
async with _client_with(handler) as c:
result = await c.search(query="x", rows=1)
assert result["num_found"] == 0
assert calls["n"] == 2
assert sleeps == [0.0] # honored Retry-After: 0
async def test_retry_exhaustion_raises_with_body(monkeypatch):
"""If 429 persists past max_attempts, the final error body is surfaced."""
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep())
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(429, json={"error": "rate limit exhausted"})
async with _client_with(handler) as c:
with pytest.raises(ArchiveError, match="rate limit exhausted"):
await c.search(query="x")
async def _noop_sleep():
"""Used in place of asyncio.sleep when we don't care about backoff timing."""
async def test_retry_on_503_for_stream(monkeypatch, tmp_path):
"""Stream-level retry: 503 once, then 200 with body."""
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", lambda d: _noop_sleep())
body = b"actual file body"
calls = {"n": 0}
def handler(req: httpx.Request) -> httpx.Response:
calls["n"] += 1
if calls["n"] == 1:
return httpx.Response(503, text="overloaded")
return httpx.Response(200, content=body)
dest = tmp_path / "f.bin"
async with _client_with(handler) as c:
result = await c.download_to_file("nasa", "f.bin", dest)
assert result["bytes_written"] == len(body)
assert calls["n"] == 2
assert dest.read_bytes() == body
async def test_retry_after_http_date_form(monkeypatch):
"""Retry-After can be an HTTP-date; we must parse it to a delta seconds."""
sleeps: list[float] = []
async def fake_sleep(d: float) -> None:
sleeps.append(d)
monkeypatch.setattr("mcarchive_org.client.asyncio.sleep", fake_sleep)
calls = {"n": 0}
def handler(req: httpx.Request) -> httpx.Response:
calls["n"] += 1
if calls["n"] == 1:
# An HTTP-date in the past should produce a 0-or-negative wait, clamped to 0.
return httpx.Response(429, headers={"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"})
return httpx.Response(200, json={"response": {"numFound": 0, "docs": []}})
async with _client_with(handler) as c:
await c.search(query="x")
assert sleeps == [0.0]
# ---------- H1: stream-abort error context ----------
async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path):
"""If httpx raises mid-stream, we wrap it in ArchiveError with byte count
so the caller knows where the partial download ended."""
# Yield enough bytes to flush past httpx's internal chunk buffer (64KB) so
# at least one chunk reaches our writer before the error fires.
chunk_payload = b"X" * (1 << 17) # 128KB — multiple buffer fills
async def evil_body():
yield chunk_payload
raise httpx.ReadError("simulated network drop")
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=evil_body())
dest = tmp_path / "interrupted.bin"
async with _client_with(handler) as c:
with pytest.raises(ArchiveError) as exc_info:
await c.download_to_file("nasa", "interrupted.bin", dest)
msg = str(exc_info.value)
assert "interrupted after" in msg
assert "ReadError" in msg
# Partial bytes ARE on disk — at least the first delivered chunk.
on_disk = dest.read_bytes()
assert len(on_disk) > 0
assert on_disk == chunk_payload[: len(on_disk)]
# ---------- happy path ----------
async def test_fresh_download_writes_full_body(tmp_path): async def test_fresh_download_writes_full_body(tmp_path):
body = b"hello world" * 100 body = b"hello world" * 100
dest = tmp_path / "new.bin" dest = tmp_path / "new.bin"

193
tests/test_server_mocked.py Normal file
View File

@ -0,0 +1,193 @@
"""Server-layer regression tests using a swapped-in shared client.
These exercise the MCP tool functions directly and verify:
- Collection normalization (M5)
- `is_collection` derived flag (M7)
- Shared client lifecycle (H7)
- Concurrent-download serialization (M2)
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import httpx
import pytest
from mcarchive_org import client as client_mod
from mcarchive_org.client import ArchiveClient
from mcarchive_org.server import (
_enrich_doc,
_normalize_collection,
download_file,
get_item_metadata,
search_items,
)
@asynccontextmanager
async def swap_shared_client(handler):
"""Temporarily replace the process-wide shared client with a mock-backed one.
Tests that exercise server.py tools need this because those tools call
get_shared_client() under the hood, and we can't pass a transport in.
"""
saved = client_mod._shared_client
mock = ArchiveClient(transport=httpx.MockTransport(handler))
client_mod._shared_client = mock
try:
yield mock
finally:
client_mod._shared_client = saved
await mock.aclose()
# ---------- M5: collection normalization ----------
@pytest.mark.parametrize(
"raw,expected",
[
(None, []),
("", []),
("nasa", ["nasa"]),
(["nasa", "opensource"], ["nasa", "opensource"]),
([], []),
([None, "nasa", ""], ["nasa"]), # falsy items dropped
],
)
def test_normalize_collection_shapes(raw, expected):
assert _normalize_collection(raw) == expected
def test_enrich_doc_marks_is_collection():
assert _enrich_doc({"mediatype": "collection", "identifier": "nasa"})["is_collection"] is True
assert _enrich_doc({"mediatype": "audio", "identifier": "x"})["is_collection"] is False
assert _enrich_doc({"identifier": "x"})["is_collection"] is False
def test_enrich_doc_normalizes_collection_field():
out = _enrich_doc({"identifier": "x", "collection": "single"})
assert out["collection"] == ["single"]
# ---------- M7: is_collection in real tool flow ----------
async def test_search_items_decorates_docs_with_is_collection():
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={
"response": {
"numFound": 2,
"docs": [
{"identifier": "nasa", "mediatype": "collection", "collection": "nasa"},
{"identifier": "song1", "mediatype": "audio", "collection": ["etree", "GratefulDead"]},
],
}
},
)
async with swap_shared_client(handler):
result = await search_items(query="x", rows=2)
assert len(result["docs"]) == 2
nasa, song = result["docs"]
assert nasa["is_collection"] is True
assert nasa["collection"] == ["nasa"]
assert song["is_collection"] is False
assert song["collection"] == ["etree", "GratefulDead"]
async def test_get_item_metadata_normalizes_collection():
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={
"metadata": {
"identifier": "nasa",
"title": "NASA Images",
"mediatype": "collection",
"collection": "internetarchive",
},
"files_count": 0,
"item_size": 0,
},
)
async with swap_shared_client(handler):
result = await get_item_metadata(identifier="nasa")
assert result["is_collection"] is True
assert result["collection"] == ["internetarchive"]
# ---------- H7: shared client lifecycle ----------
async def test_get_shared_client_returns_same_instance():
await client_mod.close_shared_client()
a = await client_mod.get_shared_client()
b = await client_mod.get_shared_client()
assert a is b
await client_mod.close_shared_client()
async def test_close_shared_client_clears_singleton():
a = await client_mod.get_shared_client()
await client_mod.close_shared_client()
b = await client_mod.get_shared_client()
assert a is not b
await client_mod.close_shared_client()
# ---------- M2: concurrent-download serialization ----------
async def test_concurrent_downloads_same_file_are_serialized(tmp_path, monkeypatch):
"""Two parallel download_file calls for the same (id, filename) must not
interleave otherwise they'd race on the destination file."""
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
state = {"active": 0, "max_active": 0}
async def handler(req: httpx.Request) -> httpx.Response:
state["active"] += 1
state["max_active"] = max(state["max_active"], state["active"])
await asyncio.sleep(0.05) # hold the request long enough to overlap
state["active"] -= 1
return httpx.Response(200, content=b"file-content")
async with swap_shared_client(handler):
await asyncio.gather(
download_file(identifier="nasa", filename="shared.bin", overwrite=True),
download_file(identifier="nasa", filename="shared.bin", overwrite=True),
)
# The lock should have prevented any overlap.
assert state["max_active"] == 1
async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch):
"""Different filenames get different locks — they should run concurrently."""
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
state = {"active": 0, "max_active": 0}
async def handler(req: httpx.Request) -> httpx.Response:
state["active"] += 1
state["max_active"] = max(state["max_active"], state["active"])
await asyncio.sleep(0.05)
state["active"] -= 1
return httpx.Response(200, content=b"data")
async with swap_shared_client(handler):
await asyncio.gather(
download_file(identifier="nasa", filename="a.bin", overwrite=True),
download_file(identifier="nasa", filename="b.bin", overwrite=True),
)
# Different files — should overlap.
assert state["max_active"] == 2