Skip to content
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):
Expand Down
71 changes: 71 additions & 0 deletions tests/models/transformers/test_models_transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,74 @@ def prepare_dummy_input(self, height, width):

def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()

def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)

# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None

# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)

# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_no_mask_2 = model(**inputs_no_mask)

self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])

# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())

# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)

# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_all_ones_2 = model(**inputs_all_ones)

self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])

# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding

inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding

# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)

# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_with_padding_2 = model(**inputs_with_padding)

self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])

# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
Loading