diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index ee3dd3b28e4d..48d4fc9b4f6d 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -95,6 +95,8 @@ image.save("qwen_fewsteps.png") With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference. +### Single prompt with multiple reference images + ```py import torch from PIL import Image @@ -114,6 +116,36 @@ image = pipe( ).images[0] ``` +### Batch processing with multiple prompts + +The pipeline also supports batch processing where you can edit multiple images with different prompts simultaneously. Use a nested list format `[[img1], [img2]]` to provide input images for each prompt: + +```py +import torch +from diffusers import QwenImageEditPlusPipeline +from diffusers.utils import load_image + +pipe = QwenImageEditPlusPipeline.from_pretrained( + "Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16 +).to("cuda") + +# Load input images +mountain_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mountain.jpg") + +# Process two different edits in a single batch +images = pipe( + image=[[mountain_image], [mountain_image]], # Nested list for batch_size=2 + prompt=[ + "Transform into a sunset scene with warm orange and pink sky", + "Add snow and make it a winter scene" + ], + num_inference_steps=50 +).images + +# images[0] contains the sunset version +# images[1] contains the winter version +``` + ## Performance ### torch.compile diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 257e2d846c7c..f3c8910ded78 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -407,6 +407,32 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents + def _preprocess_image_list(self, images): + """ + Preprocess a list of PIL images for both condition encoder and VAE. + + Args: + images: List of PIL images + + Returns: + Tuple of (condition_sizes, condition_images, vae_sizes, vae_images) + """ + condition_sizes = [] + condition_images = [] + vae_sizes = [] + vae_images = [] + + for img in images: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions(CONDITION_IMAGE_SIZE, image_width / image_height) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_sizes.append((condition_width, condition_height)) + vae_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + return condition_sizes, condition_images, vae_sizes, vae_images + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -431,6 +457,18 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents + def _encode_and_pack_image(self, image, num_channels_latents, device, dtype, generator): + """Encode a single image and pack it. Returns packed latents.""" + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + img_latents = self._encode_vae_image(image=image, generator=generator) + else: + img_latents = image + + image_latent_height, image_latent_width = img_latents.shape[3:] + img_latents = self._pack_latents(img_latents, 1, num_channels_latents, image_latent_height, image_latent_width) + return img_latents + def prepare_latents( self, images, @@ -454,30 +492,28 @@ def prepare_latents( if images is not None: if not isinstance(images, list): images = [images] - all_image_latents = [] - for image in images: - image = image.to(device=device, dtype=dtype) - if image.shape[1] != self.latent_channels: - image_latents = self._encode_vae_image(image=image, generator=generator) - else: - image_latents = image - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - image_latent_height, image_latent_width = image_latents.shape[3:] - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width - ) - all_image_latents.append(image_latents) - image_latents = torch.cat(all_image_latents, dim=1) + # Check if nested list (batch_size > 1): [[img1, img2], [img3, img4]] + is_nested = images and isinstance(images[0], list) + + if is_nested: + # batch_size > 1: Process each batch item separately + batch_image_latents = [] + for batch_images in images: + batch_item_latents = [ + self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator) + for img in batch_images + ] + # Concatenate all images for this batch item along sequence dimension + batch_image_latents.append(torch.cat(batch_item_latents, dim=1)) + # Stack all batch items to create final batch dimension + image_latents = torch.cat(batch_image_latents, dim=0) + else: + # batch_size == 1: Process flat list [img1, img2] + all_image_latents = [ + self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator) for img in images + ] + image_latents = torch.cat(all_image_latents, dim=1) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -543,12 +579,15 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, or `List[List[PIL.Image.Image]]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. + latents as `image`, but if passing latents directly it is not encoded again. For batch processing with + multiple prompts (batch_size > 1), provide a nested list where each sublist contains the input images + for that prompt: `[[img1_for_prompt1], [img2_for_prompt2]]`. For a single prompt with multiple + reference images (batch_size == 1), use a flat list: `[img1, img2]`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -627,7 +666,17 @@ def __call__( [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - image_size = image[-1].size if isinstance(image, list) else image.size + # Handle both flat list [img1, img2] and nested list [[img1, img2], [img3, img4]] + if isinstance(image, list): + # Check if nested list (batch_size > 1) + if isinstance(image[0], list): + # Use last image from first batch item + image_size = image[0][-1].size + else: + # Flat list (batch_size == 1) + image_size = image[-1].size + else: + image_size = image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) height = height or calculated_height width = width or calculated_width @@ -663,32 +712,38 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # QwenImageEditPlusPipeline does not currently support batch_size > 1 - if batch_size > 1: - raise ValueError( - f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " - "Please process prompts one at a time." - ) - device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): if not isinstance(image, list): image = [image] - condition_image_sizes = [] - condition_images = [] - vae_image_sizes = [] - vae_images = [] - for img in image: - image_width, image_height = img.size - condition_width, condition_height = calculate_dimensions( - CONDITION_IMAGE_SIZE, image_width / image_height + + # Check if nested list (batch_size > 1) or flat list (batch_size == 1) + is_nested = isinstance(image[0], list) + + if is_nested: + if batch_size > 1 and len(image) != batch_size: + raise ValueError( + f"Image batch_size ({len(image)}) must match batch_size for prompts ({batch_size}) for batch inference." + ) + # batch_size > 1: image = [[img1, img2], [img3, img4]] + # Process each batch item separately + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + + for batch_images in image: + cond_sizes, cond_imgs, vae_szs, vae_imgs = self._preprocess_image_list(batch_images) + condition_image_sizes.append(cond_sizes) + condition_images.append(cond_imgs) + vae_image_sizes.append(vae_szs) + vae_images.append(vae_imgs) + else: + # batch_size == 1: image = [img1, img2] + condition_image_sizes, condition_images, vae_image_sizes, vae_images = self._preprocess_image_list( + image ) - vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) - condition_image_sizes.append((condition_width, condition_height)) - vae_image_sizes.append((vae_width, vae_height)) - condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) - vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None @@ -737,15 +792,19 @@ def __call__( generator, latents, ) + # Build img_shapes for each batch item (avoid shared references!) + # Normalize vae_image_sizes to nested list format for uniform processing + sizes_list = vae_image_sizes if is_nested else [vae_image_sizes] img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), *[ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) - for vae_width, vae_height in vae_image_sizes + for vae_width, vae_height in batch_vae_sizes ], ] - ] * batch_size + for batch_vae_sizes in sizes_list + ] # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py index 6faf34728286..a6a16a4d194b 100644 --- a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py +++ b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py @@ -240,14 +240,52 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) - @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) - def test_num_images_per_prompt(): + @pytest.mark.xfail( + condition=True, + reason="num_images_per_prompt > 1 is not yet supported for EditPlus pipeline", + strict=True, + ) + def test_num_images_per_prompt(self): super().test_num_images_per_prompt() - @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) - def test_inference_batch_consistent(): - super().test_inference_batch_consistent() + def test_inference_batch_single_identical(self): + # Test that batch_size=1 gives identical results to non-batched inference + self._test_inference_batch_single_identical(expected_max_diff=1e-3) + + def test_inference_batch_consistent(self): + # Test that batched inference gives consistent results + self._test_inference_batch_consistent() + + def test_batch_processing_multiple_prompts(self): + # Test batch processing with multiple prompts (batch_size > 1) + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + if str(device).startswith("mps"): + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=device).manual_seed(0) + + image = Image.new("RGB", (32, 32)) + + # Test with nested list format for batch_size=2 + inputs = { + "prompt": ["dance monkey", "jump around"], + "image": [[image], [image]], # Nested list for batch_size=2 + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + images = pipe(**inputs).images - @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) - def test_inference_batch_single_identical(): - super().test_inference_batch_single_identical() + # Should return 2 images (batch_size=2) + self.assertEqual(len(images), 2) + self.assertEqual(images[0].shape, (3, 32, 32)) + self.assertEqual(images[1].shape, (3, 32, 32))