Hardening: address Hamilton review ship-blockers
Critical fixes:
- Validate identifier (^[A-Za-z0-9._-]+$) and filename (no '..', absolute
paths, NUL bytes, drive letters) at the client boundary
- Confine download destinations under MCARCHIVE_DOWNLOAD_ROOT via
Path.resolve() + is_relative_to() check; reject symlinked dirs
- Use O_NOFOLLOW on the destination open() to refuse symlink substitution
- Detect Range-ignored responses: if resume requested but server returns 200
(or 206 with wrong Content-Range start), raise ArchiveError BEFORE writing
any bytes — closes the silent file-corruption hole
Usability:
- Wrap raise_for_status everywhere with ArchiveError that includes the
response body preview — 4xx Solr errors now tell you what's wrong
- URL-encode filenames in download URLs (handles spaces and special chars)
- Map archive.org's {"error": ...} payloads on /metadata/{id}/files to
ArchiveError with the server's message
- Lazy-resolve download root so env-var changes after import are honored
- Refactor item_resource to a shared async helper (drops .fn type-ignore)
- Rename result key 'bytes' -> 'bytes_written' (avoids shadowing builtin)
Tests:
- New tests/test_client_mocked.py: 29 regression tests using
httpx.MockTransport covering every Hamilton finding above (path traversal,
symlink refusal, Range-ignored, Content-Range mismatch, error body
surfacing, malformed JSON, dark items, etc.)
- Set asyncio_mode = "auto" in pyproject for cleaner test markers
33/33 tests pass (4 live + 29 mocked), ruff clean.
This commit is contained in:
parent
5265a6440b
commit
4a03af1675
@ -38,6 +38,9 @@ build-backend = "hatchling.build"
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/mcarchive_org"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
|
||||
@ -2,29 +2,104 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
ARCHIVE_BASE = "https://archive.org"
|
||||
DEFAULT_UA = "mcarchive-org/2026.04.21 (+https://archive.org/developers/)"
|
||||
# Per-chunk read timeout (60s) means a stalled stream is caught between chunks.
|
||||
# Don't relax this without thinking about hung TCP connections.
|
||||
DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0)
|
||||
|
||||
# Archive.org documents identifiers as [A-Za-z0-9._-], starting with alnum, max 100 chars.
|
||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$")
|
||||
_MAX_FILENAME = 512
|
||||
|
||||
|
||||
class ArchiveError(RuntimeError):
|
||||
"""Raised when archive.org returns an error payload or unexpected status."""
|
||||
|
||||
|
||||
class ArchiveClient:
|
||||
"""Async client for the three archive.org endpoints we care about.
|
||||
# ---------- input validators (defense-in-depth) ----------
|
||||
|
||||
- advancedsearch.php : small Solr-style queries (<= ~10,000 rows paginated)
|
||||
- services/search/v1/scrape : bulk cursor-based iteration (count >= 100)
|
||||
|
||||
def validate_identifier(identifier: str) -> str:
|
||||
"""Reject identifiers that don't match archive.org's documented grammar.
|
||||
|
||||
Both archive.org and the local filesystem trust this string — '/' or '..'
|
||||
in an identifier is a path-traversal vector.
|
||||
"""
|
||||
if not isinstance(identifier, str) or not _IDENTIFIER_RE.match(identifier):
|
||||
raise ValueError(
|
||||
f"invalid archive.org identifier: {identifier!r} "
|
||||
f"(must match {_IDENTIFIER_RE.pattern})"
|
||||
)
|
||||
return identifier
|
||||
|
||||
|
||||
def validate_filename(filename: str) -> str:
|
||||
"""Reject filenames that could escape a download root or attack the FS.
|
||||
|
||||
archive.org files[].name CAN contain forward slashes for subdirectory files
|
||||
(e.g. 'cover/back.jpg') — that's allowed. What's rejected: '..' components,
|
||||
absolute paths, NUL bytes, Windows drive letters, and excessive length.
|
||||
"""
|
||||
if not isinstance(filename, str) or not filename:
|
||||
raise ValueError(f"filename must be a non-empty string, got {filename!r}")
|
||||
if len(filename) > _MAX_FILENAME:
|
||||
raise ValueError(f"filename exceeds {_MAX_FILENAME} chars")
|
||||
if "\x00" in filename:
|
||||
raise ValueError(f"filename contains NUL byte: {filename!r}")
|
||||
if filename.startswith(("/", "\\")):
|
||||
raise ValueError(f"filename must not be absolute: {filename!r}")
|
||||
if len(filename) >= 2 and filename[1] == ":":
|
||||
raise ValueError(f"filename must not be a Windows drive path: {filename!r}")
|
||||
if any(part == ".." for part in filename.replace("\\", "/").split("/")):
|
||||
raise ValueError(f"filename must not contain '..' components: {filename!r}")
|
||||
return filename
|
||||
|
||||
|
||||
# ---------- safe filesystem open ----------
|
||||
|
||||
|
||||
def _safe_open_for_write(dest: Path, append: bool):
|
||||
"""Open dest for writing, refusing to follow a symlink at the leaf.
|
||||
|
||||
Defense against the symlink-substitution race: even if our path-confinement
|
||||
check passes, a symlink at `dest` could redirect the write. O_NOFOLLOW tells
|
||||
the kernel to fail the open instead.
|
||||
"""
|
||||
flags = os.O_WRONLY | os.O_CREAT
|
||||
flags |= os.O_APPEND if append else os.O_TRUNC
|
||||
nofollow = getattr(os, "O_NOFOLLOW", 0) # not present on Windows
|
||||
flags |= nofollow
|
||||
try:
|
||||
fd = os.open(dest, flags, 0o644)
|
||||
except OSError as e:
|
||||
if nofollow and e.errno == errno.ELOOP:
|
||||
raise ArchiveError(f"refusing to write through symlink at {dest}") from e
|
||||
raise
|
||||
return os.fdopen(fd, "ab" if append else "wb")
|
||||
|
||||
|
||||
# ---------- client ----------
|
||||
|
||||
|
||||
class ArchiveClient:
|
||||
"""Async client for the archive.org endpoints we wrap.
|
||||
|
||||
- advancedsearch.php : Solr-style queries (<= ~10,000 rows paginated)
|
||||
- services/search/v1/scrape : bulk cursor pagination (count >= 100)
|
||||
- metadata/{id} : full item manifest including files[]
|
||||
- download/{id}/{file} : byte stream with Range support
|
||||
- download/{id}/{file} : byte stream with HTTP Range support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -32,13 +107,17 @@ class ArchiveClient:
|
||||
base_url: str = ARCHIVE_BASE,
|
||||
user_agent: str = DEFAULT_UA,
|
||||
timeout: httpx.Timeout | float = DEFAULT_TIMEOUT,
|
||||
transport: httpx.AsyncBaseTransport | None = None,
|
||||
) -> None:
|
||||
self._base = base_url.rstrip("/")
|
||||
self._client = httpx.AsyncClient(
|
||||
headers={"User-Agent": user_agent, "Accept": "application/json"},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"headers": {"User-Agent": user_agent, "Accept": "application/json"},
|
||||
"timeout": timeout,
|
||||
"follow_redirects": True,
|
||||
}
|
||||
if transport is not None:
|
||||
kwargs["transport"] = transport
|
||||
self._client = httpx.AsyncClient(**kwargs)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._client.aclose()
|
||||
@ -49,6 +128,25 @@ class ArchiveClient:
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
await self.aclose()
|
||||
|
||||
# ---------- internal: error-surfacing fetch ----------
|
||||
|
||||
async def _fetch_json(self, url: str, params: Any = None) -> Any:
|
||||
"""GET + JSON decode with archive.org-friendly error messages.
|
||||
|
||||
Wraps raise_for_status so that 4xx/5xx responses include a body preview
|
||||
in the exception — invaluable for an LLM trying to fix a bad query.
|
||||
"""
|
||||
r = await self._client.get(url, params=params)
|
||||
if r.is_error:
|
||||
body = r.text[:500] if r.content else ""
|
||||
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
||||
try:
|
||||
return r.json()
|
||||
except ValueError as e:
|
||||
raise ArchiveError(
|
||||
f"invalid JSON from {r.url}: {r.text[:200]!r}"
|
||||
) from e
|
||||
|
||||
# ---------- search ----------
|
||||
|
||||
async def search(
|
||||
@ -71,9 +169,7 @@ class ArchiveClient:
|
||||
for s in sort or []:
|
||||
params.append(("sort[]", s))
|
||||
|
||||
r = await self._client.get(f"{self._base}/advancedsearch.php", params=params)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
data = await self._fetch_json(f"{self._base}/advancedsearch.php", params=params)
|
||||
resp = data.get("response", {})
|
||||
return {
|
||||
"num_found": resp.get("numFound", 0),
|
||||
@ -103,39 +199,41 @@ class ArchiveClient:
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
r = await self._client.get(f"{self._base}/services/search/v1/scrape", params=params)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "error" in data:
|
||||
data = await self._fetch_json(f"{self._base}/services/search/v1/scrape", params=params)
|
||||
if isinstance(data, dict) and "error" in data:
|
||||
raise ArchiveError(f"{data.get('errorType', 'ScrapeError')}: {data['error']}")
|
||||
return data # keys: items, count, total, cursor (if more pages)
|
||||
|
||||
# ---------- metadata ----------
|
||||
|
||||
async def metadata(self, identifier: str) -> dict[str, Any]:
|
||||
"""Full metadata blob for an item."""
|
||||
r = await self._client.get(f"{self._base}/metadata/{identifier}")
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
"""Full metadata blob for an item. Empty {} from archive.org → not found."""
|
||||
validate_identifier(identifier)
|
||||
data = await self._fetch_json(f"{self._base}/metadata/{identifier}")
|
||||
if not data:
|
||||
raise ArchiveError(f"item not found: {identifier}")
|
||||
raise ArchiveError(f"item not found or unavailable: {identifier}")
|
||||
return data
|
||||
|
||||
async def files(self, identifier: str) -> list[dict[str, Any]]:
|
||||
"""Just the files[] slice — smaller payload when that's all you want."""
|
||||
r = await self._client.get(f"{self._base}/metadata/{identifier}/files")
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if isinstance(data, dict) and "result" in data:
|
||||
validate_identifier(identifier)
|
||||
data = await self._fetch_json(f"{self._base}/metadata/{identifier}/files")
|
||||
if isinstance(data, dict):
|
||||
if "error" in data:
|
||||
raise ArchiveError(f"archive.org error for {identifier}: {data['error']}")
|
||||
if "result" in data:
|
||||
return data["result"]
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
raise ArchiveError(f"unexpected files response for {identifier}")
|
||||
raise ArchiveError(f"unexpected files response shape for {identifier}: {type(data).__name__}")
|
||||
|
||||
# ---------- download ----------
|
||||
|
||||
def download_url(self, identifier: str, filename: str) -> str:
|
||||
return f"{self._base}/download/{identifier}/{filename}"
|
||||
"""Build the canonical download URL. Filename is URL-encoded but '/' preserved."""
|
||||
validate_identifier(identifier)
|
||||
validate_filename(filename)
|
||||
return f"{self._base}/download/{identifier}/{quote(filename, safe='/')}"
|
||||
|
||||
async def stream_file(
|
||||
self,
|
||||
@ -143,13 +241,40 @@ class ArchiveClient:
|
||||
filename: str,
|
||||
resume_from: int = 0,
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""Async byte iterator — caller is responsible for writing to disk."""
|
||||
"""Async byte iterator. If resume_from > 0, requires a 206 response.
|
||||
|
||||
Raises ArchiveError BEFORE yielding any bytes if:
|
||||
- the server returns a 4xx/5xx
|
||||
- resume was requested but the server returned 200 (Range ignored)
|
||||
- the Content-Range start byte doesn't match resume_from
|
||||
"""
|
||||
validate_identifier(identifier)
|
||||
validate_filename(filename)
|
||||
|
||||
headers = {}
|
||||
if resume_from > 0:
|
||||
headers["Range"] = f"bytes={resume_from}-"
|
||||
url = self.download_url(identifier, filename)
|
||||
|
||||
async with self._client.stream("GET", url, headers=headers) as r:
|
||||
r.raise_for_status()
|
||||
if r.is_error:
|
||||
body = (await r.aread())[:500].decode("utf-8", errors="replace")
|
||||
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
||||
if resume_from > 0:
|
||||
if r.status_code != 206:
|
||||
raise ArchiveError(
|
||||
f"server ignored Range request (got HTTP {r.status_code}); "
|
||||
f"local file may be stale — retry download_file with overwrite=True"
|
||||
)
|
||||
# Verify the byte range starts where we expect. archive.org's CDN
|
||||
# is normally well-behaved here, but trust-but-verify.
|
||||
cr = r.headers.get("Content-Range", "")
|
||||
m = re.match(r"bytes\s+(\d+)-", cr)
|
||||
if m and int(m.group(1)) != resume_from:
|
||||
raise ArchiveError(
|
||||
f"Content-Range start {m.group(1)} != resume_from {resume_from}; "
|
||||
f"refusing to corrupt {filename}"
|
||||
)
|
||||
async for chunk in r.aiter_bytes(chunk_size=1 << 16):
|
||||
yield chunk
|
||||
|
||||
@ -161,20 +286,34 @@ class ArchiveClient:
|
||||
verify_md5: str | None = None,
|
||||
chunk_cb=None,
|
||||
) -> dict[str, Any]:
|
||||
"""Download with resume support. Returns stats + md5 verification result."""
|
||||
"""Download with resume support. Returns stats + md5 verification result.
|
||||
|
||||
Caller is responsible for ensuring `dest` is confined to a safe directory
|
||||
(use mcarchive_org.server's path validation, or do your own). This method
|
||||
adds defense-in-depth via O_NOFOLLOW but does not validate path confinement.
|
||||
"""
|
||||
validate_identifier(identifier)
|
||||
validate_filename(filename)
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
resume_from = dest.stat().st_size if dest.exists() else 0
|
||||
# If parent is a symlink to elsewhere, refuse — our caller's confinement
|
||||
# check should already have caught this, but redundant defense is cheap.
|
||||
if dest.parent.is_symlink():
|
||||
raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}")
|
||||
|
||||
resume_from = dest.stat().st_size if dest.exists() and not dest.is_symlink() else 0
|
||||
# is_symlink() detection: if dest is a symlink, treat as fresh (open with
|
||||
# O_NOFOLLOW will refuse anyway, so we won't actually corrupt the target).
|
||||
|
||||
hasher = hashlib.md5() if verify_md5 else None
|
||||
if hasher and resume_from:
|
||||
# re-hash existing bytes so the final digest is correct
|
||||
with dest.open("rb") as f:
|
||||
while chunk := f.read(1 << 16):
|
||||
hasher.update(chunk)
|
||||
|
||||
bytes_written = resume_from
|
||||
mode = "ab" if resume_from else "wb"
|
||||
with dest.open(mode) as f:
|
||||
f = _safe_open_for_write(dest, append=resume_from > 0)
|
||||
try:
|
||||
async for chunk in self.stream_file(identifier, filename, resume_from=resume_from):
|
||||
f.write(chunk)
|
||||
bytes_written += len(chunk)
|
||||
@ -182,10 +321,12 @@ class ArchiveClient:
|
||||
hasher.update(chunk)
|
||||
if chunk_cb:
|
||||
chunk_cb(bytes_written)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
result = {
|
||||
result: dict[str, Any] = {
|
||||
"path": str(dest),
|
||||
"bytes": bytes_written,
|
||||
"bytes_written": bytes_written,
|
||||
"resumed_from": resume_from,
|
||||
}
|
||||
if verify_md5 and hasher:
|
||||
|
||||
@ -6,15 +6,24 @@ import fnmatch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from pydantic import Field
|
||||
|
||||
from mcarchive_org import __version__
|
||||
from mcarchive_org.client import ArchiveClient
|
||||
from mcarchive_org.client import (
|
||||
ArchiveClient,
|
||||
ArchiveError,
|
||||
validate_filename,
|
||||
validate_identifier,
|
||||
)
|
||||
|
||||
DEFAULT_DOWNLOAD_ROOT = Path(
|
||||
os.environ.get("MCARCHIVE_DOWNLOAD_ROOT", Path.cwd() / "downloads")
|
||||
|
||||
def _resolve_download_root() -> Path:
|
||||
"""Resolve the download root lazily so env-var changes after import are honored."""
|
||||
return Path(
|
||||
os.environ.get("MCARCHIVE_DOWNLOAD_ROOT", str(Path.cwd() / "downloads"))
|
||||
).expanduser()
|
||||
|
||||
mcp = FastMCP(
|
||||
@ -42,18 +51,29 @@ def _human_size(n: int | str | None) -> str:
|
||||
return f"{x:.1f} PB"
|
||||
|
||||
|
||||
def _parse_size(raw: Any) -> int | None:
|
||||
"""Best-effort int parse of archive.org size fields. None if unparseable."""
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return int(str(raw).strip())
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _enrich_file(identifier: str, f: dict[str, Any]) -> dict[str, Any]:
|
||||
name = f.get("name", "")
|
||||
size = _parse_size(f.get("size"))
|
||||
return {
|
||||
"name": name,
|
||||
"format": f.get("format"),
|
||||
"size": int(f["size"]) if f.get("size") and str(f["size"]).isdigit() else None,
|
||||
"size_human": _human_size(f.get("size")),
|
||||
"size": size,
|
||||
"size_human": _human_size(size if size is not None else f.get("size")),
|
||||
"md5": f.get("md5"),
|
||||
"sha1": f.get("sha1"),
|
||||
"mtime": f.get("mtime"),
|
||||
"source": f.get("source"),
|
||||
"download_url": f"https://archive.org/download/{identifier}/{name}",
|
||||
"download_url": f"https://archive.org/download/{identifier}/{quote(name, safe='/')}",
|
||||
}
|
||||
|
||||
|
||||
@ -63,6 +83,36 @@ def _matches(name: str, format_: str | None, name_glob: str | None, formats: lis
|
||||
return not (formats and (format_ or "").lower() not in {f.lower() for f in formats})
|
||||
|
||||
|
||||
def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path:
|
||||
"""Construct + verify the download destination path.
|
||||
|
||||
Validates inputs, resolves the path, and asserts it lives inside the allowed
|
||||
download root. Raises ValueError on any escape attempt — never returns an
|
||||
unsafe path.
|
||||
"""
|
||||
validate_identifier(identifier)
|
||||
validate_filename(filename)
|
||||
|
||||
download_root = _resolve_download_root().resolve()
|
||||
if dest_dir:
|
||||
target_dir = Path(dest_dir).expanduser().resolve()
|
||||
else:
|
||||
target_dir = (download_root / identifier).resolve()
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
if target_dir.is_symlink():
|
||||
raise ValueError(f"download directory must not be a symlink: {target_dir}")
|
||||
|
||||
# Build dest, then resolve and confirm containment. Path.resolve() collapses
|
||||
# '..' even on non-existent leaves, so escape attempts surface here.
|
||||
dest = (target_dir / filename).resolve()
|
||||
if not dest.is_relative_to(target_dir):
|
||||
raise ValueError(
|
||||
f"refusing destination outside target dir: {dest} not under {target_dir}"
|
||||
)
|
||||
return dest
|
||||
|
||||
|
||||
# ---------- tools ----------
|
||||
|
||||
|
||||
@ -132,6 +182,13 @@ async def get_item_metadata(
|
||||
|
||||
By default omits the (potentially huge) files[] array — call list_files for that.
|
||||
"""
|
||||
return await _fetch_item_metadata(identifier, include_files=include_files)
|
||||
|
||||
|
||||
async def _fetch_item_metadata(identifier: str, include_files: bool) -> dict[str, Any]:
|
||||
"""Shared metadata-fetching logic. Used by both the tool and the MCP resource
|
||||
so neither has to depend on the other's `.fn` attribute."""
|
||||
validate_identifier(identifier)
|
||||
async with ArchiveClient() as c:
|
||||
data = await c.metadata(identifier)
|
||||
|
||||
@ -177,6 +234,7 @@ async def list_files(
|
||||
|
||||
Each entry includes a ready-to-use `download_url`.
|
||||
"""
|
||||
validate_identifier(identifier)
|
||||
async with ArchiveClient() as c:
|
||||
files = await c.files(identifier)
|
||||
|
||||
@ -199,8 +257,10 @@ def get_file_url(
|
||||
filename: Annotated[str, Field(description="Exact filename as shown in list_files.")],
|
||||
) -> dict[str, str]:
|
||||
"""Build the canonical download URL for a file without fetching anything."""
|
||||
validate_identifier(identifier)
|
||||
validate_filename(filename)
|
||||
return {
|
||||
"url": f"https://archive.org/download/{identifier}/{filename}",
|
||||
"url": f"https://archive.org/download/{identifier}/{quote(filename, safe='/')}",
|
||||
"item_url": f"https://archive.org/details/{identifier}",
|
||||
}
|
||||
|
||||
@ -222,18 +282,29 @@ async def download_file(
|
||||
Field(description="If false and file exists, resume the download (Range request)."),
|
||||
] = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Download a file to disk. Supports resume via HTTP Range when overwrite=false."""
|
||||
target_dir = Path(dest_dir).expanduser() if dest_dir else (DEFAULT_DOWNLOAD_ROOT / identifier)
|
||||
dest = target_dir / filename
|
||||
if overwrite and dest.exists():
|
||||
"""Download a file to disk. Supports resume via HTTP Range when overwrite=false.
|
||||
|
||||
The destination is path-confined to either `dest_dir` (when given) or
|
||||
$MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute
|
||||
paths, or NUL bytes are rejected before any FS or network I/O.
|
||||
"""
|
||||
dest = _confine_dest(identifier, filename, dest_dir)
|
||||
if overwrite and dest.exists() and not dest.is_symlink():
|
||||
dest.unlink()
|
||||
elif overwrite and dest.is_symlink():
|
||||
# Don't follow the symlink to delete the target; remove the link itself.
|
||||
dest.unlink()
|
||||
|
||||
try:
|
||||
async with ArchiveClient() as c:
|
||||
result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5)
|
||||
except ArchiveError as e:
|
||||
# Re-raise with the destination context so the caller can act on it.
|
||||
raise ArchiveError(f"{e} (dest={dest})") from e
|
||||
|
||||
result["identifier"] = identifier
|
||||
result["filename"] = filename
|
||||
result["size_human"] = _human_size(result.get("bytes"))
|
||||
result["size_human"] = _human_size(result.get("bytes_written"))
|
||||
return result
|
||||
|
||||
|
||||
@ -243,7 +314,7 @@ async def download_file(
|
||||
@mcp.resource("archive://item/{identifier}")
|
||||
async def item_resource(identifier: str) -> dict[str, Any]:
|
||||
"""Expose item metadata as a readable MCP resource."""
|
||||
return await get_item_metadata.fn(identifier=identifier, include_files=False) # type: ignore[attr-defined]
|
||||
return await _fetch_item_metadata(identifier, include_files=False)
|
||||
|
||||
|
||||
# ---------- entry point ----------
|
||||
|
||||
@ -12,7 +12,7 @@ import pytest
|
||||
|
||||
from mcarchive_org.client import ArchiveClient
|
||||
|
||||
pytestmark = [pytest.mark.asyncio, pytest.mark.network]
|
||||
pytestmark = pytest.mark.network
|
||||
|
||||
|
||||
async def test_search_nasa_item():
|
||||
@ -41,7 +41,7 @@ async def test_download_small_file(tmp_path: Path):
|
||||
result = await c.download_to_file(
|
||||
"nasa", small["name"], dest, verify_md5=small.get("md5")
|
||||
)
|
||||
assert result["bytes"] > 0
|
||||
assert result["bytes_written"] > 0
|
||||
if small.get("md5"):
|
||||
assert result["md5_ok"] is True
|
||||
|
||||
|
||||
242
tests/test_client_mocked.py
Normal file
242
tests/test_client_mocked.py
Normal file
@ -0,0 +1,242 @@
|
||||
"""Failure-mode regression tests using httpx.MockTransport (no network).
|
||||
|
||||
Each test pins down one of the Hamilton review findings (C1/C2/C3/H4 etc.) so
|
||||
future refactors can't silently regress safety.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from mcarchive_org.client import (
|
||||
ArchiveClient,
|
||||
ArchiveError,
|
||||
validate_filename,
|
||||
validate_identifier,
|
||||
)
|
||||
from mcarchive_org.server import _confine_dest
|
||||
|
||||
|
||||
def _client_with(handler) -> ArchiveClient:
|
||||
"""Build an ArchiveClient backed by a MockTransport handler."""
|
||||
return ArchiveClient(transport=httpx.MockTransport(handler))
|
||||
|
||||
|
||||
# ---------- C1: identifier + filename validation ----------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad", ["", "../etc", "foo/bar", "has space", "a" * 200])
|
||||
def test_invalid_identifier_rejected(bad):
|
||||
with pytest.raises(ValueError, match=r"invalid archive\.org identifier"):
|
||||
validate_identifier(bad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad",
|
||||
[
|
||||
"../escape.txt",
|
||||
"/etc/passwd",
|
||||
"C:\\windows.txt",
|
||||
"with\x00null.bin",
|
||||
"foo/../bar.mp3",
|
||||
"foo\\..\\bar.mp3",
|
||||
"",
|
||||
],
|
||||
)
|
||||
def test_invalid_filename_rejected(bad):
|
||||
with pytest.raises(ValueError):
|
||||
validate_filename(bad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ok",
|
||||
["song.mp3", "cover/back.jpg", "subdir/file with space.txt", "a.b.c.d"],
|
||||
)
|
||||
def test_legitimate_filenames_accepted(ok):
|
||||
assert validate_filename(ok) == ok
|
||||
|
||||
|
||||
def test_confine_dest_blocks_traversal(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||
# validate_filename catches '..' before _confine_dest's path-resolution check,
|
||||
# so this raises ValueError from the validator — both layers in agreement.
|
||||
with pytest.raises(ValueError):
|
||||
_confine_dest("nasa", "../escape.txt", dest_dir=None)
|
||||
|
||||
|
||||
def test_confine_dest_legit_filename_lands_in_root(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||
dest = _confine_dest("nasa", "globe.jpg", dest_dir=None)
|
||||
assert dest.is_relative_to(tmp_path)
|
||||
assert dest.name == "globe.jpg"
|
||||
|
||||
|
||||
# ---------- C2: symlink refusal ----------
|
||||
|
||||
|
||||
async def test_download_refuses_symlink_at_dest(tmp_path):
|
||||
target = tmp_path / "real.bin"
|
||||
target.write_bytes(b"original-content")
|
||||
|
||||
link = tmp_path / "evil.bin"
|
||||
link.symlink_to(target)
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, content=b"new-content-that-should-not-overwrite")
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="symlink"):
|
||||
await c.download_to_file("nasa", "evil.bin", link)
|
||||
|
||||
# Symlink target must be unchanged.
|
||||
assert target.read_bytes() == b"original-content"
|
||||
|
||||
|
||||
# ---------- C3: Range-ignored detection ----------
|
||||
|
||||
|
||||
async def test_resume_with_200_response_raises_before_writing(tmp_path):
|
||||
"""If the server returns 200 instead of 206 on a Range request, we must not
|
||||
append to the existing file — that path corrupts data silently."""
|
||||
dest = tmp_path / "partial.bin"
|
||||
dest.write_bytes(b"X" * 100) # pretend we have a partial download
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
# Server ignores Range header and returns the full body with 200
|
||||
assert req.headers.get("Range") == "bytes=100-"
|
||||
return httpx.Response(200, content=b"FULL_FILE_BODY")
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="ignored Range"):
|
||||
await c.download_to_file("nasa", "partial.bin", dest)
|
||||
|
||||
# File must be unchanged — corruption avoided.
|
||||
assert dest.read_bytes() == b"X" * 100
|
||||
|
||||
|
||||
async def test_resume_with_correct_206_succeeds(tmp_path):
|
||||
full_body = b"0123456789ABCDEF" * 16 # 256 bytes
|
||||
dest = tmp_path / "resume.bin"
|
||||
dest.write_bytes(full_body[:64]) # we already have first 64 bytes
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
assert req.headers.get("Range") == "bytes=64-"
|
||||
return httpx.Response(
|
||||
206,
|
||||
content=full_body[64:],
|
||||
headers={"Content-Range": f"bytes 64-{len(full_body)-1}/{len(full_body)}"},
|
||||
)
|
||||
|
||||
expected_md5 = hashlib.md5(full_body).hexdigest()
|
||||
async with _client_with(handler) as c:
|
||||
result = await c.download_to_file(
|
||||
"nasa", "resume.bin", dest, verify_md5=expected_md5
|
||||
)
|
||||
|
||||
assert result["bytes_written"] == len(full_body)
|
||||
assert result["resumed_from"] == 64
|
||||
assert result["md5_ok"] is True
|
||||
assert dest.read_bytes() == full_body
|
||||
|
||||
|
||||
async def test_resume_with_wrong_content_range_start_raises(tmp_path):
|
||||
dest = tmp_path / "off.bin"
|
||||
dest.write_bytes(b"X" * 100)
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
# Server returns 206 but with WRONG starting offset
|
||||
return httpx.Response(
|
||||
206,
|
||||
content=b"junk",
|
||||
headers={"Content-Range": "bytes 50-99/100"},
|
||||
)
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="Content-Range start"):
|
||||
await c.download_to_file("nasa", "off.bin", dest)
|
||||
|
||||
assert dest.read_bytes() == b"X" * 100 # unchanged
|
||||
|
||||
|
||||
# ---------- H4: error body surfacing ----------
|
||||
|
||||
|
||||
async def test_search_400_includes_response_body():
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(400, text='{"error":"bad query syntax"}')
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="bad query syntax"):
|
||||
await c.search(query="INVALID:::")
|
||||
|
||||
|
||||
async def test_metadata_404_includes_status():
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(404, text="not found")
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="HTTP 404"):
|
||||
await c.metadata("nasa")
|
||||
|
||||
|
||||
async def test_metadata_empty_dict_means_not_found():
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={})
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="not found or unavailable"):
|
||||
await c.metadata("nasa")
|
||||
|
||||
|
||||
async def test_files_returns_error_payload_as_archive_error():
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"error": "item is dark"})
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="item is dark"):
|
||||
await c.files("nasa")
|
||||
|
||||
|
||||
async def test_scrape_error_payload_surfaced():
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
200, json={"error": "count too small", "errorType": "RangeException"}
|
||||
)
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match=r"RangeException.*count too small"):
|
||||
await c.scrape(query="identifier:nasa", count=100)
|
||||
|
||||
|
||||
async def test_invalid_json_response_surfaced():
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, text="<html>not json</html>")
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
with pytest.raises(ArchiveError, match="invalid JSON"):
|
||||
await c.metadata("nasa")
|
||||
|
||||
|
||||
# ---------- happy path ----------
|
||||
|
||||
|
||||
async def test_fresh_download_writes_full_body(tmp_path):
|
||||
body = b"hello world" * 100
|
||||
dest = tmp_path / "new.bin"
|
||||
|
||||
def handler(req: httpx.Request) -> httpx.Response:
|
||||
assert "Range" not in req.headers
|
||||
return httpx.Response(200, content=body)
|
||||
|
||||
async with _client_with(handler) as c:
|
||||
result = await c.download_to_file(
|
||||
"nasa", "new.bin", dest, verify_md5=hashlib.md5(body).hexdigest()
|
||||
)
|
||||
|
||||
assert result["bytes_written"] == len(body)
|
||||
assert result["resumed_from"] == 0
|
||||
assert result["md5_ok"] is True
|
||||
assert dest.read_bytes() == body
|
||||
Loading…
x
Reference in New Issue
Block a user