diff --git a/helm/rayservice/applications/episeg-1.yaml b/helm/rayservice/applications/episeg-1.yaml index 5677799..5c8e9db 100644 --- a/helm/rayservice/applications/episeg-1.yaml +++ b/helm/rayservice/applications/episeg-1.yaml @@ -6,7 +6,7 @@ deployments: - name: SemanticSegmentation max_ongoing_requests: 16 - max_queued_requests: 32 + max_queued_requests: 128 autoscaling_config: min_replicas: 0 max_replicas: 4 diff --git a/helm/rayservice/applications/heatmap-builder.yaml b/helm/rayservice/applications/heatmap-builder.yaml index 13b90cb..f39bf49 100644 --- a/helm/rayservice/applications/heatmap-builder.yaml +++ b/helm/rayservice/applications/heatmap-builder.yaml @@ -6,7 +6,7 @@ deployments: - name: HeatmapBuilder max_ongoing_requests: 16 - max_queued_requests: 32 + max_queued_requests: 128 autoscaling_config: min_replicas: 0 max_replicas: 4 diff --git a/helm/rayservice/applications/prostate-classifier-1.yaml b/helm/rayservice/applications/prostate-classifier-1.yaml index a177c43..6cd33b2 100644 --- a/helm/rayservice/applications/prostate-classifier-1.yaml +++ b/helm/rayservice/applications/prostate-classifier-1.yaml @@ -6,7 +6,7 @@ deployments: - name: BinaryClassifier max_ongoing_requests: 512 - max_queued_requests: 1024 + max_queued_requests: 4096 autoscaling_config: min_replicas: 0 max_replicas: 4 diff --git a/helm/rayservice/applications/prov-gigapath.yaml b/helm/rayservice/applications/prov-gigapath.yaml new file mode 100644 index 0000000..f0eb4db --- /dev/null +++ b/helm/rayservice/applications/prov-gigapath.yaml @@ -0,0 +1,28 @@ +- name: prov-gigapath + import_path: models.prov_gigapath:app + route_prefix: /prov-gigapath + runtime_env: + config: + setup_timeout_seconds: 1800 + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + deployments: + - name: ProvGigaPath + max_ongoing_requests: 1024 + max_queued_requests: 8192 + autoscaling_config: + min_replicas: 0 + max_replicas: 4 + target_ongoing_requests: 256 + ray_actor_options: + num_cpus: 4 + num_gpus: 1 + memory: 8589934592 + runtime_env: + env_vars: + HF_HOME: /mnt/huggingface_cache + user_config: + tile_size: 224 + max_batch_size: 512 + batch_wait_timeout_s: 0.1 + model: + repo_id: prov-gigapath/prov-gigapath diff --git a/helm/rayservice/applications/virchow2.yaml b/helm/rayservice/applications/virchow2.yaml index cf797d8..eaac069 100644 --- a/helm/rayservice/applications/virchow2.yaml +++ b/helm/rayservice/applications/virchow2.yaml @@ -8,7 +8,7 @@ deployments: - name: Virchow2 max_ongoing_requests: 1024 - max_queued_requests: 2048 + max_queued_requests: 8192 autoscaling_config: min_replicas: 0 max_replicas: 4 diff --git a/helm/rayservice/values.yaml b/helm/rayservice/values.yaml index b6e24b7..6e62751 100644 --- a/helm/rayservice/values.yaml +++ b/helm/rayservice/values.yaml @@ -6,4 +6,5 @@ applications: - episeg-1 - heatmap-builder - prostate-classifier-1 + - prov-gigapath - virchow2 diff --git a/models/prov_gigapath.py b/models/prov_gigapath.py new file mode 100644 index 0000000..b59ed11 --- /dev/null +++ b/models/prov_gigapath.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, TypedDict + +import lz4.frame +import numpy as np +from fastapi import FastAPI, Request, Response +from ray import serve + + +if TYPE_CHECKING: + import torch + + +class Config(TypedDict): + tile_size: int + model: dict[str, Any] + max_batch_size: int + batch_wait_timeout_s: float + + +fastapi = FastAPI() + + +@serve.deployment(num_replicas="auto") +@serve.ingress(fastapi) +class ProvGigaPath: + """GigaPath tile encoder for pathology.""" + + model: torch.nn.Module + transforms: Any + tile_size: int + + def __init__(self) -> None: + import torch + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def reconfigure(self, config: Config) -> None: + import timm + from torchvision import transforms + + self.tile_size = config["tile_size"] + model_config = dict(config["model"]) + repo_id = model_config["repo_id"] + + self.model = timm.create_model( + f"hf_hub:{repo_id}", + pretrained=True, + ) + self.model = self.model.to(self.device).eval() + + # Based on the HF documentation + self.transforms = transforms.Compose( + [ + transforms.Resize( + 256, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ), + ] + ) + + self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] + + @serve.batch + async def predict(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]: + import torch + + tensors = torch.stack(inputs).to(self.device) + with torch.inference_mode(): + output = self.model(tensors) + return list(output) + + @fastapi.post("/") + async def root(self, request: Request) -> Response: + from PIL import Image + + data = await asyncio.to_thread(lz4.frame.decompress, await request.body()) + image = np.frombuffer(data, dtype=np.uint8).reshape( + self.tile_size, self.tile_size, 3 + ) + + output_dtype = np.dtype( + request.headers.get("x-output-dtype", "float32").lower() + ) + + tensor = self.transforms(Image.fromarray(image)) + + raw_output: torch.Tensor = await self.predict(tensor) + result = raw_output.cpu().numpy().astype(output_dtype, copy=False) + output_shape = str(result.shape) + + return Response( + content=lz4.frame.compress(result.tobytes()), + media_type="application/octet-stream", + headers={ + "x-output-shape": output_shape, + }, + ) + + +app = ProvGigaPath.bind() # type: ignore[attr-defined]