diff --git a/app.py b/app.py index 1141b62..88e94d6 100644 --- a/app.py +++ b/app.py @@ -252,13 +252,24 @@ def __init__(self, default_runtime: NanoTTSService) -> None: self._lock = threading.Lock() self._cpu_execution_lock = threading.Lock() self._cpu_runtime: NanoTTSService | None = None + self._cuda_runtimes: dict[str, NanoTTSService] = {} @staticmethod def normalize_requested_execution_device(requested: str | None) -> str: normalized = str(requested or "default").strip().lower() - if normalized not in {"default", "cpu"}: - return "default" - return normalized + if normalized in {"default", "cpu"}: + return normalized + # 允许 "cuda" 或 "cuda:N" 格式 + if normalized == "cuda" or normalized.startswith("cuda:"): + parts = normalized.split(":", 1) + if len(parts) == 2: + try: + int(parts[1]) + return normalized + except ValueError: + return "default" + return "cuda" + return "default" def is_dedicated_cpu_request(self, requested: str | None) -> bool: normalized = self.normalize_requested_execution_device(requested) @@ -282,14 +293,37 @@ def _build_cpu_runtime_locked(self) -> NanoTTSService: ) return self._cpu_runtime + def _build_cuda_runtime_locked(self, device: str) -> NanoTTSService: + if device in self._cuda_runtimes: + return self._cuda_runtimes[device] + self._cuda_runtimes[device] = NanoTTSService( + checkpoint_path=self.default_runtime.checkpoint_path, + audio_tokenizer_path=self.default_runtime.audio_tokenizer_path, + device=device, + dtype=self.default_runtime.dtype or "auto", + attn_implementation=self.default_runtime.attn_implementation or "auto", + output_dir=self.default_runtime.output_dir, + voice_presets=self.default_runtime.voice_presets, + ) + return self._cuda_runtimes[device] + def resolve_runtime(self, requested: str | None) -> tuple[NanoTTSService, str]: normalized = self.normalize_requested_execution_device(requested) - if normalized != "cpu": + if normalized == "default": return self.default_runtime, str(self.default_runtime.device.type) - if self.default_runtime.device.type == "cpu": - return self.default_runtime, "cpu" - with self._lock: - return self._build_cpu_runtime_locked(), "cpu" + if normalized == "cpu": + if self.default_runtime.device.type == "cpu": + return self.default_runtime, "cpu" + with self._lock: + return self._build_cpu_runtime_locked(), "cpu" + # 请求 CUDA 设备 + if normalized.startswith("cuda"): + target_device = normalized + if self.default_runtime.device.type == "cuda": + return self.default_runtime, str(self.default_runtime.device) + with self._lock: + return self._build_cuda_runtime_locked(target_device), target_device + return self.default_runtime, str(self.default_runtime.device.type) def _resolve_cpu_threads(self, cpu_threads: int | None) -> int: if cpu_threads is None: @@ -657,6 +691,20 @@ async def _persist_uploaded_prompt_audio(upload: UploadFile | None) -> tuple[str return temp_path, _format_uploaded_prompt_display_name(original_filename) +def _build_cuda_options_html(cuda_available: bool, runtime_device: str) -> str: + if not cuda_available: + return "" + import torch + parts = [] + for i in range(torch.cuda.device_count()): + device_name = f"cuda:{i}" + label = f"CUDA:{i}" + if runtime_device == device_name: + label += " (runtime)" + parts.append(f' ') + return "\n".join(parts) + + def _render_index_html( *, request: Request, @@ -664,6 +712,7 @@ def _render_index_html( demo_entries: list[DemoEntry], warmup_status: str, text_normalization_status: str, + cuda_available: bool = False, ) -> str: base_path = request.scope.get("root_path", "").rstrip("/") template = """ @@ -1112,12 +1161,22 @@ def _render_index_html( Buffered generation keeps chunk order and decodes codec sub-batches no larger than the current TTS batch. Realtime Streaming Decode keeps output order and uses the smallest active chunk-group width among auto batching, Max TTS Batch Size, and Max Codec Batch Size. -