Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ Model deployment infrastructure for RationAI using Ray Serve on Kubernetes.

This repository contains:

- A Helm chart (`helm/rayservice/`) that renders and deploys a KubeRay `RayService`.
- A static RayService manifest (`ray-service.yaml`) for reference/manual apply workflows.
- Model implementations under `models/` (reference: `models/binary_classifier.py`).
- Documentation under `docs/` (MkDocs).
- `builders/`: WSI output aggregation services (e.g., `heatmap_builder.py`).
- `docker/`: Dockerfiles for building CPU and GPU environments.
- `docs/`: MkDocs documentation and architecture guides.
- `helm/rayservice/`: A Helm chart that renders and deploys a KubeRay `RayService`.
- `models/`: Python entrypoints for model implementations (e.g., `binary_classifier.py`, `virchow2.py`).

## Documentation

- MkDocs content: `docs/`
- Key pages:
- `docs/available-models.md`
- `docs/get-started/quick-start.md`
- `docs/guides/deployment-guide.md`
- `docs/guides/adding-models.md`
Expand Down
122 changes: 81 additions & 41 deletions builders/heatmap_builder.py
Comment thread
Jurgee marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, TypedDict

import numpy as np
from fastapi import FastAPI
from ray import serve

Expand Down Expand Up @@ -35,19 +37,22 @@ async def root(
output_bigtiff_tile_height: int,
output_bigtiff_tile_width: int,
) -> str:
import pyvips
from ratiopath.masks.mask_builders import MaskBuilder
from ratiopath.openslide import OpenSlide
from ratiopath.tiling import grid_tiles

from misc.fetch_tissue_tile import fetch_tissue_tile
from misc.tile_heatmap_builder import TileHeatmapBuilder

model = serve.get_app_handle(model_id)
model_config = await model.get_config.remote()
stride: int = round(stride_fraction * model_config["tile_size"])
tile_size: int = model_config["tile_size"]
output_tile_size: int = model_config["output_tile_size"]
n_channels: int = model_config["n_channels"]
stride: int = round(stride_fraction * tile_size)

loop = asyncio.get_running_loop()
tasks: set[asyncio.Task[Any]] = set()

with (
OpenSlide(slide_path) as slide,
OpenSlide(tissue_mask_path) as tissue_slide,
Expand All @@ -63,50 +68,85 @@ async def root(
]
scale_x = tissue_extent_x / extent_x
scale_y = tissue_extent_y / extent_y

mask_builder = TileHeatmapBuilder(
extent_x=extent_x, extent_y=extent_y, mpp_x=mpp_x, mpp_y=mpp_y
mask_builder = MaskBuilder(
source_extents=(extent_y, extent_x),
source_tile_extent=tile_size,
output_tile_extent=output_tile_size,
stride=stride,
n_channels=n_channels,
storage="memmap",
)
try:

async def process_tile(x: int, y: int) -> None:
tile = await loop.run_in_executor(
executor,
fetch_tissue_tile,
slide,
tissue_slide,
x,
y,
level,
scale_x,
scale_y,
tissue_level,
tile_size,
)
if tile is None:
return

async def process_tile(x: int, y: int) -> None:
tile = await loop.run_in_executor(
executor,
fetch_tissue_tile,
slide,
tissue_slide,
x,
y,
level,
scale_x,
scale_y,
tissue_level,
model_config["tile_size"],
)
if tile is not None:
prediction = await model.predict.remote(tile)
mask_builder.update(prediction, x, y)

for x, y in grid_tiles(
slide_extent=(extent_x, extent_y),
tile_extent=(model_config["tile_size"], model_config["tile_size"]),
stride=(stride, stride),
):
if len(tasks) >= self.max_concurrent_tasks:
_, tasks = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
arr = np.asarray(prediction, dtype=np.float32)

if arr.ndim == 2:
batch = arr[np.newaxis, np.newaxis, ...]
elif arr.ndim == 3:
batch = arr[np.newaxis, ...]
else:
raise ValueError(f"Unexpected prediction shape: {arr.shape}")

mask_builder.update_batch(
batch=batch,
coords=np.array([[y, x]], dtype=np.int64),
)

tasks.add(asyncio.create_task(process_tile(x, y)))

await asyncio.wait(tasks)
for x, y in grid_tiles(
slide_extent=(extent_x, extent_y),
tile_extent=(tile_size, tile_size),
stride=(stride, stride),
):
if len(tasks) >= self.max_concurrent_tasks:
done, tasks = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
task.result()
tasks.add(asyncio.create_task(process_tile(x, y)))

if tasks:
done, _ = await asyncio.wait(tasks)
for task in done:
task.result()

result = np.asarray(mask_builder.finalize()["mask"])

vips_image = mask_builder.resize_to_source(result)
vips_image = (vips_image * 255).cast(pyvips.BandFormat.UCHAR)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
vips_image.tiffsave(
output_path,
bigtiff=True,
compression=pyvips.enums.ForeignTiffCompression.DEFLATE,
tile=True,
tile_width=output_bigtiff_tile_width,
tile_height=output_bigtiff_tile_height,
xres=1000 / mpp_x,
yres=1000 / mpp_y,
pyramid=True,
)
finally:
mask_builder.cleanup()

mask_builder.flush()
mask_builder.save(
output_path,
tile_height=output_bigtiff_tile_height,
tile_width=output_bigtiff_tile_width,
)
mask_builder.cleanup()
return output_path


Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ RUN sudo apt-get update && sudo apt-get -y upgrade && \
RUN sudo apt-get remove -y --purge systemd systemd-sysv && sudo apt-get autoremove --purge -y && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*

RUN pip install --no-cache-dir \
onnxruntime lz4 ratiopath "mlflow<3.0"
onnxruntime lz4 ratiopath pyvips "mlflow<3.0"
27 changes: 26 additions & 1 deletion docs/available-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ All endpoints receive and return data over HTTP using `POST` requests. To minimi
| **Prostate Classifier 1** | `/prostate-classifier-1` | Binary Classification |
| **Episeg 1** | `/episeg-1` | Semantic Segmentation |
| **Virchow2** | `/virchow2` | Foundation Model / Embeddings |
| **Prov-GigaPath** | `/prov-gigapath` | Foundation Model / Embeddings |
| **Heatmap Builder** | `/heatmap-builder` | Pipeline / Custom Builder |

---
Expand Down Expand Up @@ -80,7 +81,31 @@ with Client() as client:
print(emb.shape)
```

### 4. Heatmap Builder (`/heatmap-builder`)
### 4. Prov-GigaPath (`/prov-gigapath`)

A foundation model for pathology tile embeddings (Prov-GigaPath).

- **Input**: LZ4-compressed raw bytes of a tissue tile image (`uint8`, shape `(tile_size, tile_size, 3)`).
- **Output**: Output tensor matching the user's requested precision.
- **Headers**:
- `x-output-dtype` (optional, default: `float32`): Sets the return precision (for example `float16`).
- **SDK example**:

```python
from rationai import Client
import numpy as np

with Client() as client:
emb = client.models.embed_image(
model="prov-gigapath",
image=image,
output_dtype=np.float16,
timeout=30.0,
)
print(emb.shape)
```

### 5. Heatmap Builder (`/heatmap-builder`)

A processing pipeline element for aggregating inferences into spatial heatmaps.

Expand Down
55 changes: 55 additions & 0 deletions docs/guides/adding-models.md
Comment thread
Jurgee marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,22 @@ What happens with data:
- Output tensor is flattened and returned as a Python list.
- Ray Serve maps each list item back to the original HTTP request.

### `get_config`: expose model settings for builders

If your model is used by any Whole-Slide Inference builder (for example `HeatmapBuilder`), you must provide a `get_config` method that builders can call through a Serve handle. The builder uses this to read `tile_size`, `output_tile_size`, `n_channels`, and `mpp` so it can pick the right tiling grid and resolution.

```python
async def get_config(self) -> dict[str, Any]:
return {
"tile_size": self.tile_size,
"output_tile_size": self.output_tile_size,
"n_channels": self.n_channels,
"mpp": self.mpp,
}
```

The builder calls it with `await model.get_config.remote()`; keep it cheap and avoid any I/O.

### `root`: HTTP request parsing and serialization

```python
Expand Down Expand Up @@ -172,6 +188,45 @@ app = MyModel.bind()

This exported symbol is what `import_path: models.my_model:app` points to in Helm.

### Using Foundation Models from Your Model

If you are deploying a model that is a downstream head (e.g., an MLP or Attention layer trained on top of a foundation model like Virchow2 or Prov-GigaPath), you **do not** need to re-export the entire foundation model into your ONNX artifact.

Because foundation models are already deployed as independent services within the cluster, your model can directly invoke them via Ray Serve handles. In your `root` or `predict` method, call the foundation model first, then pass its output to your custom layers. If you call a foundation model's `predict` method directly, do **not** pass a raw `np.ndarray` image; first apply the same preprocessing/transforms that deployment expects, or call the model's ingress/request path instead.

```python
# In your __init__ or reconfigure method:
from ray import serve
self.foundation_model = serve.get_app_handle("virchow2")
self.foundation_transform = build_virchow2_transform()

# In your request handler:
@fastapi.post("/")
async def root(self, request: Request):
# 1. Fetch raw image bytes from request and decode to an image / np.ndarray
...

# 2. Apply the same transforms used by the foundation model deployment
image_tensor = self.foundation_transform(image)
if image_tensor.ndim == 3:
image_tensor = image_tensor.unsqueeze(0)

# 3. Call the foundation model with the transformed tensor batch
embedding = await self.foundation_model.predict.remote(image_tensor)

# 4. Pass the embedding to your own model's predict endpoint
return await self.predict(embedding)
```

## Whole-Slide (WSI) Inference and Output Builders

When predicting on an entire Whole-Slide Image (WSI):

1. **Heatmaps:** If your model's WSI output should be a spatial heatmap (e.g., probability maps or segmentation masks overlaying the WSI), you **do not need to implement WSI logic**. The cluster already provides a universal `HeatmapBuilder` service (running under `/heatmap-builder`). Users can pass your model's ID to the heatmap builder via the SDK, and it will tile the image, aggregate all localized predictions seamlessly, and output a multi-resolution BigTIFF mask automatically.

2. **Custom WSI Aggregations (Non-Heatmap Outputs):** If your model generates something else across the entire slide (for example, a single slide-level scalar score, diagnostic classification, custom tabular statistics, embedded feature bags), you must **implement your own WSI aggregator service**. You should create a custom Application (similar to `HeatmapBuilder`) that takes paths to WSI files, iterates through the WSI tiles querying your base model for each tile, and correctly aggregates the results into your desired slide-level output format.


## Next Step

After the Python entrypoint is ready, continue with the [Deployment Guide](deployment-guide.md). That guide covers the Helm application YAML, deployment, rollout monitoring, and smoke testing in the order you should run them.
Expand Down
2 changes: 2 additions & 0 deletions docs/guides/deployment-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Create a file in `helm/rayservice/applications/` (for example `my-model.yaml`) w
max_replicas: 4
```

- Add the new application file name to `helm/rayservice/values.yaml` under `applications`.

Notes:

- Use a dedicated branch in `working_dir` during development.
Expand Down
2 changes: 2 additions & 0 deletions helm/rayservice/applications/episeg-1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
MLFLOW_TRACKING_URI: http://mlflow.rationai-mlflow:5000
user_config:
tile_size: 1024
output_tile_size: 1024
n_channels: 1
mpp: 0.468
max_batch_size: 8
batch_wait_timeout_s: 0.1
Expand Down
2 changes: 1 addition & 1 deletion helm/rayservice/workers/cpu-workers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ template:
type: RuntimeDefault
containers:
- name: ray-worker
image: cerit.io/rationai/model-service:2.54.0
image: cerit.io/rationai/model-service:2.55.0
imagePullPolicy: Always
resources:
limits:
Expand Down
Loading
Loading