Add atomic .part staging + runtime download root tools
Atomic write pattern (tier-3 polish from headless test finding): - download_to_file now writes to <dest>.part and renames to <dest> only on successful stream completion (os.replace is POSIX-atomic). Failed downloads leave only the .part file — no misleading 0-byte dest files in the user's downloads directory. - Resume logic reads from <dest>.part instead of <dest>; the user's directory only ever contains complete files or clearly-marked .part files. - New `already_complete` short-circuit: if dest exists and no .part, skip the network entirely (still re-verify MD5 if requested). The headless Claude test confirmed this avoids redundant CDN load. - Symlink rejection re-added at the new code path: even though os.replace would only replace (not follow) a symlink at dest, predictable refusal beats silent symlink removal. Runtime download root tools (for stdio MCP mode): - get_download_root(): reports current root, source (env var vs default), existence, writability. - set_download_root(path): change MCARCHIVE_DOWNLOAD_ROOT mid-session. Expands ~, creates the dir, refuses system paths (/, /etc, /usr, /bin, /sbin, /var, /sys, /proc, /dev, /boot, /root). The lazy-resolved root means the change takes effect on the next download_file call without restarting the server. 14 new tests (66 total, all green, ruff clean): - 4 staging tests: failed download leaves no dest, success leaves no .part, already_complete short-circuit, MD5 verification on existing files - 6 root-tools tests: env reporting, default reporting, ~ expansion, system-dir refusal (parametrized), set→download takes effect immediately - 4 existing tests rewritten to use .part as the resume staging file Headless Claude smoke test verified end-to-end: get_download_root → set_download_root → search → list → download → second download short-circuits with already_complete=true and zero network bytes.
This commit is contained in:
parent
6198defeca
commit
25a34cd24d
@ -329,11 +329,21 @@ class ArchiveClient:
|
|||||||
verify_md5: str | None = None,
|
verify_md5: str | None = None,
|
||||||
chunk_cb=None,
|
chunk_cb=None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Download with resume support. Returns stats + md5 verification result.
|
"""Download with resume support, atomically promoted on success.
|
||||||
|
|
||||||
Caller is responsible for ensuring `dest` is confined to a safe directory
|
Writes go to a `<dest>.part` staging file. On success, that staging file
|
||||||
(use mcarchive_org.server's path validation, or do your own). This method
|
is atomically renamed to `dest` (POSIX rename is observed all-or-nothing).
|
||||||
adds defense-in-depth via O_NOFOLLOW but does not validate path confinement.
|
On any failure, the `.part` file remains and a follow-up call resumes
|
||||||
|
from it — meaning the user's directory only ever contains complete files
|
||||||
|
or `.part` files (which clearly signal incomplete state).
|
||||||
|
|
||||||
|
- `dest` already complete (no .part) and overwrite implied false →
|
||||||
|
early-return with `already_complete: True` (still verifies MD5 if asked).
|
||||||
|
- `.part` exists → resume from its size, append, then promote on success.
|
||||||
|
- Neither exists → fresh download to `.part`, promote on success.
|
||||||
|
|
||||||
|
Caller is responsible for path confinement; this method adds O_NOFOLLOW
|
||||||
|
defense-in-depth but does not enforce a download root.
|
||||||
"""
|
"""
|
||||||
validate_identifier(identifier)
|
validate_identifier(identifier)
|
||||||
validate_filename(filename)
|
validate_filename(filename)
|
||||||
@ -343,19 +353,50 @@ class ArchiveClient:
|
|||||||
# check should already have caught this, but redundant defense is cheap.
|
# check should already have caught this, but redundant defense is cheap.
|
||||||
if dest.parent.is_symlink():
|
if dest.parent.is_symlink():
|
||||||
raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}")
|
raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}")
|
||||||
|
# Refuse to "download over" a symlink at dest. Even though the staging
|
||||||
|
# rename would replace (not follow) the symlink, predictable refusal is
|
||||||
|
# better than silent symlink removal — that surprises the user and could
|
||||||
|
# destroy a symlink farm they intentionally set up.
|
||||||
|
if dest.is_symlink():
|
||||||
|
raise ArchiveError(f"refusing to write through symlink at {dest}")
|
||||||
|
|
||||||
resume_from = dest.stat().st_size if dest.exists() and not dest.is_symlink() else 0
|
part = dest.with_name(dest.name + ".part")
|
||||||
# is_symlink() detection: if dest is a symlink, treat as fresh (open with
|
# Same protection on the staging path.
|
||||||
# O_NOFOLLOW will refuse anyway, so we won't actually corrupt the target).
|
if part.is_symlink():
|
||||||
|
raise ArchiveError(f"refusing to write through symlink at {part}")
|
||||||
|
|
||||||
|
# Short-circuit: dest already exists complete and no .part to resume.
|
||||||
|
if dest.exists() and not dest.is_symlink() and not part.exists():
|
||||||
|
size = dest.stat().st_size
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"path": str(dest),
|
||||||
|
"bytes_written": size,
|
||||||
|
"resumed_from": size,
|
||||||
|
"already_complete": True,
|
||||||
|
}
|
||||||
|
if verify_md5:
|
||||||
|
# Verify the existing file matches the expected hash.
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
with dest.open("rb") as fh:
|
||||||
|
while chunk := fh.read(1 << 16):
|
||||||
|
hasher.update(chunk)
|
||||||
|
actual = hasher.hexdigest()
|
||||||
|
result["md5_actual"] = actual
|
||||||
|
result["md5_expected"] = verify_md5
|
||||||
|
result["md5_ok"] = actual.lower() == verify_md5.lower()
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Resume from .part if present; otherwise fresh.
|
||||||
|
resume_from = part.stat().st_size if part.exists() and not part.is_symlink() else 0
|
||||||
|
|
||||||
hasher = hashlib.md5() if verify_md5 else None
|
hasher = hashlib.md5() if verify_md5 else None
|
||||||
if hasher and resume_from:
|
if hasher and resume_from:
|
||||||
with dest.open("rb") as f:
|
with part.open("rb") as fh:
|
||||||
while chunk := f.read(1 << 16):
|
while chunk := fh.read(1 << 16):
|
||||||
hasher.update(chunk)
|
hasher.update(chunk)
|
||||||
|
|
||||||
bytes_written = resume_from
|
bytes_written = resume_from
|
||||||
f = _safe_open_for_write(dest, append=resume_from > 0)
|
f = _safe_open_for_write(part, append=resume_from > 0)
|
||||||
try:
|
try:
|
||||||
async for chunk in self.stream_file(identifier, filename, resume_from=resume_from):
|
async for chunk in self.stream_file(identifier, filename, resume_from=resume_from):
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
@ -365,22 +406,24 @@ class ArchiveClient:
|
|||||||
if chunk_cb:
|
if chunk_cb:
|
||||||
chunk_cb(bytes_written)
|
chunk_cb(bytes_written)
|
||||||
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e:
|
except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e:
|
||||||
# H1: surface partial-state context so the caller can decide whether
|
# H1: surface partial-state context. The .part file stays on disk
|
||||||
# to resume or restart. The bytes already on disk are valid (we only
|
# so a follow-up call with overwrite=False resumes from bytes_written.
|
||||||
# write whole chunks), so a follow-up call with overwrite=False will
|
|
||||||
# resume cleanly from `bytes_written`.
|
|
||||||
raise ArchiveError(
|
raise ArchiveError(
|
||||||
f"download interrupted after {bytes_written - resume_from} new bytes "
|
f"download interrupted after {bytes_written - resume_from} new bytes "
|
||||||
f"({bytes_written} total on disk, resumed from {resume_from}). "
|
f"({bytes_written} total in {part.name}, resumed from {resume_from}). "
|
||||||
f"File at {dest} — call download_file again to resume. Cause: {e!r}"
|
f"Call download_file again to resume. Cause: {e!r}"
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
result: dict[str, Any] = {
|
# Atomic promotion — only after the stream completed cleanly.
|
||||||
|
os.replace(part, dest)
|
||||||
|
|
||||||
|
result = {
|
||||||
"path": str(dest),
|
"path": str(dest),
|
||||||
"bytes_written": bytes_written,
|
"bytes_written": bytes_written,
|
||||||
"resumed_from": resume_from,
|
"resumed_from": resume_from,
|
||||||
|
"already_complete": False,
|
||||||
}
|
}
|
||||||
if verify_md5 and hasher:
|
if verify_md5 and hasher:
|
||||||
actual = hasher.hexdigest()
|
actual = hasher.hexdigest()
|
||||||
|
|||||||
@ -342,13 +342,18 @@ async def download_file(
|
|||||||
so two parallel tool invocations can't race on the same destination file.
|
so two parallel tool invocations can't race on the same destination file.
|
||||||
"""
|
"""
|
||||||
dest = _confine_dest(identifier, filename, dest_dir)
|
dest = _confine_dest(identifier, filename, dest_dir)
|
||||||
|
part = dest.with_name(dest.name + ".part")
|
||||||
|
|
||||||
lock = _download_lock_for(identifier, filename)
|
lock = _download_lock_for(identifier, filename)
|
||||||
async with lock:
|
async with lock:
|
||||||
if overwrite and (dest.exists() or dest.is_symlink()):
|
if overwrite:
|
||||||
# is_symlink() before exists() — a dangling symlink reports exists()=False
|
# Remove both the final file AND any leftover .part — we want to start
|
||||||
# but we still want to remove the link itself rather than follow it.
|
# truly fresh, not resume an old partial. is_symlink check first since
|
||||||
dest.unlink()
|
# a dangling symlink reports exists()=False.
|
||||||
|
if dest.exists() or dest.is_symlink():
|
||||||
|
dest.unlink()
|
||||||
|
if part.exists() or part.is_symlink():
|
||||||
|
part.unlink()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
c = await get_shared_client()
|
c = await get_shared_client()
|
||||||
@ -363,6 +368,76 @@ async def download_file(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- runtime configuration ----------
|
||||||
|
|
||||||
|
# Paths we refuse to use as a download root no matter what — these are system
|
||||||
|
# directories where writing junk is genuinely harmful. The user's MCP client
|
||||||
|
# can usually re-launch with MCARCHIVE_DOWNLOAD_ROOT pointing anywhere they
|
||||||
|
# want at startup; this guard is just for the LLM-driven set_download_root tool.
|
||||||
|
_FORBIDDEN_ROOTS = frozenset({
|
||||||
|
"/", "/etc", "/usr", "/bin", "/sbin", "/var", "/sys", "/proc", "/dev", "/boot", "/root",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _check_root_safety(p: Path) -> None:
|
||||||
|
s = str(p)
|
||||||
|
if s in _FORBIDDEN_ROOTS:
|
||||||
|
raise ValueError(f"refusing to use system directory as download root: {s}")
|
||||||
|
for forbidden in _FORBIDDEN_ROOTS:
|
||||||
|
if forbidden != "/" and s.startswith(forbidden + "/"):
|
||||||
|
raise ValueError(f"refusing to use system directory as download root: {s}")
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool
|
||||||
|
def get_download_root() -> dict[str, Any]:
|
||||||
|
"""Report the directory where download_file writes by default.
|
||||||
|
|
||||||
|
Useful at the start of a session to confirm where files will land. The
|
||||||
|
`source` field tells you whether the value came from the MCARCHIVE_DOWNLOAD_ROOT
|
||||||
|
env var or from the built-in default of `./downloads` under the server's CWD.
|
||||||
|
"""
|
||||||
|
raw_env = os.environ.get("MCARCHIVE_DOWNLOAD_ROOT")
|
||||||
|
root = _resolve_download_root().resolve()
|
||||||
|
return {
|
||||||
|
"download_root": str(root),
|
||||||
|
"exists": root.exists(),
|
||||||
|
"writable": os.access(root, os.W_OK) if root.exists() else None,
|
||||||
|
"source": "MCARCHIVE_DOWNLOAD_ROOT env var" if raw_env else "default (./downloads under server CWD)",
|
||||||
|
"raw_env_value": raw_env,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool
|
||||||
|
def set_download_root(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
Field(description="New download root path. '~' is expanded; the directory is created if missing."),
|
||||||
|
],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Change the download root for the rest of this MCP server session.
|
||||||
|
|
||||||
|
Useful when running as a stdio MCP server, since you can't otherwise
|
||||||
|
re-export environment variables to a running child process. The change
|
||||||
|
persists until the server process exits or set_download_root is called
|
||||||
|
again. System directories (/etc, /usr, /var, /sys, /proc, /dev, /boot,
|
||||||
|
/root, /bin, /sbin, /) are refused.
|
||||||
|
"""
|
||||||
|
expanded = Path(path).expanduser().resolve()
|
||||||
|
_check_root_safety(expanded)
|
||||||
|
|
||||||
|
previous = _resolve_download_root().resolve()
|
||||||
|
expanded.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not os.access(expanded, os.W_OK):
|
||||||
|
raise ValueError(f"download root not writable: {expanded}")
|
||||||
|
|
||||||
|
os.environ["MCARCHIVE_DOWNLOAD_ROOT"] = str(expanded)
|
||||||
|
return {
|
||||||
|
"download_root": str(expanded),
|
||||||
|
"previous": str(previous),
|
||||||
|
"changed": str(expanded) != str(previous),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------- resources ----------
|
# ---------- resources ----------
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -102,7 +102,8 @@ async def test_resume_with_200_response_raises_before_writing(tmp_path):
|
|||||||
"""If the server returns 200 instead of 206 on a Range request, we must not
|
"""If the server returns 200 instead of 206 on a Range request, we must not
|
||||||
append to the existing file — that path corrupts data silently."""
|
append to the existing file — that path corrupts data silently."""
|
||||||
dest = tmp_path / "partial.bin"
|
dest = tmp_path / "partial.bin"
|
||||||
dest.write_bytes(b"X" * 100) # pretend we have a partial download
|
part = tmp_path / "partial.bin.part" # staging file holds resume state
|
||||||
|
part.write_bytes(b"X" * 100)
|
||||||
|
|
||||||
def handler(req: httpx.Request) -> httpx.Response:
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
# Server ignores Range header and returns the full body with 200
|
# Server ignores Range header and returns the full body with 200
|
||||||
@ -113,14 +114,16 @@ async def test_resume_with_200_response_raises_before_writing(tmp_path):
|
|||||||
with pytest.raises(ArchiveError, match="ignored Range"):
|
with pytest.raises(ArchiveError, match="ignored Range"):
|
||||||
await c.download_to_file("nasa", "partial.bin", dest)
|
await c.download_to_file("nasa", "partial.bin", dest)
|
||||||
|
|
||||||
# File must be unchanged — corruption avoided.
|
# The .part file must be unchanged and dest must not exist — corruption avoided.
|
||||||
assert dest.read_bytes() == b"X" * 100
|
assert part.read_bytes() == b"X" * 100
|
||||||
|
assert not dest.exists()
|
||||||
|
|
||||||
|
|
||||||
async def test_resume_with_correct_206_succeeds(tmp_path):
|
async def test_resume_with_correct_206_succeeds(tmp_path):
|
||||||
full_body = b"0123456789ABCDEF" * 16 # 256 bytes
|
full_body = b"0123456789ABCDEF" * 16 # 256 bytes
|
||||||
dest = tmp_path / "resume.bin"
|
dest = tmp_path / "resume.bin"
|
||||||
dest.write_bytes(full_body[:64]) # we already have first 64 bytes
|
part = tmp_path / "resume.bin.part"
|
||||||
|
part.write_bytes(full_body[:64]) # we already have first 64 bytes in staging
|
||||||
|
|
||||||
def handler(req: httpx.Request) -> httpx.Response:
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
assert req.headers.get("Range") == "bytes=64-"
|
assert req.headers.get("Range") == "bytes=64-"
|
||||||
@ -139,12 +142,15 @@ async def test_resume_with_correct_206_succeeds(tmp_path):
|
|||||||
assert result["bytes_written"] == len(full_body)
|
assert result["bytes_written"] == len(full_body)
|
||||||
assert result["resumed_from"] == 64
|
assert result["resumed_from"] == 64
|
||||||
assert result["md5_ok"] is True
|
assert result["md5_ok"] is True
|
||||||
|
# On success the .part is atomically renamed to dest.
|
||||||
assert dest.read_bytes() == full_body
|
assert dest.read_bytes() == full_body
|
||||||
|
assert not part.exists()
|
||||||
|
|
||||||
|
|
||||||
async def test_resume_with_wrong_content_range_start_raises(tmp_path):
|
async def test_resume_with_wrong_content_range_start_raises(tmp_path):
|
||||||
dest = tmp_path / "off.bin"
|
dest = tmp_path / "off.bin"
|
||||||
dest.write_bytes(b"X" * 100)
|
part = tmp_path / "off.bin.part"
|
||||||
|
part.write_bytes(b"X" * 100)
|
||||||
|
|
||||||
def handler(req: httpx.Request) -> httpx.Response:
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
# Server returns 206 but with WRONG starting offset
|
# Server returns 206 but with WRONG starting offset
|
||||||
@ -158,7 +164,9 @@ async def test_resume_with_wrong_content_range_start_raises(tmp_path):
|
|||||||
with pytest.raises(ArchiveError, match="Content-Range start"):
|
with pytest.raises(ArchiveError, match="Content-Range start"):
|
||||||
await c.download_to_file("nasa", "off.bin", dest)
|
await c.download_to_file("nasa", "off.bin", dest)
|
||||||
|
|
||||||
assert dest.read_bytes() == b"X" * 100 # unchanged
|
# .part unchanged, dest never created.
|
||||||
|
assert part.read_bytes() == b"X" * 100
|
||||||
|
assert not dest.exists()
|
||||||
|
|
||||||
|
|
||||||
# ---------- H4: error body surfacing ----------
|
# ---------- H4: error body surfacing ----------
|
||||||
@ -328,6 +336,7 @@ async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path):
|
|||||||
return httpx.Response(200, content=evil_body())
|
return httpx.Response(200, content=evil_body())
|
||||||
|
|
||||||
dest = tmp_path / "interrupted.bin"
|
dest = tmp_path / "interrupted.bin"
|
||||||
|
part = tmp_path / "interrupted.bin.part"
|
||||||
async with _client_with(handler) as c:
|
async with _client_with(handler) as c:
|
||||||
with pytest.raises(ArchiveError) as exc_info:
|
with pytest.raises(ArchiveError) as exc_info:
|
||||||
await c.download_to_file("nasa", "interrupted.bin", dest)
|
await c.download_to_file("nasa", "interrupted.bin", dest)
|
||||||
@ -335,8 +344,9 @@ async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path):
|
|||||||
msg = str(exc_info.value)
|
msg = str(exc_info.value)
|
||||||
assert "interrupted after" in msg
|
assert "interrupted after" in msg
|
||||||
assert "ReadError" in msg
|
assert "ReadError" in msg
|
||||||
# Partial bytes ARE on disk — at least the first delivered chunk.
|
# Partial bytes go to .part, NOT dest. dest stays absent until success.
|
||||||
on_disk = dest.read_bytes()
|
assert not dest.exists()
|
||||||
|
on_disk = part.read_bytes()
|
||||||
assert len(on_disk) > 0
|
assert len(on_disk) > 0
|
||||||
assert on_disk == chunk_payload[: len(on_disk)]
|
assert on_disk == chunk_payload[: len(on_disk)]
|
||||||
|
|
||||||
@ -360,4 +370,82 @@ async def test_fresh_download_writes_full_body(tmp_path):
|
|||||||
assert result["bytes_written"] == len(body)
|
assert result["bytes_written"] == len(body)
|
||||||
assert result["resumed_from"] == 0
|
assert result["resumed_from"] == 0
|
||||||
assert result["md5_ok"] is True
|
assert result["md5_ok"] is True
|
||||||
|
assert result["already_complete"] is False
|
||||||
|
# The atomic-rename pattern leaves no .part artifact after success.
|
||||||
assert dest.read_bytes() == body
|
assert dest.read_bytes() == body
|
||||||
|
assert not (tmp_path / "new.bin.part").exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Atomic .part staging ----------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_failed_download_leaves_no_dest_file(tmp_path):
|
||||||
|
"""A failed fresh download must NOT leave the final dest file as zero bytes —
|
||||||
|
it should leave only the .part staging file (or nothing if no bytes arrived)."""
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(500, text="upstream cdn miss")
|
||||||
|
|
||||||
|
dest = tmp_path / "shouldfail.bin"
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
with pytest.raises(ArchiveError, match="HTTP 500"):
|
||||||
|
await c.download_to_file("nasa", "shouldfail.bin", dest)
|
||||||
|
|
||||||
|
# Critical: dest must NOT exist as an empty file misleading the user.
|
||||||
|
assert not dest.exists()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_already_complete_short_circuits_without_network(tmp_path):
|
||||||
|
"""If dest exists and no .part, a follow-up download must not hit the
|
||||||
|
network — the file is already complete."""
|
||||||
|
dest = tmp_path / "done.bin"
|
||||||
|
dest.write_bytes(b"already-here")
|
||||||
|
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
calls["n"] += 1
|
||||||
|
return httpx.Response(500, text="should never fire")
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
result = await c.download_to_file("nasa", "done.bin", dest)
|
||||||
|
|
||||||
|
assert calls["n"] == 0 # no network at all
|
||||||
|
assert result["already_complete"] is True
|
||||||
|
assert result["bytes_written"] == len(b"already-here")
|
||||||
|
assert dest.read_bytes() == b"already-here"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_already_complete_verifies_md5_against_existing_file(tmp_path):
|
||||||
|
"""If verify_md5 is passed and dest is complete, we re-hash to confirm."""
|
||||||
|
body = b"on-disk-content"
|
||||||
|
dest = tmp_path / "done.bin"
|
||||||
|
dest.write_bytes(body)
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(500, text="should never fire")
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
result = await c.download_to_file(
|
||||||
|
"nasa", "done.bin", dest, verify_md5=hashlib.md5(body).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["already_complete"] is True
|
||||||
|
assert result["md5_ok"] is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_already_complete_md5_mismatch_caught(tmp_path):
|
||||||
|
"""If the existing file's MD5 doesn't match expected, surface md5_ok=False."""
|
||||||
|
dest = tmp_path / "wrong.bin"
|
||||||
|
dest.write_bytes(b"actual-content")
|
||||||
|
|
||||||
|
def handler(req: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(500, text="should never fire")
|
||||||
|
|
||||||
|
async with _client_with(handler) as c:
|
||||||
|
result = await c.download_to_file(
|
||||||
|
"nasa", "wrong.bin", dest, verify_md5="0" * 32
|
||||||
|
)
|
||||||
|
assert result["already_complete"] is True
|
||||||
|
assert result["md5_ok"] is False
|
||||||
|
assert result["md5_expected"] == "0" * 32
|
||||||
|
|||||||
@ -10,6 +10,7 @@ These exercise the MCP tool functions directly and verify:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -21,8 +22,10 @@ from mcarchive_org.server import (
|
|||||||
_enrich_doc,
|
_enrich_doc,
|
||||||
_normalize_collection,
|
_normalize_collection,
|
||||||
download_file,
|
download_file,
|
||||||
|
get_download_root,
|
||||||
get_item_metadata,
|
get_item_metadata,
|
||||||
search_items,
|
search_items,
|
||||||
|
set_download_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -170,6 +173,75 @@ async def test_concurrent_downloads_same_file_are_serialized(tmp_path, monkeypat
|
|||||||
assert state["max_active"] == 1
|
assert state["max_active"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- runtime download root management ----------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_download_root_reports_env_value(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||||
|
info = get_download_root()
|
||||||
|
assert info["download_root"] == str(tmp_path.resolve())
|
||||||
|
assert info["source"] == "MCARCHIVE_DOWNLOAD_ROOT env var"
|
||||||
|
assert info["raw_env_value"] == str(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_download_root_reports_default_when_no_env(monkeypatch):
|
||||||
|
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
|
||||||
|
info = get_download_root()
|
||||||
|
assert info["source"] == "default (./downloads under server CWD)"
|
||||||
|
assert info["raw_env_value"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_download_root_changes_env_and_creates_dir(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
|
||||||
|
target = tmp_path / "new" / "spot"
|
||||||
|
assert not target.exists()
|
||||||
|
|
||||||
|
info = set_download_root(path=str(target))
|
||||||
|
|
||||||
|
assert info["download_root"] == str(target.resolve())
|
||||||
|
assert info["changed"] is True
|
||||||
|
assert target.exists() and target.is_dir()
|
||||||
|
assert os.environ["MCARCHIVE_DOWNLOAD_ROOT"] == str(target.resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_download_root_expands_tilde(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
|
||||||
|
monkeypatch.setenv("HOME", str(tmp_path))
|
||||||
|
|
||||||
|
info = set_download_root(path="~/dl")
|
||||||
|
|
||||||
|
assert info["download_root"] == str((tmp_path / "dl").resolve())
|
||||||
|
assert (tmp_path / "dl").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("forbidden", ["/etc", "/usr/local", "/var/log", "/", "/sys"])
|
||||||
|
def test_set_download_root_refuses_system_dirs(forbidden):
|
||||||
|
with pytest.raises(ValueError, match="system directory"):
|
||||||
|
set_download_root(path=forbidden)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_download_root_takes_effect_for_next_download(tmp_path, monkeypatch):
|
||||||
|
"""The lazy-resolved root means a runtime change is honored by download_file
|
||||||
|
on the very next call without restarting."""
|
||||||
|
monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False)
|
||||||
|
set_download_root(path=str(tmp_path / "first"))
|
||||||
|
|
||||||
|
def handler(req):
|
||||||
|
return httpx.Response(200, content=b"data")
|
||||||
|
|
||||||
|
async with swap_shared_client(handler):
|
||||||
|
await download_file(identifier="nasa", filename="a.bin", overwrite=True)
|
||||||
|
# Now move the root to a different directory mid-session.
|
||||||
|
set_download_root(path=str(tmp_path / "second"))
|
||||||
|
await download_file(identifier="nasa", filename="b.bin", overwrite=True)
|
||||||
|
|
||||||
|
assert (tmp_path / "first" / "nasa" / "a.bin").exists()
|
||||||
|
assert (tmp_path / "second" / "nasa" / "b.bin").exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- M2 (continued): cross-file parallelism ----------
|
||||||
|
|
||||||
|
|
||||||
async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch):
|
async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch):
|
||||||
"""Different filenames get different locks — they should run concurrently."""
|
"""Different filenames get different locks — they should run concurrently."""
|
||||||
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user