Add atomic .part staging + runtime download root tools

Atomic write pattern (tier-3 polish from headless test finding):
- download_to_file now writes to <dest>.part and renames to <dest> only on
  successful stream completion (os.replace is POSIX-atomic). Failed
  downloads leave only the .part file — no misleading 0-byte dest files
  in the user's downloads directory.
- Resume logic reads from <dest>.part instead of <dest>; the user's
  directory only ever contains complete files or clearly-marked .part files.
- New `already_complete` short-circuit: if dest exists and no .part, skip
  the network entirely (still re-verify MD5 if requested). The headless
  Claude test confirmed this avoids redundant CDN load.
- Symlink rejection re-added at the new code path: even though os.replace
  would only replace (not follow) a symlink at dest, predictable refusal
  beats silent symlink removal.

Runtime download root tools (for stdio MCP mode):
- get_download_root(): reports current root, source (env var vs default),
  existence, writability.
- set_download_root(path): change MCARCHIVE_DOWNLOAD_ROOT mid-session.
  Expands ~, creates the dir, refuses system paths
  (/, /etc, /usr, /bin, /sbin, /var, /sys, /proc, /dev, /boot, /root).
  The lazy-resolved root means the change takes effect on the next
  download_file call without restarting the server.

14 new tests (66 total, all green, ruff clean):
- 4 staging tests: failed download leaves no dest, success leaves no .part,
  already_complete short-circuit, MD5 verification on existing files
- 6 root-tools tests: env reporting, default reporting, ~ expansion,
  system-dir refusal (parametrized), set→download takes effect immediately
- 4 existing tests rewritten to use .part as the resume staging file

Headless Claude smoke test verified end-to-end: get_download_root →
set_download_root → search → list → download → second download
short-circuits with already_complete=true and zero network bytes.
This commit is contained in:
Ryan Malloy 2026-04-21 21:11:56 -06:00
parent 6198defeca
commit 25a34cd24d
4 changed files with 307 additions and 29 deletions

View File

@ -329,11 +329,21 @@ class ArchiveClient:
verify_md5: str | None = None, verify_md5: str | None = None,
chunk_cb=None, chunk_cb=None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Download with resume support. Returns stats + md5 verification result. """Download with resume support, atomically promoted on success.
Caller is responsible for ensuring `dest` is confined to a safe directory Writes go to a `<dest>.part` staging file. On success, that staging file
(use mcarchive_org.server's path validation, or do your own). This method is atomically renamed to `dest` (POSIX rename is observed all-or-nothing).
adds defense-in-depth via O_NOFOLLOW but does not validate path confinement. On any failure, the `.part` file remains and a follow-up call resumes
from it meaning the user's directory only ever contains complete files
or `.part` files (which clearly signal incomplete state).
- `dest` already complete (no .part) and overwrite implied false
early-return with `already_complete: True` (still verifies MD5 if asked).
- `.part` exists resume from its size, append, then promote on success.
- Neither exists fresh download to `.part`, promote on success.
Caller is responsible for path confinement; this method adds O_NOFOLLOW
defense-in-depth but does not enforce a download root.
""" """
validate_identifier(identifier) validate_identifier(identifier)
validate_filename(filename) validate_filename(filename)
@ -343,19 +353,50 @@ class ArchiveClient:
# check should already have caught this, but redundant defense is cheap. # check should already have caught this, but redundant defense is cheap.
if dest.parent.is_symlink(): if dest.parent.is_symlink():
raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}") raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}")
# Refuse to "download over" a symlink at dest. Even though the staging
# rename would replace (not follow) the symlink, predictable refusal is
# better than silent symlink removal — that surprises the user and could
# destroy a symlink farm they intentionally set up.
if dest.is_symlink():
raise ArchiveError(f"refusing to write through symlink at {dest}")
resume_from = dest.stat().st_size if dest.exists() and not dest.is_symlink() else 0 part = dest.with_name(dest.name + ".part")
# is_symlink() detection: if dest is a symlink, treat as fresh (open with # Same protection on the staging path.
# O_NOFOLLOW will refuse anyway, so we won't actually corrupt the target). if part.is_symlink():
raise ArchiveError(f"refusing to write through symlink at {part}")
# Short-circuit: dest already exists complete and no .part to resume.
if dest.exists() and not dest.is_symlink() and not part.exists():
size = dest.stat().st_size
result: dict[str, Any] = {
"path": str(dest),
"bytes_written": size,
"resumed_from": size,
"already_complete": True,
}
if verify_md5:
# Verify the existing file matches the expected hash.
hasher = hashlib.md5()
with dest.open("rb") as fh:
while chunk := fh.read(1 << 16):
hasher.update(chunk)
actual = hasher.hexdigest()
result["md5_actual"] = actual
result["md5_expected"] = verify_md5
result["md5_ok"] = actual.lower() == verify_md5.lower()
return result
# Resume from .part if present; otherwise fresh.
resume_from = part.stat().st_size if part.exists() and not part.is_symlink() else 0
hasher = hashlib.md5() if verify_md5 else None hasher = hashlib.md5() if verify_md5 else None
if hasher and resume_from: if hasher and resume_from:
with dest.open("rb") as f: with part.open("rb") as fh:
while chunk := f.read(1 << 16): while chunk := fh.read(1 << 16):
hasher.update(chunk) hasher.update(chunk)
bytes_written = resume_from bytes_written = resume_from
f = _safe_open_for_write(dest, append=resume_from > 0) f = _safe_open_for_write(part, append=resume_from > 0)
try: try:
async for chunk in self.stream_file(identifier, filename, resume_from=resume_from): async for chunk in self.stream_file(identifier, filename, resume_from=resume_from):
f.write(chunk) f.write(chunk)
@ -365,22 +406,24 @@ class ArchiveClient:
if chunk_cb: if chunk_cb:
chunk_cb(bytes_written) chunk_cb(bytes_written)
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e: except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e:
# H1: surface partial-state context so the caller can decide whether # H1: surface partial-state context. The .part file stays on disk
# to resume or restart. The bytes already on disk are valid (we only # so a follow-up call with overwrite=False resumes from bytes_written.
# write whole chunks), so a follow-up call with overwrite=False will
# resume cleanly from `bytes_written`.
raise ArchiveError( raise ArchiveError(
f"download interrupted after {bytes_written - resume_from} new bytes " f"download interrupted after {bytes_written - resume_from} new bytes "
f"({bytes_written} total on disk, resumed from {resume_from}). " f"({bytes_written} total in {part.name}, resumed from {resume_from}). "
f"File at {dest} — call download_file again to resume. Cause: {e!r}" f"Call download_file again to resume. Cause: {e!r}"
) from e ) from e
finally: finally:
f.close() f.close()
result: dict[str, Any] = { # Atomic promotion — only after the stream completed cleanly.
os.replace(part, dest)
result = {
"path": str(dest), "path": str(dest),
"bytes_written": bytes_written, "bytes_written": bytes_written,
"resumed_from": resume_from, "resumed_from": resume_from,
"already_complete": False,
} }
if verify_md5 and hasher: if verify_md5 and hasher:
actual = hasher.hexdigest() actual = hasher.hexdigest()

View File

@ -342,13 +342,18 @@ async def download_file(
so two parallel tool invocations can't race on the same destination file. so two parallel tool invocations can't race on the same destination file.
""" """
dest = _confine_dest(identifier, filename, dest_dir) dest = _confine_dest(identifier, filename, dest_dir)
part = dest.with_name(dest.name + ".part")
lock = _download_lock_for(identifier, filename) lock = _download_lock_for(identifier, filename)
async with lock: async with lock:
if overwrite and (dest.exists() or dest.is_symlink()): if overwrite:
# is_symlink() before exists() — a dangling symlink reports exists()=False # Remove both the final file AND any leftover .part — we want to start
# but we still want to remove the link itself rather than follow it. # truly fresh, not resume an old partial. is_symlink check first since
# a dangling symlink reports exists()=False.
if dest.exists() or dest.is_symlink():
dest.unlink() dest.unlink()
if part.exists() or part.is_symlink():
part.unlink()
try: try:
c = await get_shared_client() c = await get_shared_client()
@ -363,6 +368,76 @@ async def download_file(
return result return result
# ---------- runtime configuration ----------
# Paths we refuse to use as a download root no matter what — these are system
# directories where writing junk is genuinely harmful. The user's MCP client
# can usually re-launch with MCARCHIVE_DOWNLOAD_ROOT pointing anywhere they
# want at startup; this guard is just for the LLM-driven set_download_root tool.
_FORBIDDEN_ROOTS = frozenset({
"/", "/etc", "/usr", "/bin", "/sbin", "/var", "/sys", "/proc", "/dev", "/boot", "/root",
})
def _check_root_safety(p: Path) -> None:
s = str(p)
if s in _FORBIDDEN_ROOTS:
raise ValueError(f"refusing to use system directory as download root: {s}")
for forbidden in _FORBIDDEN_ROOTS:
if forbidden != "/" and s.startswith(forbidden + "/"):
raise ValueError(f"refusing to use system directory as download root: {s}")
@mcp.tool
def get_download_root() -> dict[str, Any]:
"""Report the directory where download_file writes by default.
Useful at the start of a session to confirm where files will land. The
`source` field tells you whether the value came from the MCARCHIVE_DOWNLOAD_ROOT
env var or from the built-in default of `./downloads` under the server's CWD.
"""
raw_env = os.environ.get("MCARCHIVE_DOWNLOAD_ROOT")
root = _resolve_download_root().resolve()
return {
"download_root": str(root),
"exists": root.exists(),
"writable": os.access(root, os.W_OK) if root.exists() else None,
"source": "MCARCHIVE_DOWNLOAD_ROOT env var" if raw_env else "default (./downloads under server CWD)",
"raw_env_value": raw_env,
}
@mcp.tool
def set_download_root(
path: Annotated[
str,
Field(description="New download root path. '~' is expanded; the directory is created if missing."),
],
) -> dict[str, Any]:
"""Change the download root for the rest of this MCP server session.
Useful when running as a stdio MCP server, since you can't otherwise
re-export environment variables to a running child process. The change
persists until the server process exits or set_download_root is called
again. System directories (/etc, /usr, /var, /sys, /proc, /dev, /boot,
/root, /bin, /sbin, /) are refused.
"""
expanded = Path(path).expanduser().resolve()
_check_root_safety(expanded)
previous = _resolve_download_root().resolve()
expanded.mkdir(parents=True, exist_ok=True)
if not os.access(expanded, os.W_OK):
raise ValueError(f"download root not writable: {expanded}")
os.environ["MCARCHIVE_DOWNLOAD_ROOT"] = str(expanded)
return {
"download_root": str(expanded),
"previous": str(previous),
"changed": str(expanded) != str(previous),
}
# ---------- resources ---------- # ---------- resources ----------

View File

@ -102,7 +102,8 @@ async def test_resume_with_200_response_raises_before_writing(tmp_path):
"""If the server returns 200 instead of 206 on a Range request, we must not """If the server returns 200 instead of 206 on a Range request, we must not
append to the existing file that path corrupts data silently.""" append to the existing file that path corrupts data silently."""
dest = tmp_path / "partial.bin" dest = tmp_path / "partial.bin"
dest.write_bytes(b"X" * 100) # pretend we have a partial download part = tmp_path / "partial.bin.part" # staging file holds resume state
part.write_bytes(b"X" * 100)
def handler(req: httpx.Request) -> httpx.Response: def handler(req: httpx.Request) -> httpx.Response:
# Server ignores Range header and returns the full body with 200 # Server ignores Range header and returns the full body with 200
@ -113,14 +114,16 @@ async def test_resume_with_200_response_raises_before_writing(tmp_path):
with pytest.raises(ArchiveError, match="ignored Range"): with pytest.raises(ArchiveError, match="ignored Range"):
await c.download_to_file("nasa", "partial.bin", dest) await c.download_to_file("nasa", "partial.bin", dest)
# File must be unchanged — corruption avoided. # The .part file must be unchanged and dest must not exist — corruption avoided.
assert dest.read_bytes() == b"X" * 100 assert part.read_bytes() == b"X" * 100
assert not dest.exists()
async def test_resume_with_correct_206_succeeds(tmp_path): async def test_resume_with_correct_206_succeeds(tmp_path):
full_body = b"0123456789ABCDEF" * 16 # 256 bytes full_body = b"0123456789ABCDEF" * 16 # 256 bytes
dest = tmp_path / "resume.bin" dest = tmp_path / "resume.bin"
dest.write_bytes(full_body[:64]) # we already have first 64 bytes part = tmp_path / "resume.bin.part"
part.write_bytes(full_body[:64]) # we already have first 64 bytes in staging
def handler(req: httpx.Request) -> httpx.Response: def handler(req: httpx.Request) -> httpx.Response:
assert req.headers.get("Range") == "bytes=64-" assert req.headers.get("Range") == "bytes=64-"
@ -139,12 +142,15 @@ async def test_resume_with_correct_206_succeeds(tmp_path):
assert result["bytes_written"] == len(full_body) assert result["bytes_written"] == len(full_body)
assert result["resumed_from"] == 64 assert result["resumed_from"] == 64
assert result["md5_ok"] is True assert result["md5_ok"] is True
# On success the .part is atomically renamed to dest.
assert dest.read_bytes() == full_body assert dest.read_bytes() == full_body
assert not part.exists()
async def test_resume_with_wrong_content_range_start_raises(tmp_path): async def test_resume_with_wrong_content_range_start_raises(tmp_path):
dest = tmp_path / "off.bin" dest = tmp_path / "off.bin"
dest.write_bytes(b"X" * 100) part = tmp_path / "off.bin.part"
part.write_bytes(b"X" * 100)
def handler(req: httpx.Request) -> httpx.Response: def handler(req: httpx.Request) -> httpx.Response:
# Server returns 206 but with WRONG starting offset # Server returns 206 but with WRONG starting offset
@ -158,7 +164,9 @@ async def test_resume_with_wrong_content_range_start_raises(tmp_path):
with pytest.raises(ArchiveError, match="Content-Range start"): with pytest.raises(ArchiveError, match="Content-Range start"):
await c.download_to_file("nasa", "off.bin", dest) await c.download_to_file("nasa", "off.bin", dest)
assert dest.read_bytes() == b"X" * 100 # unchanged # .part unchanged, dest never created.
assert part.read_bytes() == b"X" * 100
assert not dest.exists()
# ---------- H4: error body surfacing ---------- # ---------- H4: error body surfacing ----------
@ -328,6 +336,7 @@ async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path):
return httpx.Response(200, content=evil_body()) return httpx.Response(200, content=evil_body())
dest = tmp_path / "interrupted.bin" dest = tmp_path / "interrupted.bin"
part = tmp_path / "interrupted.bin.part"
async with _client_with(handler) as c: async with _client_with(handler) as c:
with pytest.raises(ArchiveError) as exc_info: with pytest.raises(ArchiveError) as exc_info:
await c.download_to_file("nasa", "interrupted.bin", dest) await c.download_to_file("nasa", "interrupted.bin", dest)
@ -335,8 +344,9 @@ async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path):
msg = str(exc_info.value) msg = str(exc_info.value)
assert "interrupted after" in msg assert "interrupted after" in msg
assert "ReadError" in msg assert "ReadError" in msg
# Partial bytes ARE on disk — at least the first delivered chunk. # Partial bytes go to .part, NOT dest. dest stays absent until success.
on_disk = dest.read_bytes() assert not dest.exists()
on_disk = part.read_bytes()
assert len(on_disk) > 0 assert len(on_disk) > 0
assert on_disk == chunk_payload[: len(on_disk)] assert on_disk == chunk_payload[: len(on_disk)]
@ -360,4 +370,82 @@ async def test_fresh_download_writes_full_body(tmp_path):
assert result["bytes_written"] == len(body) assert result["bytes_written"] == len(body)
assert result["resumed_from"] == 0 assert result["resumed_from"] == 0
assert result["md5_ok"] is True assert result["md5_ok"] is True
assert result["already_complete"] is False
# The atomic-rename pattern leaves no .part artifact after success.
assert dest.read_bytes() == body assert dest.read_bytes() == body
assert not (tmp_path / "new.bin.part").exists()
# ---------- Atomic .part staging ----------
async def test_failed_download_leaves_no_dest_file(tmp_path):
"""A failed fresh download must NOT leave the final dest file as zero bytes —
it should leave only the .part staging file (or nothing if no bytes arrived)."""
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(500, text="upstream cdn miss")
dest = tmp_path / "shouldfail.bin"
async with _client_with(handler) as c:
with pytest.raises(ArchiveError, match="HTTP 500"):
await c.download_to_file("nasa", "shouldfail.bin", dest)
# Critical: dest must NOT exist as an empty file misleading the user.
assert not dest.exists()
async def test_already_complete_short_circuits_without_network(tmp_path):
"""If dest exists and no .part, a follow-up download must not hit the
network the file is already complete."""
dest = tmp_path / "done.bin"
dest.write_bytes(b"already-here")
calls = {"n": 0}
def handler(req: httpx.Request) -> httpx.Response:
calls["n"] += 1
return httpx.Response(500, text="should never fire")
async with _client_with(handler) as c:
result = await c.download_to_file("nasa", "done.bin", dest)
assert calls["n"] == 0 # no network at all
assert result["already_complete"] is True
assert result["bytes_written"] == len(b"already-here")
assert dest.read_bytes() == b"already-here"
async def test_already_complete_verifies_md5_against_existing_file(tmp_path):
"""If verify_md5 is passed and dest is complete, we re-hash to confirm."""
body = b"on-disk-content"
dest = tmp_path / "done.bin"
dest.write_bytes(body)
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(500, text="should never fire")
async with _client_with(handler) as c:
result = await c.download_to_file(
"nasa", "done.bin", dest, verify_md5=hashlib.md5(body).hexdigest()
)
assert result["already_complete"] is True
assert result["md5_ok"] is True
async def test_already_complete_md5_mismatch_caught(tmp_path):
"""If the existing file's MD5 doesn't match expected, surface md5_ok=False."""
dest = tmp_path / "wrong.bin"
dest.write_bytes(b"actual-content")
def handler(req: httpx.Request) -> httpx.Response:
return httpx.Response(500, text="should never fire")
async with _client_with(handler) as c:
result = await c.download_to_file(
"nasa", "wrong.bin", dest, verify_md5="0" * 32
)
assert result["already_complete"] is True
assert result["md5_ok"] is False
assert result["md5_expected"] == "0" * 32

View File

@ -10,6 +10,7 @@ These exercise the MCP tool functions directly and verify:
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import httpx import httpx
@ -21,8 +22,10 @@ from mcarchive_org.server import (
_enrich_doc, _enrich_doc,
_normalize_collection, _normalize_collection,
download_file, download_file,
get_download_root,
get_item_metadata, get_item_metadata,
search_items, search_items,
set_download_root,
) )
@ -170,6 +173,75 @@ async def test_concurrent_downloads_same_file_are_serialized(tmp_path, monkeypat
assert state["max_active"] == 1 assert state["max_active"] == 1
# ---------- runtime download root management ----------
def test_get_download_root_reports_env_value(tmp_path, monkeypatch):
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
info = get_download_root()
assert info["download_root"] == str(tmp_path.resolve())
assert info["source"] == "MCARCHIVE_DOWNLOAD_ROOT env var"
assert info["raw_env_value"] == str(tmp_path)
def test_get_download_root_reports_default_when_no_env(monkeypatch):
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
info = get_download_root()
assert info["source"] == "default (./downloads under server CWD)"
assert info["raw_env_value"] is None
def test_set_download_root_changes_env_and_creates_dir(tmp_path, monkeypatch):
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
target = tmp_path / "new" / "spot"
assert not target.exists()
info = set_download_root(path=str(target))
assert info["download_root"] == str(target.resolve())
assert info["changed"] is True
assert target.exists() and target.is_dir()
assert os.environ["MCARCHIVE_DOWNLOAD_ROOT"] == str(target.resolve())
def test_set_download_root_expands_tilde(tmp_path, monkeypatch):
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
monkeypatch.setenv("HOME", str(tmp_path))
info = set_download_root(path="~/dl")
assert info["download_root"] == str((tmp_path / "dl").resolve())
assert (tmp_path / "dl").exists()
@pytest.mark.parametrize("forbidden", ["/etc", "/usr/local", "/var/log", "/", "/sys"])
def test_set_download_root_refuses_system_dirs(forbidden):
with pytest.raises(ValueError, match="system directory"):
set_download_root(path=forbidden)
async def test_set_download_root_takes_effect_for_next_download(tmp_path, monkeypatch):
"""The lazy-resolved root means a runtime change is honored by download_file
on the very next call without restarting."""
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
set_download_root(path=str(tmp_path / "first"))
def handler(req):
return httpx.Response(200, content=b"data")
async with swap_shared_client(handler):
await download_file(identifier="nasa", filename="a.bin", overwrite=True)
# Now move the root to a different directory mid-session.
set_download_root(path=str(tmp_path / "second"))
await download_file(identifier="nasa", filename="b.bin", overwrite=True)
assert (tmp_path / "first" / "nasa" / "a.bin").exists()
assert (tmp_path / "second" / "nasa" / "b.bin").exists()
# ---------- M2 (continued): cross-file parallelism ----------
async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch): async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch):
"""Different filenames get different locks — they should run concurrently.""" """Different filenames get different locks — they should run concurrently."""
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path)) monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))