Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/source/en/api/pipelines/qwenimage.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ image.save("qwen_fewsteps.png")

With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.

### Single prompt with multiple reference images

```py
import torch
from PIL import Image
Expand All @@ -114,6 +116,36 @@ image = pipe(
).images[0]
```

### Batch processing with multiple prompts

The pipeline also supports batch processing where you can edit multiple images with different prompts simultaneously. Use a nested list format `[[img1], [img2]]` to provide input images for each prompt:

```py
import torch
from diffusers import QwenImageEditPlusPipeline
from diffusers.utils import load_image

pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
).to("cuda")

# Load input images
mountain_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mountain.jpg")

# Process two different edits in a single batch
images = pipe(
image=[[mountain_image], [mountain_image]], # Nested list for batch_size=2
prompt=[
"Transform into a sunset scene with warm orange and pink sky",
"Add snow and make it a winter scene"
],
num_inference_steps=50
).images

# images[0] contains the sunset version
# images[1] contains the winter version
```

## Performance

### torch.compile
Expand Down
155 changes: 107 additions & 48 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,32 @@ def _unpack_latents(latents, height, width, vae_scale_factor):

return latents

def _preprocess_image_list(self, images):
"""
Preprocess a list of PIL images for both condition encoder and VAE.

Args:
images: List of PIL images

Returns:
Tuple of (condition_sizes, condition_images, vae_sizes, vae_images)
"""
condition_sizes = []
condition_images = []
vae_sizes = []
vae_images = []

for img in images:
image_width, image_height = img.size
condition_width, condition_height = calculate_dimensions(CONDITION_IMAGE_SIZE, image_width / image_height)
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
condition_sizes.append((condition_width, condition_height))
vae_sizes.append((vae_width, vae_height))
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))

return condition_sizes, condition_images, vae_sizes, vae_images

# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
Expand All @@ -431,6 +457,18 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):

return image_latents

def _encode_and_pack_image(self, image, num_channels_latents, device, dtype, generator):
"""Encode a single image and pack it. Returns packed latents."""
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
img_latents = self._encode_vae_image(image=image, generator=generator)
else:
img_latents = image

image_latent_height, image_latent_width = img_latents.shape[3:]
img_latents = self._pack_latents(img_latents, 1, num_channels_latents, image_latent_height, image_latent_width)
return img_latents

def prepare_latents(
self,
images,
Expand All @@ -454,30 +492,28 @@ def prepare_latents(
if images is not None:
if not isinstance(images, list):
images = [images]
all_image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)

image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
)
all_image_latents.append(image_latents)
image_latents = torch.cat(all_image_latents, dim=1)
# Check if nested list (batch_size > 1): [[img1, img2], [img3, img4]]
is_nested = images and isinstance(images[0], list)

if is_nested:
# batch_size > 1: Process each batch item separately
batch_image_latents = []
for batch_images in images:
batch_item_latents = [
self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator)
for img in batch_images
]
# Concatenate all images for this batch item along sequence dimension
batch_image_latents.append(torch.cat(batch_item_latents, dim=1))
# Stack all batch items to create final batch dimension
image_latents = torch.cat(batch_image_latents, dim=0)
else:
# batch_size == 1: Process flat list [img1, img2]
all_image_latents = [
self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator) for img in images
]
image_latents = torch.cat(all_image_latents, dim=1)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand Down Expand Up @@ -543,12 +579,15 @@ def __call__(
Function invoked when calling the pipeline for generation.

Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, or `List[List[PIL.Image.Image]]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
latents as `image`, but if passing latents directly it is not encoded again. For batch processing with
multiple prompts (batch_size > 1), provide a nested list where each sublist contains the input images
for that prompt: `[[img1_for_prompt1], [img2_for_prompt2]]`. For a single prompt with multiple
reference images (batch_size == 1), use a flat list: `[img1, img2]`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
Expand Down Expand Up @@ -627,7 +666,17 @@ def __call__(
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
image_size = image[-1].size if isinstance(image, list) else image.size
# Handle both flat list [img1, img2] and nested list [[img1, img2], [img3, img4]]
if isinstance(image, list):
# Check if nested list (batch_size > 1)
if isinstance(image[0], list):
# Use last image from first batch item
image_size = image[0][-1].size
else:
# Flat list (batch_size == 1)
image_size = image[-1].size
else:
image_size = image.size
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
height = height or calculated_height
width = width or calculated_width
Expand Down Expand Up @@ -663,32 +712,38 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]

# QwenImageEditPlusPipeline does not currently support batch_size > 1
if batch_size > 1:
raise ValueError(
f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. "
"Please process prompts one at a time."
)

device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
if not isinstance(image, list):
image = [image]
condition_image_sizes = []
condition_images = []
vae_image_sizes = []
vae_images = []
for img in image:
image_width, image_height = img.size
condition_width, condition_height = calculate_dimensions(
CONDITION_IMAGE_SIZE, image_width / image_height

# Check if nested list (batch_size > 1) or flat list (batch_size == 1)
is_nested = isinstance(image[0], list)

if is_nested:
if batch_size > 1 and len(image) != batch_size:
raise ValueError(
f"Image batch_size ({len(image)}) must match batch_size for prompts ({batch_size}) for batch inference."
)
# batch_size > 1: image = [[img1, img2], [img3, img4]]
# Process each batch item separately
condition_image_sizes = []
condition_images = []
vae_image_sizes = []
vae_images = []

for batch_images in image:
cond_sizes, cond_imgs, vae_szs, vae_imgs = self._preprocess_image_list(batch_images)
condition_image_sizes.append(cond_sizes)
condition_images.append(cond_imgs)
vae_image_sizes.append(vae_szs)
vae_images.append(vae_imgs)
else:
# batch_size == 1: image = [img1, img2]
condition_image_sizes, condition_images, vae_image_sizes, vae_images = self._preprocess_image_list(
image
)
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
condition_image_sizes.append((condition_width, condition_height))
vae_image_sizes.append((vae_width, vae_height))
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
Expand Down Expand Up @@ -737,15 +792,19 @@ def __call__(
generator,
latents,
)
# Build img_shapes for each batch item (avoid shared references!)
# Normalize vae_image_sizes to nested list format for uniform processing
sizes_list = vae_image_sizes if is_nested else [vae_image_sizes]
img_shapes = [
[
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
*[
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
for vae_width, vae_height in vae_image_sizes
for vae_width, vae_height in batch_vae_sizes
],
]
] * batch_size
for batch_vae_sizes in sizes_list
]

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
Expand Down
54 changes: 46 additions & 8 deletions tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,52 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt():
@pytest.mark.xfail(
condition=True,
reason="num_images_per_prompt > 1 is not yet supported for EditPlus pipeline",
strict=True,
)
def test_num_images_per_prompt(self):
super().test_num_images_per_prompt()

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_consistent():
super().test_inference_batch_consistent()
def test_inference_batch_single_identical(self):
# Test that batch_size=1 gives identical results to non-batched inference
self._test_inference_batch_single_identical(expected_max_diff=1e-3)

def test_inference_batch_consistent(self):
# Test that batched inference gives consistent results
self._test_inference_batch_consistent()

def test_batch_processing_multiple_prompts(self):
# Test batch processing with multiple prompts (batch_size > 1)
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

if str(device).startswith("mps"):
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=device).manual_seed(0)

image = Image.new("RGB", (32, 32))

# Test with nested list format for batch_size=2
inputs = {
"prompt": ["dance monkey", "jump around"],
"image": [[image], [image]], # Nested list for batch_size=2
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}

images = pipe(**inputs).images

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_single_identical():
super().test_inference_batch_single_identical()
# Should return 2 images (batch_size=2)
self.assertEqual(len(images), 2)
self.assertEqual(images[0].shape, (3, 32, 32))
self.assertEqual(images[1].shape, (3, 32, 32))
Loading