Skip to content

Commit 30a691a

Browse files
committed
Revert "code tidy"
This reverts commit 0252595.
1 parent 0252595 commit 30a691a

9 files changed

Lines changed: 163 additions & 0 deletions

File tree

src/multiverse/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Multiverse Archive loaders and utilities."""
2+
3+
from .datasets import list_datasets, get_spec, load_dataset
4+
5+
__all__ = ["list_datasets", "get_spec", "load_dataset"]

src/multiverse/_version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.1.0"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ._registry import list_datasets, get_spec
2+
from ._load import load_dataset
3+
4+
__all__ = ["list_datasets", "get_spec", "load_dataset"]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from pathlib import Path
5+
import requests
6+
from tqdm import tqdm
7+
8+
from multiverse.utils._hash import sha256_file
9+
10+
11+
def _cache_dir() -> Path:
12+
base = Path(os.environ.get("MULTIVERSE_CACHE", Path.home() / ".multiverse"))
13+
d = base / "datasets"
14+
d.mkdir(parents=True, exist_ok=True)
15+
return d
16+
17+
18+
def _zenodo_download_url(record_id: str, artifact_path: str) -> str:
19+
# Simple form which works for public Zenodo records if the filename matches exactly.
20+
# For robustness, you can query Zenodo's API and locate the file by filename.
21+
return f"https://zenodo.org/records/{record_id}/files/{artifact_path}?download=1"
22+
23+
24+
def download_artifact(record_id: str, artifact_path: str, expected_sha256: str | None) -> Path:
25+
cache = _cache_dir()
26+
out = cache / artifact_path
27+
28+
if out.exists() and expected_sha256 and expected_sha256 != "REPLACE_ME":
29+
if sha256_file(out) == expected_sha256.lower():
30+
return out
31+
32+
url = _zenodo_download_url(record_id, artifact_path)
33+
r = requests.get(url, stream=True, timeout=60)
34+
r.raise_for_status()
35+
36+
total = int(r.headers.get("Content-Length", 0))
37+
tmp = out.with_suffix(out.suffix + ".tmp")
38+
39+
with tmp.open("wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=f"Downloading {artifact_path}") as pbar:
40+
for chunk in r.iter_content(chunk_size=1024 * 1024):
41+
if chunk:
42+
f.write(chunk)
43+
pbar.update(len(chunk))
44+
45+
tmp.replace(out)
46+
47+
if expected_sha256 and expected_sha256 != "REPLACE_ME":
48+
got = sha256_file(out)
49+
if got.lower() != expected_sha256.lower():
50+
out.unlink(missing_ok=True)
51+
raise ValueError(
52+
f"Checksum mismatch for {artifact_path}. Expected {expected_sha256}, got {got}."
53+
)
54+
55+
return out

src/multiverse/datasets/_load.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
import zipfile
5+
6+
from aeon.datasets import load_from_tsfile
7+
8+
from multiverse.datasets._registry import get_spec
9+
from multiverse.datasets._download import download_artifact
10+
11+
12+
def load_dataset(name: str, split: str = "train"):
13+
"""Load a dataset split into aeon-compatible X, y.
14+
15+
Assumes the Zenodo artefact is a zip containing `<NAME>_TRAIN.ts` and `<NAME>_TEST.ts`.
16+
"""
17+
spec = get_spec(name)
18+
if spec.zenodo_record_id == "REPLACE_ME":
19+
raise ValueError(
20+
"Dataset registry contains placeholders. Replace zenodo_record_id and sha256 in mtsc_registry.csv."
21+
)
22+
23+
zip_path = download_artifact(spec.zenodo_record_id, spec.artifact_path, spec.sha256)
24+
25+
split = split.lower()
26+
if split not in {"train", "test"}:
27+
raise ValueError("split must be 'train' or 'test'")
28+
29+
with zipfile.ZipFile(zip_path) as zf:
30+
target = f"{spec.dataset}_{split.upper()}.ts"
31+
members = [m for m in zf.namelist() if m.endswith(target)]
32+
if not members:
33+
raise FileNotFoundError(f"Could not find {target} inside {spec.artifact_path}")
34+
ts_member = members[0]
35+
36+
extract_dir = zip_path.with_suffix("") # e.g. ~/.multiverse/datasets/BasicMotions/
37+
extract_dir.mkdir(exist_ok=True)
38+
out_path = extract_dir / Path(ts_member).name
39+
if not out_path.exists():
40+
zf.extract(ts_member, path=extract_dir)
41+
extracted = extract_dir / ts_member
42+
if extracted != out_path:
43+
extracted.replace(out_path)
44+
45+
X, y = load_from_tsfile(str(out_path))
46+
return X, y
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from importlib.resources import files
5+
import pandas as pd
6+
7+
8+
@dataclass(frozen=True)
9+
class DatasetSpec:
10+
dataset: str
11+
zenodo_record_id: str
12+
artifact_path: str
13+
sha256: str
14+
format: str
15+
notes: str
16+
licence: str
17+
18+
19+
def load_registry() -> pd.DataFrame:
20+
reg_path = files("multiverse").joinpath("datasets/mtsc_registry.csv")
21+
return pd.read_csv(reg_path)
22+
23+
24+
def get_spec(name: str) -> DatasetSpec:
25+
df = load_registry()
26+
row = df.loc[df["dataset"].str.lower() == name.lower()]
27+
if row.empty:
28+
raise KeyError(f"Unknown dataset: {name}. Use list_datasets().")
29+
r = row.iloc[0].to_dict()
30+
return DatasetSpec(**r)
31+
32+
33+
def list_datasets() -> list[str]:
34+
df = load_registry()
35+
return sorted(df["dataset"].tolist())
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dataset,zenodo_record_id,artifact_path,sha256,format,notes,licence
2+
BasicMotions,REPLACE_ME,BasicMotions.zip,REPLACE_ME,tszip,UEA MTSC classic,CC-BY-4.0
3+
NATOPS,REPLACE_ME,NATOPS.zip,REPLACE_ME,tszip,UEA MTSC classic,CC-BY-4.0

src/multiverse/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/multiverse/utils/_hash.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations
2+
3+
import hashlib
4+
from pathlib import Path
5+
6+
7+
def sha256_file(path: str | Path, chunk_size: int = 1024 * 1024) -> str:
8+
h = hashlib.sha256()
9+
p = Path(path)
10+
with p.open("rb") as f:
11+
for chunk in iter(lambda: f.read(chunk_size), b""):
12+
h.update(chunk)
13+
return h.hexdigest()

0 commit comments

Comments
 (0)