From 5021eaa9c7de9adf23a849fed5284824ff2e2c1d Mon Sep 17 00:00:00 2001 From: EruditionHerta Date: Sat, 18 Apr 2026 18:56:18 +0800 Subject: [PATCH] feat: add GPU inference option to web UI Allow users to select CUDA device for inference in the web interface. Previously the app was forced to CPU-only mode. Now supports: - `--device cuda` to start on GPU - `--device auto` to auto-detect (GPU if available, else CPU) - Device selector dropdown in the web UI - Dynamic GPU runtime creation when requested per-request - `/health` endpoint reports `cuda_available` --- app.py | 128 +++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 106 insertions(+), 22 deletions(-) diff --git a/app.py b/app.py index 1141b62..88e94d6 100644 --- a/app.py +++ b/app.py @@ -252,13 +252,24 @@ def __init__(self, default_runtime: NanoTTSService) -> None: self._lock = threading.Lock() self._cpu_execution_lock = threading.Lock() self._cpu_runtime: NanoTTSService | None = None + self._cuda_runtimes: dict[str, NanoTTSService] = {} @staticmethod def normalize_requested_execution_device(requested: str | None) -> str: normalized = str(requested or "default").strip().lower() - if normalized not in {"default", "cpu"}: - return "default" - return normalized + if normalized in {"default", "cpu"}: + return normalized + # 允许 "cuda" 或 "cuda:N" 格式 + if normalized == "cuda" or normalized.startswith("cuda:"): + parts = normalized.split(":", 1) + if len(parts) == 2: + try: + int(parts[1]) + return normalized + except ValueError: + return "default" + return "cuda" + return "default" def is_dedicated_cpu_request(self, requested: str | None) -> bool: normalized = self.normalize_requested_execution_device(requested) @@ -282,14 +293,37 @@ def _build_cpu_runtime_locked(self) -> NanoTTSService: ) return self._cpu_runtime + def _build_cuda_runtime_locked(self, device: str) -> NanoTTSService: + if device in self._cuda_runtimes: + return self._cuda_runtimes[device] + self._cuda_runtimes[device] = NanoTTSService( + checkpoint_path=self.default_runtime.checkpoint_path, + audio_tokenizer_path=self.default_runtime.audio_tokenizer_path, + device=device, + dtype=self.default_runtime.dtype or "auto", + attn_implementation=self.default_runtime.attn_implementation or "auto", + output_dir=self.default_runtime.output_dir, + voice_presets=self.default_runtime.voice_presets, + ) + return self._cuda_runtimes[device] + def resolve_runtime(self, requested: str | None) -> tuple[NanoTTSService, str]: normalized = self.normalize_requested_execution_device(requested) - if normalized != "cpu": + if normalized == "default": return self.default_runtime, str(self.default_runtime.device.type) - if self.default_runtime.device.type == "cpu": - return self.default_runtime, "cpu" - with self._lock: - return self._build_cpu_runtime_locked(), "cpu" + if normalized == "cpu": + if self.default_runtime.device.type == "cpu": + return self.default_runtime, "cpu" + with self._lock: + return self._build_cpu_runtime_locked(), "cpu" + # 请求 CUDA 设备 + if normalized.startswith("cuda"): + target_device = normalized + if self.default_runtime.device.type == "cuda": + return self.default_runtime, str(self.default_runtime.device) + with self._lock: + return self._build_cuda_runtime_locked(target_device), target_device + return self.default_runtime, str(self.default_runtime.device.type) def _resolve_cpu_threads(self, cpu_threads: int | None) -> int: if cpu_threads is None: @@ -657,6 +691,20 @@ async def _persist_uploaded_prompt_audio(upload: UploadFile | None) -> tuple[str return temp_path, _format_uploaded_prompt_display_name(original_filename) +def _build_cuda_options_html(cuda_available: bool, runtime_device: str) -> str: + if not cuda_available: + return "" + import torch + parts = [] + for i in range(torch.cuda.device_count()): + device_name = f"cuda:{i}" + label = f"CUDA:{i}" + if runtime_device == device_name: + label += " (runtime)" + parts.append(f' ') + return "\n".join(parts) + + def _render_index_html( *, request: Request, @@ -664,6 +712,7 @@ def _render_index_html( demo_entries: list[DemoEntry], warmup_status: str, text_normalization_status: str, + cuda_available: bool = False, ) -> str: base_path = request.scope.get("root_path", "").rstrip("/") template = """ @@ -1112,12 +1161,22 @@ def _render_index_html( Buffered generation keeps chunk order and decodes codec sub-batches no larger than the current TTS batch. Realtime Streaming Decode keeps output order and uses the smallest active chunk-group width among auto batching, Max TTS Batch Size, and Max Codec Batch Size. -
- - +
+
+ + +
+
+ + +
- This app is CPU-only. CPU Threads maps to torch.set_num_threads for that request. + Select inference device. Default uses the runtime device (__RUNTIME_DEVICE__). CPU Threads maps to torch.set_num_threads for CPU requests.
@@ -1709,6 +1768,7 @@ def _render_index_html( formData.append("enable_text_normalization", document.getElementById("enable-text-normalization").checked ? "1" : "0"); formData.append("enable_normalize_tts_text", document.getElementById("enable-robust-text-normalization").checked ? "1" : "0"); formData.append("cpu_threads", document.getElementById("cpu-thread-count").value || String(DEFAULT_CPU_THREADS)); + formData.append("execution_device", document.getElementById("execution-device").value); return formData; } @@ -2164,6 +2224,9 @@ def _render_index_html( "__TEXT_NORMALIZATION_STATUS__": text_normalization_status, "__CHECKPOINT__": str(runtime.checkpoint_path), "__AUDIO_TOKENIZER__": str(runtime.audio_tokenizer_path), + "__RUNTIME_DEVICE__": str(runtime.device), + "__CUDA_OPTIONS__": _build_cuda_options_html(cuda_available, str(runtime.device)), + "__CUDA_AVAILABLE__": json.dumps(cuda_available), } for placeholder, value in replacements.items(): template = template.replace(placeholder, value) @@ -2175,6 +2238,7 @@ def _build_app( warmup_manager: WarmupManager, text_normalizer_manager: WeTextProcessingManager | None, root_path: str | None, + cuda_available: bool = False, ) -> FastAPI: app = FastAPI(title="MOSS-TTS-Nano Demo", root_path=root_path or "") stream_jobs = StreamingJobManager() @@ -2187,6 +2251,7 @@ def _resolve_voice_clone_text_chunks( text: str, voice_clone_max_text_tokens: int, cpu_threads: int, + execution_device: str = "default", ) -> list[str]: normalized_text = str(text or "").strip() if not normalized_text: @@ -2194,7 +2259,7 @@ def _resolve_voice_clone_text_chunks( try: chunks, _, _ = runtime_manager.call_with_runtime( - requested_execution_device="cpu", + requested_execution_device=execution_device, cpu_threads=cpu_threads, callback=lambda selected_runtime: selected_runtime.split_voice_clone_text( text=normalized_text, @@ -2293,6 +2358,7 @@ def _run_streaming_job( tts_max_batch_size: int, codec_max_batch_size: int, cpu_threads: int, + execution_device: str = "default", attn_implementation: str, do_sample: bool, text_temperature: float, @@ -2305,7 +2371,8 @@ def _run_streaming_job( seed: int | None, ) -> None: try: - initial_execution_label = "cpu" + _normalized_device = RequestRuntimeManager.normalize_requested_execution_device(execution_device) + initial_execution_label = str(runtime.device) if _normalized_device == "default" else _normalized_device with job.lock: job.started_at = time.monotonic() job.state = "running" @@ -2334,7 +2401,7 @@ def _stream_factory(selected_runtime: NanoTTSService): ) for event, resolved_execution_device, resolved_cpu_threads in runtime_manager.iter_with_runtime( - requested_execution_device="cpu", + requested_execution_device=execution_device, cpu_threads=cpu_threads, factory=_stream_factory, ): @@ -2425,6 +2492,7 @@ async def index(request: Request): text_normalization_status=_text_normalization_status_text( text_normalizer_manager.snapshot() if text_normalizer_manager is not None else None ), + cuda_available=cuda_available, ) ) @@ -2433,6 +2501,7 @@ async def health(): return { "status": "ok", "device": str(runtime.device), + "cuda_available": cuda_available, "dtype": str(runtime.dtype), "cpu_runtime_loaded": runtime_manager.is_cpu_runtime_loaded(), "default_cpu_threads": runtime_manager.default_cpu_threads, @@ -2509,6 +2578,7 @@ async def generate_stream_start( codec_max_batch_size: int = Form(0), enable_text_normalization: str = Form("1"), enable_normalize_tts_text: str = Form("1"), + execution_device: str = Form("default"), cpu_threads: int = Form(0), attn_implementation: str = Form("model_default"), do_sample: str = Form("1"), @@ -2559,6 +2629,7 @@ async def generate_stream_start( text=str(prepared_texts["text"]), voice_clone_max_text_tokens=int(voice_clone_max_text_tokens), cpu_threads=int(cpu_threads), + execution_device=execution_device, ) job = stream_jobs.create() with job.lock: @@ -2577,6 +2648,7 @@ async def generate_stream_start( "tts_max_batch_size": int(tts_max_batch_size), "codec_max_batch_size": int(codec_max_batch_size), "cpu_threads": int(cpu_threads), + "execution_device": execution_device, "attn_implementation": attn_implementation, "do_sample": _coerce_bool(do_sample, True), "text_temperature": float(text_temperature), @@ -2594,7 +2666,8 @@ async def generate_stream_start( thread.start() prompt_audio_cleanup_path = None - initial_execution_label = "cpu" + _normalized_device = RequestRuntimeManager.normalize_requested_execution_device(execution_device) + initial_execution_label = str(runtime.device) if _normalized_device == "default" else _normalized_device return { "stream_id": job.stream_id, @@ -2720,6 +2793,7 @@ async def generate( codec_max_batch_size: int = Form(0), enable_text_normalization: str = Form("1"), enable_normalize_tts_text: str = Form("1"), + execution_device: str = Form("default"), cpu_threads: int = Form(0), attn_implementation: str = Form("model_default"), do_sample: str = Form("1"), @@ -2791,7 +2865,7 @@ def _synthesize(selected_runtime: NanoTTSService): ) result, resolved_execution_device, resolved_cpu_threads = runtime_manager.call_with_runtime( - requested_execution_device="cpu", + requested_execution_device=execution_device, cpu_threads=cpu_threads, callback=_synthesize, ) @@ -2809,6 +2883,7 @@ def _synthesize(selected_runtime: NanoTTSService): text=str(prepared_texts["text"]), voice_clone_max_text_tokens=int(voice_clone_max_text_tokens), cpu_threads=int(cpu_threads), + execution_device=execution_device, ) generated_audio_path = str(result["audio_path"]) wav_bytes = _audio_to_wav_bytes(result["waveform_numpy"], int(result["sample_rate"])) @@ -2847,7 +2922,7 @@ def main(argv: Optional[Sequence[str]] = None) -> None: default=str(DEFAULT_AUDIO_TOKENIZER_PATH), ) parser.add_argument("--output-dir", "--output_dir", dest="output_dir", type=str, default=str(DEFAULT_OUTPUT_DIR)) - parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "auto"]) + parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "auto", "cuda"]) parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"]) parser.add_argument( "--attn-implementation", @@ -2867,9 +2942,18 @@ def main(argv: Optional[Sequence[str]] = None) -> None: level=logging.INFO, ) - resolved_runtime_device = "cpu" - if args.device != "cpu": - logging.info("CPU-only app mode: ignoring --device=%s and forcing cpu.", args.device) + import torch as _torch + resolved_runtime_device = args.device + if resolved_runtime_device == "auto": + resolved_runtime_device = "cuda" if _torch.cuda.is_available() else "cpu" + logging.info("auto device resolved to: %s", resolved_runtime_device) + elif resolved_runtime_device == "cuda" and not _torch.cuda.is_available(): + logging.warning("--device=cuda specified but CUDA is not available, falling back to cpu.") + resolved_runtime_device = "cpu" + + cuda_available = _torch.cuda.is_available() + if cuda_available: + logging.info("CUDA available: %s (device count: %d)", _torch.cuda.get_device_name(0), _torch.cuda.device_count()) runtime = NanoTTSService( checkpoint_path=args.checkpoint_path, @@ -2890,7 +2974,7 @@ def main(argv: Optional[Sequence[str]] = None) -> None: if args.share: logging.warning("--share is ignored by the FastAPI-based Nano-TTS app.") - app = _build_app(runtime, warmup_manager, text_normalizer_manager, root_path) + app = _build_app(runtime, warmup_manager, text_normalizer_manager, root_path, cuda_available=cuda_available) uvicorn.run( app, host=args.host,