diff --git a/__init__.py b/__init__.py index 2f5efd4..f37e095 100644 --- a/__init__.py +++ b/__init__.py @@ -6,6 +6,7 @@ import json from datetime import datetime from pathlib import Path + import folder_paths import comfy.model_management as mm import comfy.model_patcher @@ -191,7 +192,7 @@ def get_torch_device_patched(): else: devs = set(get_device_list()) device = torch.device(current_device) if str(current_device) in devs else torch.device("cpu") - logger.debug(f"[MultiGPU Core Patching] get_torch_device_patched returning device: {device} (current_device={current_device})") + logger.debug(f"[MultiGPU] get_torch_device_patched() -> {device}") return device def text_encoder_device_patched(): @@ -202,7 +203,7 @@ def text_encoder_device_patched(): else: devs = set(get_device_list()) device = torch.device(current_text_encoder_device) if str(current_text_encoder_device) in devs else torch.device("cpu") - logger.info(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})") + logger.debug(f"[MultiGPU] text_encoder_device_patched() -> {device}") return device def unet_offload_device_patched(): @@ -213,13 +214,13 @@ def unet_offload_device_patched(): else: devs = set(get_device_list()) device = torch.device(current_unet_offload_device) if str(current_unet_offload_device) in devs else torch.device("cpu") - logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})") + logger.debug(f"[MultiGPU] unet_offload_device_patched() -> {device}") return device -logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device") -logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}") -logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}") -logger.info(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}") +logger.debug(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device") +logger.debug(f"[MultiGPU DEBUG] Initial current_device: {current_device}") +logger.debug(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}") +logger.debug(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}") mm.get_torch_device = get_torch_device_patched mm.text_encoder_device = text_encoder_device_patched @@ -244,6 +245,11 @@ def unet_offload_device_patched(): PulidInsightFaceLoader, PulidEvaClipLoader, UNetLoaderLP, + # LTXV2 Core Node Adapters + LTXV2AudioVAELoader, + LTXV2AVTextEncoderLoader, + LatentUpscaleModelLoader, + LTXV2CheckpointLoader, ) from .wanvideo import ( @@ -323,6 +329,25 @@ def unet_offload_device_patched(): NODE_CLASS_MAPPINGS["DiffusersLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["DiffusersLoader"]) NODE_CLASS_MAPPINGS["DiffControlNetLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["DiffControlNetLoader"]) +# ============================================================================ +# LTXV2 CORE NODES (built into ComfyUI, always available - no custom_node check needed) +# ============================================================================ +logger.info("[MultiGPU] Registering LTXV2/Core nodes...") + +# Simple device selection +NODE_CLASS_MAPPINGS["LTXV2AudioVAELoaderMultiGPU"] = override_class(LTXV2AudioVAELoader) +NODE_CLASS_MAPPINGS["LTXV2AVTextEncoderLoaderMultiGPU"] = override_class_clip(LTXV2AVTextEncoderLoader) +NODE_CLASS_MAPPINGS["LatentUpscaleModelLoaderMultiGPU"] = override_class(LatentUpscaleModelLoader) +NODE_CLASS_MAPPINGS["LTXV2CheckpointLoaderMultiGPU"] = override_class(LTXV2CheckpointLoader) + +# DisTorch2 layer distribution (for large models that need to be split across GPUs) +NODE_CLASS_MAPPINGS["LTXV2AudioVAELoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(LTXV2AudioVAELoader) +NODE_CLASS_MAPPINGS["LTXV2AVTextEncoderLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(LTXV2AVTextEncoderLoader) +NODE_CLASS_MAPPINGS["LatentUpscaleModelLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(LatentUpscaleModelLoader) +NODE_CLASS_MAPPINGS["LTXV2CheckpointLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(LTXV2CheckpointLoader) + +logger.info("[MultiGPU] LTXV2/Core nodes registered: 8 nodes") + logger.info("[MultiGPU] Initiating custom_node Registration. . .") dash_line = "-" * 47 fmt_reg = "{:<30}{:>5}{:>10}" diff --git a/nodes.py b/nodes.py index 4a75b78..e815ef2 100644 --- a/nodes.py +++ b/nodes.py @@ -429,11 +429,162 @@ def load_unet(self, unet_name): """Load UNet with low-precision LoRA flag for CPU storage optimization.""" original_loader = NODE_CLASS_MAPPINGS["UNETLoader"]() out = original_loader.load_unet(unet_name) - + # Set the low-precision LoRA flag on the loaded model if hasattr(out[0], 'model'): out[0].model._distorch_high_precision_loras = False elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'): out[0].patcher.model._distorch_high_precision_loras = False - + return out + + +# ============================================================================ +# LTXV2 CORE NODE ADAPTERS (for ComfyUI built-in LTXV2 nodes) +# These adapters wrap the new comfy_api.latest style nodes to work with +# the traditional multigpu wrapper system +# ============================================================================ + +def _convert_node_output(result): + """Convert io.NodeOutput to tuple for ComfyUI compatibility. + + io.NodeOutput from comfy_api.latest stores values in self.args tuple. + """ + if isinstance(result, tuple): + return result + + # NodeOutput stores values in self.args + if type(result).__name__ == 'NodeOutput' and hasattr(result, 'args'): + return result.args + + # Fallback for other iterables + if hasattr(result, '__iter__') and not isinstance(result, (str, bytes)): + return tuple(result) + + return (result,) + + +class LTXV2AudioVAELoader: + """Adapter for ComfyUI core LTXV2 audio VAE loader.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), + {"tooltip": "Audio VAE checkpoint"}) + } + } + + RETURN_TYPES = ("VAE",) + RETURN_NAMES = ("audio_vae",) + FUNCTION = "load" + CATEGORY = "audio" + TITLE = "LTXV2 Audio VAE Loader" + + def load(self, ckpt_name): + """Load LTXV2 audio VAE from checkpoint.""" + core_node = NODE_CLASS_MAPPINGS.get("LTXVAudioVAELoader") + if core_node is None: + raise RuntimeError("LTXVAudioVAELoader not found in ComfyUI core nodes.") + + result = core_node.execute(ckpt_name=ckpt_name) + return _convert_node_output(result) + + +class LTXV2AVTextEncoderLoader: + """Adapter for ComfyUI core LTXV2 audio-video text encoder loader (Gemma 3).""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "text_encoder": (folder_paths.get_filename_list("text_encoders"), + {"tooltip": "Text encoder model file"}), + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), + {"tooltip": "LTXV2 checkpoint for text encoder config"}) + } + } + + RETURN_TYPES = ("CLIP",) + FUNCTION = "load" + CATEGORY = "advanced/loaders" + TITLE = "LTXV2 AV Text Encoder Loader" + + def load(self, text_encoder, ckpt_name, device=None): + """Load LTXV2 audio-video text encoder (Gemma 3 12B).""" + core_node = NODE_CLASS_MAPPINGS.get("LTXAVTextEncoderLoader") + if core_node is None: + raise RuntimeError("LTXAVTextEncoderLoader not found in ComfyUI core nodes.") + + result = core_node.execute(text_encoder=text_encoder, ckpt_name=ckpt_name) + return _convert_node_output(result) + + +class LatentUpscaleModelLoader: + """Adapter for ComfyUI core Latent Upscale Model Loader (used for LTXV2 and HunyuanVideo upscaling).""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("latent_upscale_models"), + {"tooltip": "Latent upscale model from ComfyUI/models/latent_upscale_models"}) + } + } + + RETURN_TYPES = ("LATENT_UPSCALE_MODEL",) + FUNCTION = "load" + CATEGORY = "loaders" + TITLE = "Load Latent Upscale Model" + + def load(self, model_name): + """Load latent upscale model for video upsampling.""" + core_node = NODE_CLASS_MAPPINGS.get("LatentUpscaleModelLoader") + if core_node is None: + raise RuntimeError("LatentUpscaleModelLoader not found in ComfyUI core nodes.") + + result = core_node.execute(model_name=model_name) + return _convert_node_output(result) + + +class LTXV2CheckpointLoader: + """Combined loader for LTXV2 video model, video VAE, and audio VAE. + + This is a convenience node that loads all three components needed for + LTXV2 audio-video generation from a single checkpoint file. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), + {"tooltip": "LTXV2 checkpoint (contains video model, video VAE, and audio VAE)"}) + } + } + + RETURN_TYPES = ("MODEL", "VAE", "VAE") + RETURN_NAMES = ("model", "vae", "audio_vae") + FUNCTION = "load" + CATEGORY = "lightricks/LTXV" + TITLE = "LTXV2 Checkpoint Loader (Video + Audio VAE)" + + def load(self, ckpt_name): + """Load LTXV2 video model, video VAE, and audio VAE from one checkpoint.""" + # Load video model and VAE using CheckpointLoaderSimple + checkpoint_loader = NODE_CLASS_MAPPINGS.get("CheckpointLoaderSimple") + if checkpoint_loader is None: + raise RuntimeError("CheckpointLoaderSimple not found in ComfyUI core nodes.") + + model, clip, vae = checkpoint_loader().load_checkpoint(ckpt_name) + + # Load audio VAE from the same checkpoint + audio_vae_loader = NODE_CLASS_MAPPINGS.get("LTXVAudioVAELoader") + if audio_vae_loader is None: + raise RuntimeError("LTXVAudioVAELoader not found in ComfyUI core nodes.") + + audio_vae_result = audio_vae_loader.execute(ckpt_name=ckpt_name) + audio_vae = _convert_node_output(audio_vae_result)[0] + + return (model, vae, audio_vae)