Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helm/rayservice/applications/episeg-1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
deployments:
- name: SemanticSegmentation
max_ongoing_requests: 16
max_queued_requests: 32
max_queued_requests: 128
Comment thread
Jurgee marked this conversation as resolved.
autoscaling_config:
min_replicas: 0
max_replicas: 4
Expand Down
2 changes: 1 addition & 1 deletion helm/rayservice/applications/heatmap-builder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
deployments:
- name: HeatmapBuilder
max_ongoing_requests: 16
max_queued_requests: 32
max_queued_requests: 128
Comment thread
Jurgee marked this conversation as resolved.
autoscaling_config:
min_replicas: 0
max_replicas: 4
Expand Down
2 changes: 1 addition & 1 deletion helm/rayservice/applications/prostate-classifier-1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
deployments:
- name: BinaryClassifier
max_ongoing_requests: 512
max_queued_requests: 1024
max_queued_requests: 4096
Comment thread
Jurgee marked this conversation as resolved.
autoscaling_config:
min_replicas: 0
max_replicas: 4
Expand Down
28 changes: 28 additions & 0 deletions helm/rayservice/applications/prov-gigapath.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
- name: prov-gigapath
import_path: models.prov_gigapath:app
route_prefix: /prov-gigapath
runtime_env:
config:
setup_timeout_seconds: 1800
working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip
deployments:
- name: ProvGigaPath
max_ongoing_requests: 1024
max_queued_requests: 8192
Comment thread
Jurgee marked this conversation as resolved.
autoscaling_config:
min_replicas: 0
max_replicas: 4
target_ongoing_requests: 256
ray_actor_options:
num_cpus: 4
num_gpus: 1
memory: 8589934592
runtime_env:
env_vars:
HF_HOME: /mnt/huggingface_cache
user_config:
tile_size: 224
max_batch_size: 512
batch_wait_timeout_s: 0.1
model:
repo_id: prov-gigapath/prov-gigapath
2 changes: 1 addition & 1 deletion helm/rayservice/applications/virchow2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
deployments:
- name: Virchow2
max_ongoing_requests: 1024
max_queued_requests: 2048
max_queued_requests: 8192
Comment thread
Jurgee marked this conversation as resolved.
autoscaling_config:
min_replicas: 0
max_replicas: 4
Expand Down
1 change: 1 addition & 0 deletions helm/rayservice/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ applications:
- episeg-1
- heatmap-builder
- prostate-classifier-1
- prov-gigapath
- virchow2
108 changes: 108 additions & 0 deletions models/prov_gigapath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, TypedDict

import lz4.frame
import numpy as np
from fastapi import FastAPI, Request, Response
from ray import serve


if TYPE_CHECKING:
import torch


class Config(TypedDict):
tile_size: int
model: dict[str, Any]
max_batch_size: int
batch_wait_timeout_s: float


fastapi = FastAPI()


@serve.deployment(num_replicas="auto")
@serve.ingress(fastapi)
class ProvGigaPath:
"""GigaPath tile encoder for pathology."""

model: torch.nn.Module
transforms: Any
tile_size: int

def __init__(self) -> None:
import torch

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def reconfigure(self, config: Config) -> None:
import timm
from torchvision import transforms

self.tile_size = config["tile_size"]
model_config = dict(config["model"])
repo_id = model_config["repo_id"]

self.model = timm.create_model(
f"hf_hub:{repo_id}",
pretrained=True,
)
self.model = self.model.to(self.device).eval()

# Based on the HF documentation
self.transforms = transforms.Compose(
[
transforms.Resize(
256, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
]
)

self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined]
self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined]

@serve.batch
async def predict(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]:
import torch

tensors = torch.stack(inputs).to(self.device)
with torch.inference_mode():
output = self.model(tensors)
return list(output)

@fastapi.post("/")
async def root(self, request: Request) -> Response:
from PIL import Image

data = await asyncio.to_thread(lz4.frame.decompress, await request.body())
image = np.frombuffer(data, dtype=np.uint8).reshape(
self.tile_size, self.tile_size, 3
)
Comment thread
Jurgee marked this conversation as resolved.
Comment thread
Jurgee marked this conversation as resolved.

output_dtype = np.dtype(
request.headers.get("x-output-dtype", "float32").lower()
)
Comment thread
Jurgee marked this conversation as resolved.
Comment thread
Jurgee marked this conversation as resolved.

tensor = self.transforms(Image.fromarray(image))
Comment thread
Jurgee marked this conversation as resolved.

raw_output: torch.Tensor = await self.predict(tensor)
result = raw_output.cpu().numpy().astype(output_dtype, copy=False)
output_shape = str(result.shape)

return Response(
content=lz4.frame.compress(result.tobytes()),
Comment thread
Jurgee marked this conversation as resolved.
media_type="application/octet-stream",
headers={
"x-output-shape": output_shape,
},
)


app = ProvGigaPath.bind() # type: ignore[attr-defined]
Loading