Hardening: address Hamilton review ship-blockers
Critical fixes:
- Validate identifier (^[A-Za-z0-9._-]+$) and filename (no '..', absolute
paths, NUL bytes, drive letters) at the client boundary
- Confine download destinations under MCARCHIVE_DOWNLOAD_ROOT via
Path.resolve() + is_relative_to() check; reject symlinked dirs
- Use O_NOFOLLOW on the destination open() to refuse symlink substitution
- Detect Range-ignored responses: if resume requested but server returns 200
(or 206 with wrong Content-Range start), raise ArchiveError BEFORE writing
any bytes — closes the silent file-corruption hole
Usability:
- Wrap raise_for_status everywhere with ArchiveError that includes the
response body preview — 4xx Solr errors now tell you what's wrong
- URL-encode filenames in download URLs (handles spaces and special chars)
- Map archive.org's {"error": ...} payloads on /metadata/{id}/files to
ArchiveError with the server's message
- Lazy-resolve download root so env-var changes after import are honored
- Refactor item_resource to a shared async helper (drops .fn type-ignore)
- Rename result key 'bytes' -> 'bytes_written' (avoids shadowing builtin)
Tests:
- New tests/test_client_mocked.py: 29 regression tests using
httpx.MockTransport covering every Hamilton finding above (path traversal,
symlink refusal, Range-ignored, Content-Range mismatch, error body
surfacing, malformed JSON, dark items, etc.)
- Set asyncio_mode = "auto" in pyproject for cleaner test markers
33/33 tests pass (4 live + 29 mocked), ruff clean.
This commit is contained in:
parent
5265a6440b
commit
4a03af1675
@ -38,6 +38,9 @@ build-backend = "hatchling.build"
|
|||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src/mcarchive_org"]
|
packages = ["src/mcarchive_org"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
|
|||||||
@ -2,29 +2,104 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import errno
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
ARCHIVE_BASE = "https://archive.org"
|
ARCHIVE_BASE = "https://archive.org"
|
||||||
DEFAULT_UA = "mcarchive-org/2026.04.21 (+https://archive.org/developers/)"
|
DEFAULT_UA = "mcarchive-org/2026.04.21 (+https://archive.org/developers/)"
|
||||||
|
# Per-chunk read timeout (60s) means a stalled stream is caught between chunks.
|
||||||
|
# Don't relax this without thinking about hung TCP connections.
|
||||||
DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0)
|
DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0)
|
||||||
|
|
||||||
|
# Archive.org documents identifiers as [A-Za-z0-9._-], starting with alnum, max 100 chars.
|
||||||
|
_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$")
|
||||||
|
_MAX_FILENAME = 512
|
||||||
|
|
||||||
|
|
||||||
class ArchiveError(RuntimeError):
|
class ArchiveError(RuntimeError):
|
||||||
"""Raised when archive.org returns an error payload or unexpected status."""
|
"""Raised when archive.org returns an error payload or unexpected status."""
|
||||||
|
|
||||||
|
|
||||||
class ArchiveClient:
|
# ---------- input validators (defense-in-depth) ----------
|
||||||
"""Async client for the three archive.org endpoints we care about.
|
|
||||||
|
|
||||||
- advancedsearch.php : small Solr-style queries (<= ~10,000 rows paginated)
|
|
||||||
- services/search/v1/scrape : bulk cursor-based iteration (count >= 100)
|
def validate_identifier(identifier: str) -> str:
|
||||||
|
"""Reject identifiers that don't match archive.org's documented grammar.
|
||||||
|
|
||||||
|
Both archive.org and the local filesystem trust this string — '/' or '..'
|
||||||
|
in an identifier is a path-traversal vector.
|
||||||
|
"""
|
||||||
|
if not isinstance(identifier, str) or not _IDENTIFIER_RE.match(identifier):
|
||||||
|
raise ValueError(
|
||||||
|
f"invalid archive.org identifier: {identifier!r} "
|
||||||
|
f"(must match {_IDENTIFIER_RE.pattern})"
|
||||||
|
)
|
||||||
|
return identifier
|
||||||
|
|
||||||
|
|
||||||
|
def validate_filename(filename: str) -> str:
|
||||||
|
"""Reject filenames that could escape a download root or attack the FS.
|
||||||
|
|
||||||
|
archive.org files[].name CAN contain forward slashes for subdirectory files
|
||||||
|
(e.g. 'cover/back.jpg') — that's allowed. What's rejected: '..' components,
|
||||||
|
absolute paths, NUL bytes, Windows drive letters, and excessive length.
|
||||||
|
"""
|
||||||
|
if not isinstance(filename, str) or not filename:
|
||||||
|
raise ValueError(f"filename must be a non-empty string, got {filename!r}")
|
||||||
|
if len(filename) > _MAX_FILENAME:
|
||||||
|
raise ValueError(f"filename exceeds {_MAX_FILENAME} chars")
|
||||||
|
if "\x00" in filename:
|
||||||
|
raise ValueError(f"filename contains NUL byte: {filename!r}")
|
||||||
|
if filename.startswith(("/", "\\")):
|
||||||
|
raise ValueError(f"filename must not be absolute: {filename!r}")
|
||||||
|
if len(filename) >= 2 and filename[1] == ":":
|
||||||
|
raise ValueError(f"filename must not be a Windows drive path: {filename!r}")
|
||||||
|
if any(part == ".." for part in filename.replace("\\", "/").split("/")):
|
||||||
|
raise ValueError(f"filename must not contain '..' components: {filename!r}")
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- safe filesystem open ----------
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_open_for_write(dest: Path, append: bool):
|
||||||
|
"""Open dest for writing, refusing to follow a symlink at the leaf.
|
||||||
|
|
||||||
|
Defense against the symlink-substitution race: even if our path-confinement
|
||||||
|
check passes, a symlink at `dest` could redirect the write. O_NOFOLLOW tells
|
||||||
|
the kernel to fail the open instead.
|
||||||
|
"""
|
||||||
|
flags = os.O_WRONLY | os.O_CREAT
|
||||||
|
flags |= os.O_APPEND if append else os.O_TRUNC
|
||||||
|
nofollow = getattr(os, "O_NOFOLLOW", 0) # not present on Windows
|
||||||
|
flags |= nofollow
|
||||||
|
try:
|
||||||
|
fd = os.open(dest, flags, 0o644)
|
||||||
|
except OSError as e:
|
||||||
|
if nofollow and e.errno == errno.ELOOP:
|
||||||
|
raise ArchiveError(f"refusing to write through symlink at {dest}") from e
|
||||||
|
raise
|
||||||
|
return os.fdopen(fd, "ab" if append else "wb")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- client ----------
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveClient:
|
||||||
|
"""Async client for the archive.org endpoints we wrap.
|
||||||
|
|
||||||
|
- advancedsearch.php : Solr-style queries (<= ~10,000 rows paginated)
|
||||||
|
- services/search/v1/scrape : bulk cursor pagination (count >= 100)
|
||||||
- metadata/{id} : full item manifest including files[]
|
- metadata/{id} : full item manifest including files[]
|
||||||
- download/{id}/{file} : byte stream with Range support
|
- download/{id}/{file} : byte stream with HTTP Range support
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -32,13 +107,17 @@ class ArchiveClient:
|
|||||||
base_url: str = ARCHIVE_BASE,
|
base_url: str = ARCHIVE_BASE,
|
||||||
user_agent: str = DEFAULT_UA,
|
user_agent: str = DEFAULT_UA,
|
||||||
timeout: httpx.Timeout | float = DEFAULT_TIMEOUT,
|
timeout: httpx.Timeout | float = DEFAULT_TIMEOUT,
|
||||||
|
transport: httpx.AsyncBaseTransport | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._base = base_url.rstrip("/")
|
self._base = base_url.rstrip("/")
|
||||||
self._client = httpx.AsyncClient(
|
kwargs: dict[str, Any] = {
|
||||||
headers={"User-Agent": user_agent, "Accept": "application/json"},
|
"headers": {"User-Agent": user_agent, "Accept": "application/json"},
|
||||||
timeout=timeout,
|
"timeout": timeout,
|
||||||
follow_redirects=True,
|
"follow_redirects": True,
|
||||||
)
|
}
|
||||||
|
if transport is not None:
|
||||||
|
kwargs["transport"] = transport
|
||||||
|
self._client = httpx.AsyncClient(**kwargs)
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
@ -49,6 +128,25 @@ class ArchiveClient:
|
|||||||
async def __aexit__(self, *exc: object) -> None:
|
async def __aexit__(self, *exc: object) -> None:
|
||||||
await self.aclose()
|
await self.aclose()
|
||||||
|
|
||||||
|
# ---------- internal: error-surfacing fetch ----------
|
||||||
|
|
||||||
|
async def _fetch_json(self, url: str, params: Any = None) -> Any:
|
||||||
|
"""GET + JSON decode with archive.org-friendly error messages.
|
||||||
|
|
||||||
|
Wraps raise_for_status so that 4xx/5xx responses include a body preview
|
||||||
|
in the exception — invaluable for an LLM trying to fix a bad query.
|
||||||
|
"""
|
||||||
|
r = await self._client.get(url, params=params)
|
||||||
|
if r.is_error:
|
||||||
|
body = r.text[:500] if r.content else ""
|
||||||
|
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
||||||
|
try:
|
||||||
|
return r.json()
|
||||||
|
except ValueError as e:
|
||||||
|
raise ArchiveError(
|
||||||
|
f"invalid JSON from {r.url}: {r.text[:200]!r}"
|
||||||
|
) from e
|
||||||
|
|
||||||
# ---------- search ----------
|
# ---------- search ----------
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
@ -71,9 +169,7 @@ class ArchiveClient:
|
|||||||
for s in sort or []:
|
for s in sort or []:
|
||||||
params.append(("sort[]", s))
|
params.append(("sort[]", s))
|
||||||
|
|
||||||
r = await self._client.get(f"{self._base}/advancedsearch.php", params=params)
|
data = await self._fetch_json(f"{self._base}/advancedsearch.php", params=params)
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
resp = data.get("response", {})
|
resp = data.get("response", {})
|
||||||
return {
|
return {
|
||||||
"num_found": resp.get("numFound", 0),
|
"num_found": resp.get("numFound", 0),
|
||||||
@ -103,39 +199,41 @@ class ArchiveClient:
|
|||||||
if cursor:
|
if cursor:
|
||||||
params["cursor"] = cursor
|
params["cursor"] = cursor
|
||||||
|
|
||||||
r = await self._client.get(f"{self._base}/services/search/v1/scrape", params=params)
|
data = await self._fetch_json(f"{self._base}/services/search/v1/scrape", params=params)
|
||||||
r.raise_for_status()
|
if isinstance(data, dict) and "error" in data:
|
||||||
data = r.json()
|
|
||||||
if "error" in data:
|
|
||||||
raise ArchiveError(f"{data.get('errorType', 'ScrapeError')}: {data['error']}")
|
raise ArchiveError(f"{data.get('errorType', 'ScrapeError')}: {data['error']}")
|
||||||
return data # keys: items, count, total, cursor (if more pages)
|
return data # keys: items, count, total, cursor (if more pages)
|
||||||
|
|
||||||
# ---------- metadata ----------
|
# ---------- metadata ----------
|
||||||
|
|
||||||
async def metadata(self, identifier: str) -> dict[str, Any]:
|
async def metadata(self, identifier: str) -> dict[str, Any]:
|
||||||
"""Full metadata blob for an item."""
|
"""Full metadata blob for an item. Empty {} from archive.org → not found."""
|
||||||
r = await self._client.get(f"{self._base}/metadata/{identifier}")
|
validate_identifier(identifier)
|
||||||
r.raise_for_status()
|
data = await self._fetch_json(f"{self._base}/metadata/{identifier}")
|
||||||
data = r.json()
|
|
||||||
if not data:
|
if not data:
|
||||||
raise ArchiveError(f"item not found: {identifier}")
|
raise ArchiveError(f"item not found or unavailable: {identifier}")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def files(self, identifier: str) -> list[dict[str, Any]]:
|
async def files(self, identifier: str) -> list[dict[str, Any]]:
|
||||||
"""Just the files[] slice — smaller payload when that's all you want."""
|
"""Just the files[] slice — smaller payload when that's all you want."""
|
||||||
r = await self._client.get(f"{self._base}/metadata/{identifier}/files")
|
validate_identifier(identifier)
|
||||||
r.raise_for_status()
|
data = await self._fetch_json(f"{self._base}/metadata/{identifier}/files")
|
||||||
data = r.json()
|
if isinstance(data, dict):
|
||||||
if isinstance(data, dict) and "result" in data:
|
if "error" in data:
|
||||||
|
raise ArchiveError(f"archive.org error for {identifier}: {data['error']}")
|
||||||
|
if "result" in data:
|
||||||
return data["result"]
|
return data["result"]
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
return data
|
return data
|
||||||
raise ArchiveError(f"unexpected files response for {identifier}")
|
raise ArchiveError(f"unexpected files response shape for {identifier}: {type(data).__name__}")
|
||||||
|
|
||||||
# ---------- download ----------
|
# ---------- download ----------
|
||||||
|
|
||||||
def download_url(self, identifier: str, filename: str) -> str:
|
def download_url(self, identifier: str, filename: str) -> str:
|
||||||
return f"{self._base}/download/{identifier}/{filename}"
|
"""Build the canonical download URL. Filename is URL-encoded but '/' preserved."""
|
||||||
|
validate_identifier(identifier)
|
||||||
|
validate_filename(filename)
|
||||||
|
return f"{self._base}/download/{identifier}/{quote(filename, safe='/')}"
|
||||||
|
|
||||||
async def stream_file(
|
async def stream_file(
|
||||||
self,
|
self,
|
||||||
@ -143,13 +241,40 @@ class ArchiveClient:
|
|||||||
filename: str,
|
filename: str,
|
||||||
resume_from: int = 0,
|
resume_from: int = 0,
|
||||||
) -> AsyncIterator[bytes]:
|
) -> AsyncIterator[bytes]:
|
||||||
"""Async byte iterator — caller is responsible for writing to disk."""
|
"""Async byte iterator. If resume_from > 0, requires a 206 response.
|
||||||
|
|
||||||
|
Raises ArchiveError BEFORE yielding any bytes if:
|
||||||
|
- the server returns a 4xx/5xx
|
||||||
|
- resume was requested but the server returned 200 (Range ignored)
|
||||||
|
- the Content-Range start byte doesn't match resume_from
|
||||||
|
"""
|
||||||
|
validate_identifier(identifier)
|
||||||
|
validate_filename(filename)
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
if resume_from > 0:
|
if resume_from > 0:
|
||||||
headers["Range"] = f"bytes={resume_from}-"
|
headers["Range"] = f"bytes={resume_from}-"
|
||||||
url = self.download_url(identifier, filename)
|
url = self.download_url(identifier, filename)
|
||||||
|
|
||||||
async with self._client.stream("GET", url, headers=headers) as r:
|
async with self._client.stream("GET", url, headers=headers) as r:
|
||||||
r.raise_for_status()
|
if r.is_error:
|
||||||
|
body = (await r.aread())[:500].decode("utf-8", errors="replace")
|
||||||
|
raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip())
|
||||||
|
if resume_from > 0:
|
||||||
|
if r.status_code != 206:
|
||||||
|
raise ArchiveError(
|
||||||
|
f"server ignored Range request (got HTTP {r.status_code}); "
|
||||||
|
f"local file may be stale — retry download_file with overwrite=True"
|
||||||
|
)
|
||||||
|
# Verify the byte range starts where we expect. archive.org's CDN
|
||||||
|
# is normally well-behaved here, but trust-but-verify.
|
||||||
|
cr = r.headers.get("Content-Range", "")
|
||||||
|
m = re.match(r"bytes\s+(\d+)-", cr)
|
||||||
|
if m and int(m.group(1)) != resume_from:
|
||||||
|
raise ArchiveError(
|
||||||
|
f"Content-Range start {m.group(1)} != resume_from {resume_from}; "
|
||||||
|
f"refusing to corrupt {filename}"
|
||||||
|
)
|
||||||
async for chunk in r.aiter_bytes(chunk_size=1 << 16):
|
async for chunk in r.aiter_bytes(chunk_size=1 << 16):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@ -161,20 +286,34 @@ class ArchiveClient:
|
|||||||
verify_md5: str | None = None,
|
verify_md5: str | None = None,
|
||||||
chunk_cb=None,
|
chunk_cb=None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Download with resume support. Returns stats + md5 verification result."""
|
"""Download with resume support. Returns stats + md5 verification result.
|
||||||
|
|
||||||
|
Caller is responsible for ensuring `dest` is confined to a safe directory
|
||||||
|
(use mcarchive_org.server's path validation, or do your own). This method
|
||||||
|
adds defense-in-depth via O_NOFOLLOW but does not validate path confinement.
|
||||||
|
"""
|
||||||
|
validate_identifier(identifier)
|
||||||
|
validate_filename(filename)
|
||||||
|
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
resume_from = dest.stat().st_size if dest.exists() else 0
|
# If parent is a symlink to elsewhere, refuse — our caller's confinement
|
||||||
|
# check should already have caught this, but redundant defense is cheap.
|
||||||
|
if dest.parent.is_symlink():
|
||||||
|
raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}")
|
||||||
|
|
||||||
|
resume_from = dest.stat().st_size if dest.exists() and not dest.is_symlink() else 0
|
||||||
|
# is_symlink() detection: if dest is a symlink, treat as fresh (open with
|
||||||
|
# O_NOFOLLOW will refuse anyway, so we won't actually corrupt the target).
|
||||||
|
|
||||||
hasher = hashlib.md5() if verify_md5 else None
|
hasher = hashlib.md5() if verify_md5 else None
|
||||||
if hasher and resume_from:
|
if hasher and resume_from:
|
||||||
# re-hash existing bytes so the final digest is correct
|
|
||||||
with dest.open("rb") as f:
|
with dest.open("rb") as f:
|
||||||
while chunk := f.read(1 << 16):
|
while chunk := f.read(1 << 16):
|
||||||
hasher.update(chunk)
|
hasher.update(chunk)
|
||||||
|
|
||||||
bytes_written = resume_from
|
bytes_written = resume_from
|
||||||
mode = "ab" if resume_from else "wb"
|
f = _safe_open_for_write(dest, append=resume_from > 0)
|
||||||
with dest.open(mode) as f:
|
try:
|
||||||
async for chunk in self.stream_file(identifier, filename, resume_from=resume_from):
|
async for chunk in self.stream_file(identifier, filename, resume_from=resume_from):
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
bytes_written += len(chunk)
|
bytes_written += len(chunk)
|
||||||
@ -182,10 +321,12 @@ class ArchiveClient:
|
|||||||
hasher.update(chunk)
|
hasher.update(chunk)
|
||||||
if chunk_cb:
|
if chunk_cb:
|
||||||
chunk_cb(bytes_written)
|
chunk_cb(bytes_written)
|
||||||
|
finally:
|
||||||
|
f.close()
|
||||||
|
|
||||||
result = {
|
result: dict[str, Any] = {
|
||||||
"path": str(dest),
|
"path": str(dest),
|
||||||
"bytes": bytes_written,
|
"bytes_written": bytes_written,
|
||||||
"resumed_from": resume_from,
|
"resumed_from": resume_from,
|
||||||
}
|
}
|
||||||
if verify_md5 and hasher:
|
if verify_md5 and hasher:
|
||||||
|
|||||||
@ -6,15 +6,24 @@ import fnmatch
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from mcarchive_org import __version__
|
from mcarchive_org import __version__
|
||||||
from mcarchive_org.client import ArchiveClient
|
from mcarchive_org.client import (
|
||||||
|
ArchiveClient,
|
||||||
|
ArchiveError,
|
||||||
|
validate_filename,
|
||||||
|
validate_identifier,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_DOWNLOAD_ROOT = Path(
|
|
||||||
os.environ.get("MCARCHIVE_DOWNLOAD_ROOT", Path.cwd() / "downloads")
|
def _resolve_download_root() -> Path:
|
||||||
|
"""Resolve the download root lazily so env-var changes after import are honored."""
|
||||||
|
return Path(
|
||||||
|
os.environ.get("MCARCHIVE_DOWNLOAD_ROOT", str(Path.cwd() / "downloads"))
|
||||||
).expanduser()
|
).expanduser()
|
||||||
|
|
||||||
mcp = FastMCP(
|
mcp = FastMCP(
|
||||||
@ -42,18 +51,29 @@ def _human_size(n: int | str | None) -> str:
|
|||||||
return f"{x:.1f} PB"
|
return f"{x:.1f} PB"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_size(raw: Any) -> int | None:
|
||||||
|
"""Best-effort int parse of archive.org size fields. None if unparseable."""
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return int(str(raw).strip())
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _enrich_file(identifier: str, f: dict[str, Any]) -> dict[str, Any]:
|
def _enrich_file(identifier: str, f: dict[str, Any]) -> dict[str, Any]:
|
||||||
name = f.get("name", "")
|
name = f.get("name", "")
|
||||||
|
size = _parse_size(f.get("size"))
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
"format": f.get("format"),
|
"format": f.get("format"),
|
||||||
"size": int(f["size"]) if f.get("size") and str(f["size"]).isdigit() else None,
|
"size": size,
|
||||||
"size_human": _human_size(f.get("size")),
|
"size_human": _human_size(size if size is not None else f.get("size")),
|
||||||
"md5": f.get("md5"),
|
"md5": f.get("md5"),
|
||||||
"sha1": f.get("sha1"),
|
"sha1": f.get("sha1"),
|
||||||
"mtime": f.get("mtime"),
|
"mtime": f.get("mtime"),
|
||||||
"source": f.get("source"),
|
"source": f.get("source"),
|
||||||
"download_url": f"https://archive.org/download/{identifier}/{name}",
|
"download_url": f"https://archive.org/download/{identifier}/{quote(name, safe='/')}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -63,6 +83,36 @@ def _matches(name: str, format_: str | None, name_glob: str | None, formats: lis
|
|||||||
return not (formats and (format_ or "").lower() not in {f.lower() for f in formats})
|
return not (formats and (format_ or "").lower() not in {f.lower() for f in formats})
|
||||||
|
|
||||||
|
|
||||||
|
def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path:
|
||||||
|
"""Construct + verify the download destination path.
|
||||||
|
|
||||||
|
Validates inputs, resolves the path, and asserts it lives inside the allowed
|
||||||
|
download root. Raises ValueError on any escape attempt — never returns an
|
||||||
|
unsafe path.
|
||||||
|
"""
|
||||||
|
validate_identifier(identifier)
|
||||||
|
validate_filename(filename)
|
||||||
|
|
||||||
|
download_root = _resolve_download_root().resolve()
|
||||||
|
if dest_dir:
|
||||||
|
target_dir = Path(dest_dir).expanduser().resolve()
|
||||||
|
else:
|
||||||
|
target_dir = (download_root / identifier).resolve()
|
||||||
|
|
||||||
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if target_dir.is_symlink():
|
||||||
|
raise ValueError(f"download directory must not be a symlink: {target_dir}")
|
||||||
|
|
||||||
|
# Build dest, then resolve and confirm containment. Path.resolve() collapses
|
||||||
|
# '..' even on non-existent leaves, so escape attempts surface here.
|
||||||
|
dest = (target_dir / filename).resolve()
|
||||||
|
if not dest.is_relative_to(target_dir):
|
||||||
|
raise ValueError(
|
||||||
|
f"refusing destination outside target dir: {dest} not under {target_dir}"
|
||||||
|
)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
# ---------- tools ----------
|
# ---------- tools ----------
|
||||||
|
|
||||||
|
|
||||||
@ -132,6 +182,13 @@ async def get_item_metadata(
|
|||||||
|
|
||||||
By default omits the (potentially huge) files[] array — call list_files for that.
|
By default omits the (potentially huge) files[] array — call list_files for that.
|
||||||
"""
|
"""
|
||||||
|
return await _fetch_item_metadata(identifier, include_files=include_files)
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_item_metadata(identifier: str, include_files: bool) -> dict[str, Any]:
|
||||||
|
"""Shared metadata-fetching logic. Used by both the tool and the MCP resource
|
||||||
|
so neither has to depend on the other's `.fn` attribute."""
|
||||||
|
validate_identifier(identifier)
|
||||||
async with ArchiveClient() as c:
|
async with ArchiveClient() as c:
|
||||||
data = await c.metadata(identifier)
|
data = await c.metadata(identifier)
|
||||||
|
|
||||||
@ -177,6 +234,7 @@ async def list_files(
|
|||||||
|
|
||||||
Each entry includes a ready-to-use `download_url`.
|
Each entry includes a ready-to-use `download_url`.
|
||||||
"""
|
"""
|
||||||
|
validate_identifier(identifier)
|
||||||
async with ArchiveClient() as c:
|
async with ArchiveClient() as c:
|
||||||
files = await c.files(identifier)
|
files = await c.files(identifier)
|
||||||
|
|
||||||
@ -199,8 +257,10 @@ def get_file_url(
|
|||||||
filename: Annotated[str, Field(description="Exact filename as shown in list_files.")],
|
filename: Annotated[str, Field(description="Exact filename as shown in list_files.")],
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Build the canonical download URL for a file without fetching anything."""
|
"""Build the canonical download URL for a file without fetching anything."""
|
||||||
|
validate_identifier(identifier)
|
||||||
|
validate_filename(filename)
|
||||||
return {
|
return {
|
||||||
"url": f"https://archive.org/download/{identifier}/{filename}",
|
"url": f"https://archive.org/download/{identifier}/{quote(filename, safe='/')}",
|
||||||
"item_url": f"https://archive.org/details/{identifier}",
|
"item_url": f"https://archive.org/details/{identifier}",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,18 +282,29 @@ async def download_file(
|
|||||||
Field(description="If false and file exists, resume the download (Range request)."),
|
Field(description="If false and file exists, resume the download (Range request)."),
|
||||||
] = False,
|
] = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Download a file to disk. Supports resume via HTTP Range when overwrite=false."""
|
"""Download a file to disk. Supports resume via HTTP Range when overwrite=false.
|
||||||
target_dir = Path(dest_dir).expanduser() if dest_dir else (DEFAULT_DOWNLOAD_ROOT / identifier)
|
|
||||||
dest = target_dir / filename
|
The destination is path-confined to either `dest_dir` (when given) or
|
||||||
if overwrite and dest.exists():
|
$MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute
|
||||||
|
paths, or NUL bytes are rejected before any FS or network I/O.
|
||||||
|
"""
|
||||||
|
dest = _confine_dest(identifier, filename, dest_dir)
|
||||||
|
if overwrite and dest.exists() and not dest.is_symlink():
|
||||||
|
dest.unlink()
|
||||||
|
elif overwrite and dest.is_symlink():
|
||||||
|
# Don't follow the symlink to delete the target; remove the link itself.
|
||||||
dest.unlink()
|
dest.unlink()
|
||||||
|
|
||||||
|
try:
|
||||||
async with ArchiveClient() as c:
|
async with ArchiveClient() as c:
|
||||||
result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5)
|
result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5)
|
||||||
|
except ArchiveError as e:
|
||||||
|
# Re-raise with the destination context so the caller can act on it.
|
||||||
|
raise ArchiveError(f"{e} (dest={dest})") from e
|
||||||
|
|
||||||
result["identifier"] = identifier
|
result["identifier"] = identifier
|
||||||
result["filename"] = filename
|
result["filename"] = filename
|
||||||
result["size_human"] = _human_size(result.get("bytes"))
|
result["size_human"] = _human_size(result.get("bytes_written"))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -243,7 +314,7 @@ async def download_file(
|
|||||||
@mcp.resource("archive://item/{identifier}")
|
@mcp.resource("archive://item/{identifier}")
|
||||||
async def item_resource(identifier: str) -> dict[str, Any]:
|
async def item_resource(identifier: str) -> dict[str, Any]:
|
||||||
"""Expose item metadata as a readable MCP resource."""
|
"""Expose item metadata as a readable MCP resource."""
|
||||||
return await get_item_metadata.fn(identifier=identifier, include_files=False) # type: ignore[attr-defined]
|
return await _fetch_item_metadata(identifier, include_files=False)
|
||||||
|
|
||||||
|
|
||||||
# ---------- entry point ----------
|
# ---------- entry point ----------
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import pytest
|
|||||||
|
|
||||||
from mcarchive_org.client import ArchiveClient
|
from mcarchive_org.client import ArchiveClient
|
||||||
|
|
||||||
pytestmark = [pytest.mark.asyncio, pytest.mark.network]
|
pytestmark = pytest.mark.network
|
||||||
|
|
||||||
|
|
||||||
async def test_search_nasa_item():
|
async def test_search_nasa_item():
|
||||||
@ -41,7 +41,7 @@ async def test_download_small_file(tmp_path: Path):
|
|||||||
result = await c.download_to_file(
|
result = await c.download_to_file(
|
||||||
"nasa", small["name"], dest, verify_md5=small.get("md5")
|
"nasa", small["name"], dest, verify_md5=small.get("md5")
|
||||||
)
|
)
|
||||||
assert result["bytes"] > 0
|
assert result["bytes_written"] > 0
|
||||||
if small.get("md5"):
|
if small.get("md5"):
|
||||||
assert result["md5_ok"] is True
|
assert result["md5_ok"] is True
|
||||||
|
|
||||||
|
|||||||
242
tests/test_client_mocked.py
Normal file
242
tests/test_client_mocked.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
"""Failure-mode regression tests using httpx.MockTransport (no network).
|
||||||
|
|
||||||
|
Each test pins down one of the Hamilton review findings (C1/C2/C3/H4 etc.) so
|
||||||
|
future refactors can't silently regress safety.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mcarchive_org.client import (
|
||||||
|
ArchiveClient,
|
||||||
|
ArchiveError,
|
||||||
|
validate_filename,
|
||||||
|
validate_identifier,
|
||||||
|
)
|
||||||
|
from mcarchive_org.server import _confine_dest
|
||||||
|
|
||||||
|
|
||||||
|
def _client_with(handler) -> ArchiveClient:
|
||||||
|
"""Build an ArchiveClient backed by a MockTransport handler."""
|
||||||
|
return ArchiveClient(transport=httpx.MockTransport(handler))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- C1: identifier + filename validation ----------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bad", ["", "../etc", "foo/bar", "has space", "a" * 200])
|
||||||
|
def test_invalid_identifier_rejected(bad):
|
||||||
|
with pytest.raises(ValueError, match=r"invalid archive\.org identifier"):
|
||||||
|
validate_identifier(bad)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"bad",
|
||||||
|
[
|
||||||
|
"../escape.txt",
|
||||||
|
"/etc/passwd",
|
||||||
|
"C:\\windows.txt",
|
||||||
|
"with\x00null.bin",
|
||||||
|
"foo/../bar.mp3",
|
||||||
|
"foo\\..\\bar.mp3",
|
||||||
|
"",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_filename_rejected(bad):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validate_filename(bad)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ok",
|
||||||
|
["song.mp3", "cover/back.jpg", "subdir/file with space.txt", "a.b.c.d"],
|
||||||
|
)
|
||||||
|
def test_legitimate_filenames_accepted(ok):
|
||||||
|
assert validate_filename(ok) == ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_confine_dest_blocks_traversal(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||||
|
# validate_filename catches '..' before _confine_dest's path-resolution check,
|
||||||
|
# so this raises ValueError from the validator — both layers in agreement.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_confine_dest("nasa", "../escape.txt", dest_dir=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_confine_dest_legit_filename_lands_in_root(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||||
|
dest = _confine_dest("nasa", "globe.jpg", dest_dir=None)
|
||||||
|
assert dest.is_relative_to(tmp_path)
|
||||||
|
assert dest.name == "globe.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- C2: symlink refusal ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_refuses_symlink_at_dest(tmp_path):
|
||||||
|
target = tmp_path / "real.bin"
|
||||||
|
target.write_bytes(b"original-content")
|
||||||
|
|
||||||
|
link = tmp_path / "evil.bin"
|
||||||
|
link.symlink_to(target)
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, content=b"new-content-that-should-not-overwrite")
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="symlink"):
|
||||||
|
await c.download_to_file("nasa", "evil.bin", link)
|
||||||
|
|
||||||
|
# Symlink target must be unchanged.
|
||||||
|
assert target.read_bytes() == b"original-content"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- C3: Range-ignored detection ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resume_with_200_response_raises_before_writing(tmp_path):
|
||||||
|
"""If the server returns 200 instead of 206 on a Range request, we must not
|
||||||
|
append to the existing file — that path corrupts data silently."""
|
||||||
|
dest = tmp_path / "partial.bin"
|
||||||
|
dest.write_bytes(b"X" * 100) # pretend we have a partial download
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
# Server ignores Range header and returns the full body with 200
|
||||||
|
assert req.headers.get("Range") == "bytes=100-"
|
||||||
|
return httpx.Response(200, content=b"FULL_FILE_BODY")
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="ignored Range"):
|
||||||
|
await c.download_to_file("nasa", "partial.bin", dest)
|
||||||
|
|
||||||
|
# File must be unchanged — corruption avoided.
|
||||||
|
assert dest.read_bytes() == b"X" * 100
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resume_with_correct_206_succeeds(tmp_path):
|
||||||
|
full_body = b"0123456789ABCDEF" * 16 # 256 bytes
|
||||||
|
dest = tmp_path / "resume.bin"
|
||||||
|
dest.write_bytes(full_body[:64]) # we already have first 64 bytes
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
assert req.headers.get("Range") == "bytes=64-"
|
||||||
|
return httpx.Response(
|
||||||
|
206,
|
||||||
|
content=full_body[64:],
|
||||||
|
headers={"Content-Range": f"bytes 64-{len(full_body)-1}/{len(full_body)}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_md5 = hashlib.md5(full_body).hexdigest()
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
result = await c.download_to_file(
|
||||||
|
"nasa", "resume.bin", dest, verify_md5=expected_md5
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["bytes_written"] == len(full_body)
|
||||||
|
assert result["resumed_from"] == 64
|
||||||
|
assert result["md5_ok"] is True
|
||||||
|
assert dest.read_bytes() == full_body
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resume_with_wrong_content_range_start_raises(tmp_path):
|
||||||
|
dest = tmp_path / "off.bin"
|
||||||
|
dest.write_bytes(b"X" * 100)
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
# Server returns 206 but with WRONG starting offset
|
||||||
|
return httpx.Response(
|
||||||
|
206,
|
||||||
|
content=b"junk",
|
||||||
|
headers={"Content-Range": "bytes 50-99/100"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="Content-Range start"):
|
||||||
|
await c.download_to_file("nasa", "off.bin", dest)
|
||||||
|
|
||||||
|
assert dest.read_bytes() == b"X" * 100 # unchanged
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- H4: error body surfacing ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_400_includes_response_body():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(400, text='{"error":"bad query syntax"}')
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="bad query syntax"):
|
||||||
|
await c.search(query="INVALID:::")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_metadata_404_includes_status():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(404, text="not found")
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="HTTP 404"):
|
||||||
|
await c.metadata("nasa")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_metadata_empty_dict_means_not_found():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, json={})
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="not found or unavailable"):
|
||||||
|
await c.metadata("nasa")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_files_returns_error_payload_as_archive_error():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, json={"error": "item is dark"})
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="item is dark"):
|
||||||
|
await c.files("nasa")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_scrape_error_payload_surfaced():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
200, json={"error": "count too small", "errorType": "RangeException"}
|
||||||
|
)
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match=r"RangeException.*count too small"):
|
||||||
|
await c.scrape(query="identifier:nasa", count=100)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_json_response_surfaced():
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, text="<html>not json</html>")
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="invalid JSON"):
|
||||||
|
await c.metadata("nasa")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- happy path ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_fresh_download_writes_full_body(tmp_path):
|
||||||
|
body = b"hello world" * 100
|
||||||
|
dest = tmp_path / "new.bin"
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
assert "Range" not in req.headers
|
||||||
|
return httpx.Response(200, content=body)
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
result = await c.download_to_file(
|
||||||
|
"nasa", "new.bin", dest, verify_md5=hashlib.md5(body).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["bytes_written"] == len(body)
|
||||||
|
assert result["resumed_from"] == 0
|
||||||
|
assert result["md5_ok"] is True
|
||||||
|
assert dest.read_bytes() == body
|
||||||
Loading…
x
Reference in New Issue
Block a user