Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from datetime import datetime
from pathlib import Path

import folder_paths
import comfy.model_management as mm
import comfy.model_patcher
Expand Down Expand Up @@ -191,7 +192,7 @@ def get_torch_device_patched():
else:
devs = set(get_device_list())
device = torch.device(current_device) if str(current_device) in devs else torch.device("cpu")
logger.debug(f"[MultiGPU Core Patching] get_torch_device_patched returning device: {device} (current_device={current_device})")
logger.debug(f"[MultiGPU] get_torch_device_patched() -> {device}")
return device

def text_encoder_device_patched():
Expand All @@ -202,7 +203,7 @@ def text_encoder_device_patched():
else:
devs = set(get_device_list())
device = torch.device(current_text_encoder_device) if str(current_text_encoder_device) in devs else torch.device("cpu")
logger.info(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
logger.debug(f"[MultiGPU] text_encoder_device_patched() -> {device}")
return device

def unet_offload_device_patched():
Expand All @@ -213,13 +214,13 @@ def unet_offload_device_patched():
else:
devs = set(get_device_list())
device = torch.device(current_unet_offload_device) if str(current_unet_offload_device) in devs else torch.device("cpu")
logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})")
logger.debug(f"[MultiGPU] unet_offload_device_patched() -> {device}")
return device

logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
logger.info(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}")
logger.debug(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
logger.debug(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
logger.debug(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
logger.debug(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}")

mm.get_torch_device = get_torch_device_patched
mm.text_encoder_device = text_encoder_device_patched
Expand All @@ -244,6 +245,11 @@ def unet_offload_device_patched():
PulidInsightFaceLoader,
PulidEvaClipLoader,
UNetLoaderLP,
# LTXV2 Core Node Adapters
LTXV2AudioVAELoader,
LTXV2AVTextEncoderLoader,
LatentUpscaleModelLoader,
LTXV2CheckpointLoader,
)

from .wanvideo import (
Expand Down Expand Up @@ -323,6 +329,25 @@ def unet_offload_device_patched():
NODE_CLASS_MAPPINGS["DiffusersLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["DiffusersLoader"])
NODE_CLASS_MAPPINGS["DiffControlNetLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["DiffControlNetLoader"])

# ============================================================================
# LTXV2 CORE NODES (built into ComfyUI, always available - no custom_node check needed)
# ============================================================================
logger.info("[MultiGPU] Registering LTXV2/Core nodes...")

# Simple device selection
NODE_CLASS_MAPPINGS["LTXV2AudioVAELoaderMultiGPU"] = override_class(LTXV2AudioVAELoader)
NODE_CLASS_MAPPINGS["LTXV2AVTextEncoderLoaderMultiGPU"] = override_class_clip(LTXV2AVTextEncoderLoader)
NODE_CLASS_MAPPINGS["LatentUpscaleModelLoaderMultiGPU"] = override_class(LatentUpscaleModelLoader)
NODE_CLASS_MAPPINGS["LTXV2CheckpointLoaderMultiGPU"] = override_class(LTXV2CheckpointLoader)

# DisTorch2 layer distribution (for large models that need to be split across GPUs)
NODE_CLASS_MAPPINGS["LTXV2AudioVAELoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(LTXV2AudioVAELoader)
NODE_CLASS_MAPPINGS["LTXV2AVTextEncoderLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(LTXV2AVTextEncoderLoader)
NODE_CLASS_MAPPINGS["LatentUpscaleModelLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(LatentUpscaleModelLoader)
NODE_CLASS_MAPPINGS["LTXV2CheckpointLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(LTXV2CheckpointLoader)

logger.info("[MultiGPU] LTXV2/Core nodes registered: 8 nodes")

logger.info("[MultiGPU] Initiating custom_node Registration. . .")
dash_line = "-" * 47
fmt_reg = "{:<30}{:>5}{:>10}"
Expand Down
155 changes: 153 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,162 @@ def load_unet(self, unet_name):
"""Load UNet with low-precision LoRA flag for CPU storage optimization."""
original_loader = NODE_CLASS_MAPPINGS["UNETLoader"]()
out = original_loader.load_unet(unet_name)

# Set the low-precision LoRA flag on the loaded model
if hasattr(out[0], 'model'):
out[0].model._distorch_high_precision_loras = False
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
out[0].patcher.model._distorch_high_precision_loras = False

return out


# ============================================================================
# LTXV2 CORE NODE ADAPTERS (for ComfyUI built-in LTXV2 nodes)
# These adapters wrap the new comfy_api.latest style nodes to work with
# the traditional multigpu wrapper system
# ============================================================================

def _convert_node_output(result):
"""Convert io.NodeOutput to tuple for ComfyUI compatibility.

io.NodeOutput from comfy_api.latest stores values in self.args tuple.
"""
if isinstance(result, tuple):
return result

# NodeOutput stores values in self.args
if type(result).__name__ == 'NodeOutput' and hasattr(result, 'args'):
return result.args

# Fallback for other iterables
if hasattr(result, '__iter__') and not isinstance(result, (str, bytes)):
return tuple(result)

return (result,)


class LTXV2AudioVAELoader:
"""Adapter for ComfyUI core LTXV2 audio VAE loader."""

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),
{"tooltip": "Audio VAE checkpoint"})
}
}

RETURN_TYPES = ("VAE",)
RETURN_NAMES = ("audio_vae",)
FUNCTION = "load"
CATEGORY = "audio"
TITLE = "LTXV2 Audio VAE Loader"

def load(self, ckpt_name):
"""Load LTXV2 audio VAE from checkpoint."""
core_node = NODE_CLASS_MAPPINGS.get("LTXVAudioVAELoader")
if core_node is None:
raise RuntimeError("LTXVAudioVAELoader not found in ComfyUI core nodes.")

result = core_node.execute(ckpt_name=ckpt_name)
return _convert_node_output(result)


class LTXV2AVTextEncoderLoader:
"""Adapter for ComfyUI core LTXV2 audio-video text encoder loader (Gemma 3)."""

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"text_encoder": (folder_paths.get_filename_list("text_encoders"),
{"tooltip": "Text encoder model file"}),
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),
{"tooltip": "LTXV2 checkpoint for text encoder config"})
}
}

RETURN_TYPES = ("CLIP",)
FUNCTION = "load"
CATEGORY = "advanced/loaders"
TITLE = "LTXV2 AV Text Encoder Loader"

def load(self, text_encoder, ckpt_name, device=None):
"""Load LTXV2 audio-video text encoder (Gemma 3 12B)."""
core_node = NODE_CLASS_MAPPINGS.get("LTXAVTextEncoderLoader")
if core_node is None:
raise RuntimeError("LTXAVTextEncoderLoader not found in ComfyUI core nodes.")

result = core_node.execute(text_encoder=text_encoder, ckpt_name=ckpt_name)
return _convert_node_output(result)


class LatentUpscaleModelLoader:
"""Adapter for ComfyUI core Latent Upscale Model Loader (used for LTXV2 and HunyuanVideo upscaling)."""

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("latent_upscale_models"),
{"tooltip": "Latent upscale model from ComfyUI/models/latent_upscale_models"})
}
}

RETURN_TYPES = ("LATENT_UPSCALE_MODEL",)
FUNCTION = "load"
CATEGORY = "loaders"
TITLE = "Load Latent Upscale Model"

def load(self, model_name):
"""Load latent upscale model for video upsampling."""
core_node = NODE_CLASS_MAPPINGS.get("LatentUpscaleModelLoader")
if core_node is None:
raise RuntimeError("LatentUpscaleModelLoader not found in ComfyUI core nodes.")

result = core_node.execute(model_name=model_name)
return _convert_node_output(result)


class LTXV2CheckpointLoader:
"""Combined loader for LTXV2 video model, video VAE, and audio VAE.

This is a convenience node that loads all three components needed for
LTXV2 audio-video generation from a single checkpoint file.
"""

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),
{"tooltip": "LTXV2 checkpoint (contains video model, video VAE, and audio VAE)"})
}
}

RETURN_TYPES = ("MODEL", "VAE", "VAE")
RETURN_NAMES = ("model", "vae", "audio_vae")
FUNCTION = "load"
CATEGORY = "lightricks/LTXV"
TITLE = "LTXV2 Checkpoint Loader (Video + Audio VAE)"

def load(self, ckpt_name):
"""Load LTXV2 video model, video VAE, and audio VAE from one checkpoint."""
# Load video model and VAE using CheckpointLoaderSimple
checkpoint_loader = NODE_CLASS_MAPPINGS.get("CheckpointLoaderSimple")
if checkpoint_loader is None:
raise RuntimeError("CheckpointLoaderSimple not found in ComfyUI core nodes.")

model, clip, vae = checkpoint_loader().load_checkpoint(ckpt_name)

# Load audio VAE from the same checkpoint
audio_vae_loader = NODE_CLASS_MAPPINGS.get("LTXVAudioVAELoader")
if audio_vae_loader is None:
raise RuntimeError("LTXVAudioVAELoader not found in ComfyUI core nodes.")

audio_vae_result = audio_vae_loader.execute(ckpt_name=ckpt_name)
audio_vae = _convert_node_output(audio_vae_result)[0]

return (model, vae, audio_vae)