diff --git a/docs/learn/get-started/quick-start/tiled-dataset.md b/docs/learn/get-started/quick-start/tiled-dataset.md index 82a1143..5e7a5f7 100644 --- a/docs/learn/get-started/quick-start/tiled-dataset.md +++ b/docs/learn/get-started/quick-start/tiled-dataset.md @@ -5,34 +5,19 @@ In the [tiling tutorial](./tiling.md), we discussed how to create a tiled datase Before writing the data loaders, we must consider the structure and approximate size of our dataset. A tiled dataset typically consists of two highly interrelated components: 1. **The Parent Dataset:** Contains high-level metadata about the source files (e.g., Whole Slide Image file paths, original dimensions, patient IDs, or slide-level labels). -2. **The Tile Dataset:** Contains metadata about the individual chunks derived from those parents (e.g., $x$ and $y$ coordinates, the parent `slide_id`, and sometimes precomputed tile embeddings). +2. **The Tile Dataset:** Contains metadata about the individual chunks derived from those parents (e.g., `x` and `y` coordinates, the parent `slide_id`, and sometimes precomputed tile embeddings). -If your Parquet files are small enough, you can safely load the entire dataset into RAM using Pandas. However, in digital pathology and large-scale computer vision, a tile dataset can easily span hundreds of gigabytes across multiple Parquet partitions. Loading this entirely into memory will crash your system. Instead, we rely on **lazy loading** techniques to fetch only the necessary data points exactly when the model needs them. +If your Parquet files are small enough, you can safely load the entire dataset into RAM using Pandas. However, in digital pathology and large-scale computer vision, a tile dataset can easily span hundreds of gigabytes across multiple Parquet partitions. Loading this entirely into memory will crash your system. ---- +Let's build our data loading pipeline from the ground up to handle this efficiently. -## 1. Lazy Loading: The Hugging Face `datasets` Backend +----- -To handle massive tabular metadata, we use the Hugging Face [datasets](https://huggingface.co/docs/datasets/index) library. It is vastly superior to standard Pandas DataFrames for deep learning data loaders because of how it handles memory. +## 1. The Core Building Block: `TileDataset` -**How it works:** -Parquet is a heavily compressed, columnar storage format. It is great for saving disk space but terrible for the random row access required by PyTorch (e.g., `dataset[idx]`). - -When you load a Parquet file using Hugging Face `datasets`, the library translates the Parquet data into an uncompressed **Apache Arrow** format on your disk. It then utilizes **memory mapping** (`mmap`) to treat that file on your hard drive as if it were in your RAM. - -**Why it is efficient:** - -* **Zero RAM Overhead:** You can interact with a 200GB dataset while consuming mere megabytes of actual RAM. -* **$O(1)$ Random Access:** Reading a specific row is virtually instantaneous. -* **Smart Caching:** When you filter the dataset to find tiles belonging to a specific slide, Hugging Face streams the data, finds the matches, and caches the result on disk. - ---- - -## 2. The Tile Dataset (Reading Individual Tiles) - -At the lowest level, we need a standard PyTorch `Dataset` that takes a subset of our tiled data eg. fetches the actual pixel data, or the precomputed embeddings. +At the lowest level, we need a standard PyTorch `Dataset` that represents a single Whole Slide Image. Its job is simple: take a list of tile coordinates and fetch the actual pixel data (or precomputed embeddings) for those coordinates. -In our WSI use case, we use the `openslide` library to dynamically read pixel patches from the WSIs based on the $x$ and $y$ coordinates stored in our Arrow-mapped tile dataset. +In our WSI use case, we use the `openslide` library to dynamically read pixel patches from the WSIs based on the `x` and `y` metadata. ```python from pathlib import Path @@ -57,7 +42,7 @@ class TileDataset(Dataset): ) -> None: super().__init__() self.slide_path = Path(slide_path) - self.tiles = tiles + self.tiles = tiles # We will discuss how to efficiently provide this next self.level = level self.extent_x = extent_x self.extent_y = extent_y @@ -77,16 +62,34 @@ class TileDataset(Dataset): ``` ---- +Notice that our `TileDataset` expects a `tiles` object containing the metadata. If we pass a standard Pandas DataFrame here, our RAM will quickly max out as we scale up to thousands of slides. + +----- + +## 2. Managing the Metadata: The Hugging Face `datasets` Backend + +To feed our `TileDataset` without crashing our system, we use the Hugging Face [datasets](https://huggingface.co/docs/datasets/index) library. It acts as our `HFDataset` type hint above and is vastly superior to standard Pandas DataFrames for deep learning because of how it handles memory via **lazy loading**. + +**How it works:** +Parquet is a heavily compressed, columnar storage format. It is great for saving disk space but terrible for the random row access required by PyTorch (`dataset[idx]`). When you load a Parquet file using Hugging Face `datasets`, the library translates the Parquet data into an uncompressed **Apache Arrow** format on your disk. It then utilizes **memory mapping** (`mmap`) to treat that file on your hard drive as if it were in your RAM. + +**Why it is efficient:** + + * **Zero RAM Overhead:** You can interact with a 200GB dataset while consuming mere megabytes of actual RAM. + * **O(1) Random Access:** Reading a specific row coordinate for our `TileDataset` is virtually instantaneous. + * **Smart Caching:** When you filter the massive tile dataset to find only the chunks belonging to a specific slide, Hugging Face streams the data, finds the matches, and caches the view on disk. -## 3. The Main Torch Dataset (Linking Slides and Tiles) +----- -Now we need a unified approach that combines our parent dataset (the slides) with our tile dataset (the patches). We achieve this through **relative tile splitting**—iterating through the parent metadata and dynamically filtering the massive tile dataset to extract only the chunks relative to that specific parent. +## 3. The Orchestrator: `SlideDataset` -By utilizing PyTorch's `ConcatDataset`, we can seamlessly chain these individual `SlideTileDataset` instances together into one massive, unified training set. +Now we need a unified approach that links our parent metadata (the slides) with our lazily-loaded tile metadata (the patches). We achieve this through **relative tile splitting**—iterating through the parent metadata and dynamically filtering the massive Hugging Face tile dataset to extract only the chunks relative to that specific slide. + +By utilizing PyTorch's `ConcatDataset`, we can seamlessly chain our individual `TileDataset` instances together into one massive, unified training set. ```python -from datasets import load_dataset +import pyarrow.compute as pc +from datasets import load_dataset, Dataset as HFDataset from torch.utils.data import ConcatDataset @@ -98,8 +101,12 @@ class SlideDataset(ConcatDataset[TileDataset]): slides_parquet_path: str, tiles_parquet_path: str, ) -> None: - slides_dataset = load_dataset("parquet", data_files=slides_parquet_path, split="train") # Train is default split name for Hugging Face datasets, even if we don't have multiple splits - tiles_dataset = load_dataset("parquet", data_files=tiles_parquet_path, split="train") + # 'train' is the default split name for Hugging Face datasets. + self.slides_dataset = load_dataset("parquet", data_files=slides_parquet_path, split="train") + # Sort by slide_id for much faster filtering + self.tiles_dataset = load_dataset("parquet", data_files=tiles_parquet_path, split="train").sort("slide_id") + + self._slide_id_to_indices = self._build_tile_index(self.tiles_dataset) datasets = [ TileDataset( @@ -107,20 +114,65 @@ class SlideDataset(ConcatDataset[TileDataset]): level=slide["level"], extent_x=slide["extent_x"], extent_y=slide["extent_y"], - tiles=tiles_dataset.filter( - lambda row: row["slide_id"] == slide["slide_id"], - keep_in_memory=False, - ), + tiles=self.filter_tiles_by_slide(slide["slide_id"]), ) - for slide in slides_dataset + for slide in self.slides_dataset ] super().__init__(datasets) + + + @staticmethod + def _build_tile_index(tiles: HFDataset) -> dict[str, range]: + """Creates a fast lookup table for slide indices. + + This function builds a mapping from `slide_id` to the range of indices in the + `tiles` dataset that correspond to that slide. It assumes the dataset is sorted. + + Args: + tiles: A dataset containing a `slide_id` column, sorted by `slide_id`. + + Returns: + A dictionary mapping each `slide_id` to a range of indices. + """ + if len(tiles) == 0: + return {} + + # Get the underlying Arrow table (zero-copy) + table = tiles.data.table + slide_ids = table.column("slide_id") + + # Since the dataset is sorted by 'slide_id', we can use + # run-end encoding to find group boundaries efficiently. + run_ends = pc.run_end_encode(slide_ids) + + values = run_ends.values + ends = run_ends.run_ends + + index_map = {} + current_offset = 0 + + for sid, end in zip(values, ends): + end_py = end.as_py() + index_map[sid.as_py()] = range(current_offset, end_py) + current_offset = end_py + + return index_map + + def filter_tiles_by_slide(self, slide_id: str) -> HFDataset: + """Returns a view of the dataset using a slice or indices. + + This uses the precomputed `_slide_id_to_indices` mapping to efficiently + retrieve the relevant tiles without copying data. + """ + tile_range = self._slide_id_to_indices.get(slide_id, range(0)) + return self.tiles_dataset.select(tile_range) + ``` ### Using the Dataset -Once constructed, you can pass this `ConcatDataset` directly into a standard PyTorch `DataLoader`. PyTorch will automatically calculate the cumulative length and map global batch indices to the correct underlying slide and tile. +Once constructed, you can pass this `SlideDataset` directly into a standard PyTorch `DataLoader`. PyTorch will automatically calculate the cumulative length and map global batch indices to the correct underlying slide and tile. ```python from torch.utils.data import DataLoader @@ -138,5 +190,4 @@ dataloader = DataLoader( shuffle=True, num_workers=8 ) - ``` \ No newline at end of file diff --git a/ratiopath/ray/aggregate/tensor_mean.py b/ratiopath/ray/aggregate/tensor_mean.py index 44b090b..cd9a452 100644 --- a/ratiopath/ray/aggregate/tensor_mean.py +++ b/ratiopath/ray/aggregate/tensor_mean.py @@ -48,10 +48,8 @@ class TensorMean(AggregateFnV2[dict, np.ndarray | float]): ... ) >>> # 1. Global Mean (axis=None) -> Result: 2.0 >>> ds.aggregate(TensorMean(on="m", axis=None)) - >>> >>> # 2. Batch Mean (axis=0) -> Result: np.array([[2, 2], [2, 2]]) >>> ds.aggregate(TensorMean(on="m", axis=0)) - >>> >>> # 3. Mean across Batch and Rows (axis=(0, 1)) -> Result: np.array([2, 2]) >>> ds.aggregate(TensorMean(on="m", axis=(0, 1))) """ diff --git a/ratiopath/ray/aggregate/tensor_std.py b/ratiopath/ray/aggregate/tensor_std.py index 554b9a2..1c11722 100644 --- a/ratiopath/ray/aggregate/tensor_std.py +++ b/ratiopath/ray/aggregate/tensor_std.py @@ -52,11 +52,9 @@ class TensorStd(AggregateFnV2[dict, np.ndarray | float]): ... ) >>> # 1. Global Std (axis=None) -> All elements reduced to one scalar >>> ds.aggregate(TensorStd(on="m", axis=None)) - >>> >>> # 2. Batch Std (axis=0) -> Result is a 2x2 matrix of std values >>> # calculated across the dataset rows. >>> ds.aggregate(TensorStd(on="m", axis=0)) - >>> >>> # 3. Int shorthand (axis=1) -> Internally uses axis=(0, 1) >>> # Collapses batch and the first dimension of the tensor. >>> ds.aggregate(TensorStd(on="m", axis=1))