Skip to content
26 changes: 8 additions & 18 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -210,14 +211,7 @@ def _get_qwen_prompt_embeds(
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -248,19 +242,15 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, max_sequence_length
)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
20 changes: 5 additions & 15 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -274,14 +275,7 @@ def _get_qwen_prompt_embeds(
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -313,16 +307,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -256,14 +257,7 @@ def _get_qwen_prompt_embeds(
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -294,16 +288,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
20 changes: 5 additions & 15 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -257,14 +258,7 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -298,16 +292,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -268,14 +269,7 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -310,16 +304,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
65 changes: 42 additions & 23 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, concat_prompt_embeds_for_cfg, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -270,14 +271,7 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -312,16 +306,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -724,6 +714,15 @@ def __call__(
max_sequence_length=max_sequence_length,
)

use_batch_cfg = do_true_cfg and not self.transformer.is_cache_enabled
if use_batch_cfg:
prompt_embeds, prompt_embeds_mask = concat_prompt_embeds_for_cfg(
prompt_embeds,
prompt_embeds_mask,
negative_prompt_embeds,
negative_prompt_embeds_mask,
)

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, image_latents = self.prepare_latents(
Expand Down Expand Up @@ -799,7 +798,11 @@ def __call__(

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
if use_batch_cfg:
latent_model_input = torch.cat([latent_model_input] * 2)
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)

if use_batch_cfg:
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
Expand All @@ -811,20 +814,36 @@ def __call__(
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
neg_noise_pred, noise_pred = noise_pred.chunk(2)
else:
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]

if do_true_cfg:
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
Expand Down
25 changes: 8 additions & 17 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -217,14 +218,7 @@ def _get_qwen_prompt_embeds(
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)
Comment thread
kashif marked this conversation as resolved.
Outdated

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -291,19 +285,16 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
Loading
Loading