diff --git a/README.md b/README.md index 568213c..4288949 100644 --- a/README.md +++ b/README.md @@ -4,15 +4,17 @@ Model deployment infrastructure for RationAI using Ray Serve on Kubernetes. This repository contains: -- A Helm chart (`helm/rayservice/`) that renders and deploys a KubeRay `RayService`. -- A static RayService manifest (`ray-service.yaml`) for reference/manual apply workflows. -- Model implementations under `models/` (reference: `models/binary_classifier.py`). -- Documentation under `docs/` (MkDocs). +- `builders/`: WSI output aggregation services (e.g., `heatmap_builder.py`). +- `docker/`: Dockerfiles for building CPU and GPU environments. +- `docs/`: MkDocs documentation and architecture guides. +- `helm/rayservice/`: A Helm chart that renders and deploys a KubeRay `RayService`. +- `models/`: Python entrypoints for model implementations (e.g., `binary_classifier.py`, `virchow2.py`). ## Documentation - MkDocs content: `docs/` - Key pages: + - `docs/available-models.md` - `docs/get-started/quick-start.md` - `docs/guides/deployment-guide.md` - `docs/guides/adding-models.md` diff --git a/builders/heatmap_builder.py b/builders/heatmap_builder.py index 5b40636..8b89d31 100644 --- a/builders/heatmap_builder.py +++ b/builders/heatmap_builder.py @@ -1,7 +1,9 @@ import asyncio from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from typing import Any, TypedDict +import numpy as np from fastapi import FastAPI from ray import serve @@ -35,19 +37,22 @@ async def root( output_bigtiff_tile_height: int, output_bigtiff_tile_width: int, ) -> str: + import pyvips + from ratiopath.masks.mask_builders import MaskBuilder from ratiopath.openslide import OpenSlide from ratiopath.tiling import grid_tiles from misc.fetch_tissue_tile import fetch_tissue_tile - from misc.tile_heatmap_builder import TileHeatmapBuilder model = serve.get_app_handle(model_id) model_config = await model.get_config.remote() - stride: int = round(stride_fraction * model_config["tile_size"]) + tile_size: int = model_config["tile_size"] + output_tile_size: int = model_config["output_tile_size"] + n_channels: int = model_config["n_channels"] + stride: int = round(stride_fraction * tile_size) loop = asyncio.get_running_loop() tasks: set[asyncio.Task[Any]] = set() - with ( OpenSlide(slide_path) as slide, OpenSlide(tissue_mask_path) as tissue_slide, @@ -63,50 +68,85 @@ async def root( ] scale_x = tissue_extent_x / extent_x scale_y = tissue_extent_y / extent_y - - mask_builder = TileHeatmapBuilder( - extent_x=extent_x, extent_y=extent_y, mpp_x=mpp_x, mpp_y=mpp_y + mask_builder = MaskBuilder( + source_extents=(extent_y, extent_x), + source_tile_extent=tile_size, + output_tile_extent=output_tile_size, + stride=stride, + n_channels=n_channels, + storage="memmap", ) + try: + + async def process_tile(x: int, y: int) -> None: + tile = await loop.run_in_executor( + executor, + fetch_tissue_tile, + slide, + tissue_slide, + x, + y, + level, + scale_x, + scale_y, + tissue_level, + tile_size, + ) + if tile is None: + return - async def process_tile(x: int, y: int) -> None: - tile = await loop.run_in_executor( - executor, - fetch_tissue_tile, - slide, - tissue_slide, - x, - y, - level, - scale_x, - scale_y, - tissue_level, - model_config["tile_size"], - ) - if tile is not None: prediction = await model.predict.remote(tile) - mask_builder.update(prediction, x, y) - - for x, y in grid_tiles( - slide_extent=(extent_x, extent_y), - tile_extent=(model_config["tile_size"], model_config["tile_size"]), - stride=(stride, stride), - ): - if len(tasks) >= self.max_concurrent_tasks: - _, tasks = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED + arr = np.asarray(prediction, dtype=np.float32) + + if arr.ndim == 2: + batch = arr[np.newaxis, np.newaxis, ...] + elif arr.ndim == 3: + batch = arr[np.newaxis, ...] + else: + raise ValueError(f"Unexpected prediction shape: {arr.shape}") + + mask_builder.update_batch( + batch=batch, + coords=np.array([[y, x]], dtype=np.int64), ) - tasks.add(asyncio.create_task(process_tile(x, y))) - - await asyncio.wait(tasks) + for x, y in grid_tiles( + slide_extent=(extent_x, extent_y), + tile_extent=(tile_size, tile_size), + stride=(stride, stride), + ): + if len(tasks) >= self.max_concurrent_tasks: + done, tasks = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done: + task.result() + tasks.add(asyncio.create_task(process_tile(x, y))) + + if tasks: + done, _ = await asyncio.wait(tasks) + for task in done: + task.result() + + result = np.asarray(mask_builder.finalize()["mask"]) + + vips_image = mask_builder.resize_to_source(result) + vips_image = (vips_image * 255).cast(pyvips.BandFormat.UCHAR) + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + vips_image.tiffsave( + output_path, + bigtiff=True, + compression=pyvips.enums.ForeignTiffCompression.DEFLATE, + tile=True, + tile_width=output_bigtiff_tile_width, + tile_height=output_bigtiff_tile_height, + xres=1000 / mpp_x, + yres=1000 / mpp_y, + pyramid=True, + ) + finally: + mask_builder.cleanup() - mask_builder.flush() - mask_builder.save( - output_path, - tile_height=output_bigtiff_tile_height, - tile_width=output_bigtiff_tile_width, - ) - mask_builder.cleanup() return output_path diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 9f778f5..56397ef 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -52,4 +52,4 @@ RUN sudo apt-get update && sudo apt-get -y upgrade && \ RUN sudo apt-get remove -y --purge systemd systemd-sysv && sudo apt-get autoremove --purge -y && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/* RUN pip install --no-cache-dir \ - onnxruntime lz4 ratiopath "mlflow<3.0" \ No newline at end of file + onnxruntime lz4 ratiopath pyvips "mlflow<3.0" \ No newline at end of file diff --git a/docs/available-models.md b/docs/available-models.md index ca7a0e1..dc593ff 100644 --- a/docs/available-models.md +++ b/docs/available-models.md @@ -11,6 +11,7 @@ All endpoints receive and return data over HTTP using `POST` requests. To minimi | **Prostate Classifier 1** | `/prostate-classifier-1` | Binary Classification | | **Episeg 1** | `/episeg-1` | Semantic Segmentation | | **Virchow2** | `/virchow2` | Foundation Model / Embeddings | +| **Prov-GigaPath** | `/prov-gigapath` | Foundation Model / Embeddings | | **Heatmap Builder** | `/heatmap-builder` | Pipeline / Custom Builder | --- @@ -80,7 +81,31 @@ with Client() as client: print(emb.shape) ``` -### 4. Heatmap Builder (`/heatmap-builder`) +### 4. Prov-GigaPath (`/prov-gigapath`) + +A foundation model for pathology tile embeddings (Prov-GigaPath). + +- **Input**: LZ4-compressed raw bytes of a tissue tile image (`uint8`, shape `(tile_size, tile_size, 3)`). +- **Output**: Output tensor matching the user's requested precision. +- **Headers**: + - `x-output-dtype` (optional, default: `float32`): Sets the return precision (for example `float16`). +- **SDK example**: + +```python +from rationai import Client +import numpy as np + +with Client() as client: + emb = client.models.embed_image( + model="prov-gigapath", + image=image, + output_dtype=np.float16, + timeout=30.0, + ) + print(emb.shape) +``` + +### 5. Heatmap Builder (`/heatmap-builder`) A processing pipeline element for aggregating inferences into spatial heatmaps. diff --git a/docs/guides/adding-models.md b/docs/guides/adding-models.md index 440ec92..e736e3b 100644 --- a/docs/guides/adding-models.md +++ b/docs/guides/adding-models.md @@ -141,6 +141,22 @@ What happens with data: - Output tensor is flattened and returned as a Python list. - Ray Serve maps each list item back to the original HTTP request. +### `get_config`: expose model settings for builders + +If your model is used by any Whole-Slide Inference builder (for example `HeatmapBuilder`), you must provide a `get_config` method that builders can call through a Serve handle. The builder uses this to read `tile_size`, `output_tile_size`, `n_channels`, and `mpp` so it can pick the right tiling grid and resolution. + +```python +async def get_config(self) -> dict[str, Any]: + return { + "tile_size": self.tile_size, + "output_tile_size": self.output_tile_size, + "n_channels": self.n_channels, + "mpp": self.mpp, + } +``` + +The builder calls it with `await model.get_config.remote()`; keep it cheap and avoid any I/O. + ### `root`: HTTP request parsing and serialization ```python @@ -172,6 +188,45 @@ app = MyModel.bind() This exported symbol is what `import_path: models.my_model:app` points to in Helm. +### Using Foundation Models from Your Model + +If you are deploying a model that is a downstream head (e.g., an MLP or Attention layer trained on top of a foundation model like Virchow2 or Prov-GigaPath), you **do not** need to re-export the entire foundation model into your ONNX artifact. + +Because foundation models are already deployed as independent services within the cluster, your model can directly invoke them via Ray Serve handles. In your `root` or `predict` method, call the foundation model first, then pass its output to your custom layers. If you call a foundation model's `predict` method directly, do **not** pass a raw `np.ndarray` image; first apply the same preprocessing/transforms that deployment expects, or call the model's ingress/request path instead. + +```python +# In your __init__ or reconfigure method: +from ray import serve +self.foundation_model = serve.get_app_handle("virchow2") +self.foundation_transform = build_virchow2_transform() + +# In your request handler: +@fastapi.post("/") +async def root(self, request: Request): + # 1. Fetch raw image bytes from request and decode to an image / np.ndarray + ... + + # 2. Apply the same transforms used by the foundation model deployment + image_tensor = self.foundation_transform(image) + if image_tensor.ndim == 3: + image_tensor = image_tensor.unsqueeze(0) + + # 3. Call the foundation model with the transformed tensor batch + embedding = await self.foundation_model.predict.remote(image_tensor) + + # 4. Pass the embedding to your own model's predict endpoint + return await self.predict(embedding) +``` + +## Whole-Slide (WSI) Inference and Output Builders + +When predicting on an entire Whole-Slide Image (WSI): + +1. **Heatmaps:** If your model's WSI output should be a spatial heatmap (e.g., probability maps or segmentation masks overlaying the WSI), you **do not need to implement WSI logic**. The cluster already provides a universal `HeatmapBuilder` service (running under `/heatmap-builder`). Users can pass your model's ID to the heatmap builder via the SDK, and it will tile the image, aggregate all localized predictions seamlessly, and output a multi-resolution BigTIFF mask automatically. + +2. **Custom WSI Aggregations (Non-Heatmap Outputs):** If your model generates something else across the entire slide (for example, a single slide-level scalar score, diagnostic classification, custom tabular statistics, embedded feature bags), you must **implement your own WSI aggregator service**. You should create a custom Application (similar to `HeatmapBuilder`) that takes paths to WSI files, iterates through the WSI tiles querying your base model for each tile, and correctly aggregates the results into your desired slide-level output format. + + ## Next Step After the Python entrypoint is ready, continue with the [Deployment Guide](deployment-guide.md). That guide covers the Helm application YAML, deployment, rollout monitoring, and smoke testing in the order you should run them. diff --git a/docs/guides/deployment-guide.md b/docs/guides/deployment-guide.md index 56dc372..6dad494 100644 --- a/docs/guides/deployment-guide.md +++ b/docs/guides/deployment-guide.md @@ -49,6 +49,8 @@ Create a file in `helm/rayservice/applications/` (for example `my-model.yaml`) w max_replicas: 4 ``` +- Add the new application file name to `helm/rayservice/values.yaml` under `applications`. + Notes: - Use a dedicated branch in `working_dir` during development. diff --git a/helm/rayservice/applications/episeg-1.yaml b/helm/rayservice/applications/episeg-1.yaml index 5c8e9db..66a08af 100644 --- a/helm/rayservice/applications/episeg-1.yaml +++ b/helm/rayservice/applications/episeg-1.yaml @@ -20,6 +20,8 @@ MLFLOW_TRACKING_URI: http://mlflow.rationai-mlflow:5000 user_config: tile_size: 1024 + output_tile_size: 1024 + n_channels: 1 mpp: 0.468 max_batch_size: 8 batch_wait_timeout_s: 0.1 diff --git a/helm/rayservice/workers/cpu-workers.yaml b/helm/rayservice/workers/cpu-workers.yaml index 0c98d85..eaa044b 100644 --- a/helm/rayservice/workers/cpu-workers.yaml +++ b/helm/rayservice/workers/cpu-workers.yaml @@ -11,7 +11,7 @@ template: type: RuntimeDefault containers: - name: ray-worker - image: cerit.io/rationai/model-service:2.54.0 + image: cerit.io/rationai/model-service:2.55.0 imagePullPolicy: Always resources: limits: diff --git a/misc/tile_heatmap_builder.py b/misc/tile_heatmap_builder.py deleted file mode 100644 index 843536c..0000000 --- a/misc/tile_heatmap_builder.py +++ /dev/null @@ -1,75 +0,0 @@ -import tempfile -from pathlib import Path - -import numpy as np -import pyvips - - -class TileHeatmapBuilder: - def __init__( - self, extent_x: int, extent_y: int, mpp_x: float, mpp_y: float - ) -> None: - self.extent_x = extent_x - self.extent_y = extent_y - self.mpp_x = mpp_x - self.mpp_y = mpp_y - - # Create temporary files - self.temp_dir = tempfile.TemporaryDirectory() - self.image_path = Path(self.temp_dir.name) / "image.dat" - self.count_path = Path(self.temp_dir.name) / "count.dat" - - self.image = np.memmap( - str(self.image_path), - dtype=np.float32, - mode="w+", - shape=(self.extent_y, self.extent_x), - ) - - self.count = np.memmap( - str(self.count_path), - dtype=np.uint8, - mode="w+", - shape=(self.extent_y, self.extent_x), - ) - - def update(self, tile: np.ndarray, x: int, y: int) -> None: - mm_y, mm_x = self.image[y : y + tile.shape[0], x : x + tile.shape[1]].shape - self.image[y : y + mm_y, x : x + mm_x] += tile[:mm_y, :mm_x] - self.count[y : y + mm_y, x : x + mm_x] += 1 - - def flush(self) -> None: - self.image.flush() - self.count.flush() - - def save(self, output_path: str, tile_width: int, tile_height: int) -> None: - image_vips = pyvips.Image.new_from_array(self.image) - count_vips = pyvips.Image.new_from_array(self.count) - - image_vips /= count_vips - image_vips *= 255 - image_vips = image_vips.cast(pyvips.BandFormat.UCHAR) - - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - image_vips.tiffsave( - output_path, - bigtiff=True, - compression=pyvips.enums.ForeignTiffCompression.DEFLATE, - tile=True, - tile_width=tile_width, - tile_height=tile_height, - xres=1000 / self.mpp_x, - yres=1000 / self.mpp_y, - pyramid=True, - ) - - def cleanup(self) -> None: - if hasattr(self, "image"): - del self.image - if hasattr(self, "count"): - del self.count - - self.temp_dir.cleanup() - - def __del__(self) -> None: - self.cleanup() diff --git a/models/semantic_segmentation.py b/models/semantic_segmentation.py index cc1611b..67242c7 100644 --- a/models/semantic_segmentation.py +++ b/models/semantic_segmentation.py @@ -9,6 +9,8 @@ class Config(TypedDict): tile_size: int + output_tile_size: int + n_channels: int mpp: float model: dict[str, Any] max_batch_size: int @@ -29,6 +31,8 @@ class SemanticSegmentation: """Semantic segmentation for tissue tiles using ONNX Runtime with GPU and TensorRT support.""" tile_size: int + output_tile_size: int + n_channels: int def __init__(self) -> None: import lz4.frame @@ -42,6 +46,8 @@ def reconfigure(self, config: Config) -> None: import onnxruntime as ort self.tile_size = config["tile_size"] + self.output_tile_size = config["output_tile_size"] + self.n_channels = config["n_channels"] self.mpp = config["mpp"] cache_path = config["trt_cache_path"] @@ -116,8 +122,13 @@ def reconfigure(self, config: Config) -> None: self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] def get_config(self) -> dict[str, Any]: - """Return the current configuration (tile size and mpp).""" - return {"tile_size": self.tile_size, "mpp": self.mpp} + """Return the current configuration for builders.""" + return { + "tile_size": self.tile_size, + "output_tile_size": self.output_tile_size, + "n_channels": self.n_channels, + "mpp": self.mpp, + } @serve.batch async def predict(