|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +from pathlib import Path |
| 5 | +import requests |
| 6 | +from tqdm import tqdm |
| 7 | + |
| 8 | +from multiverse.utils._hash import sha256_file |
| 9 | + |
| 10 | + |
| 11 | +def _cache_dir() -> Path: |
| 12 | + base = Path(os.environ.get("MULTIVERSE_CACHE", Path.home() / ".multiverse")) |
| 13 | + d = base / "datasets" |
| 14 | + d.mkdir(parents=True, exist_ok=True) |
| 15 | + return d |
| 16 | + |
| 17 | + |
| 18 | +def _zenodo_download_url(record_id: str, artifact_path: str) -> str: |
| 19 | + # Simple form which works for public Zenodo records if the filename matches exactly. |
| 20 | + # For robustness, you can query Zenodo's API and locate the file by filename. |
| 21 | + return f"https://zenodo.org/records/{record_id}/files/{artifact_path}?download=1" |
| 22 | + |
| 23 | + |
| 24 | +def download_artifact(record_id: str, artifact_path: str, expected_sha256: str | None) -> Path: |
| 25 | + cache = _cache_dir() |
| 26 | + out = cache / artifact_path |
| 27 | + |
| 28 | + if out.exists() and expected_sha256 and expected_sha256 != "REPLACE_ME": |
| 29 | + if sha256_file(out) == expected_sha256.lower(): |
| 30 | + return out |
| 31 | + |
| 32 | + url = _zenodo_download_url(record_id, artifact_path) |
| 33 | + r = requests.get(url, stream=True, timeout=60) |
| 34 | + r.raise_for_status() |
| 35 | + |
| 36 | + total = int(r.headers.get("Content-Length", 0)) |
| 37 | + tmp = out.with_suffix(out.suffix + ".tmp") |
| 38 | + |
| 39 | + with tmp.open("wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=f"Downloading {artifact_path}") as pbar: |
| 40 | + for chunk in r.iter_content(chunk_size=1024 * 1024): |
| 41 | + if chunk: |
| 42 | + f.write(chunk) |
| 43 | + pbar.update(len(chunk)) |
| 44 | + |
| 45 | + tmp.replace(out) |
| 46 | + |
| 47 | + if expected_sha256 and expected_sha256 != "REPLACE_ME": |
| 48 | + got = sha256_file(out) |
| 49 | + if got.lower() != expected_sha256.lower(): |
| 50 | + out.unlink(missing_ok=True) |
| 51 | + raise ValueError( |
| 52 | + f"Checksum mismatch for {artifact_path}. Expected {expected_sha256}, got {got}." |
| 53 | + ) |
| 54 | + |
| 55 | + return out |
0 commit comments