Skip to content

[bug] #4

@coder4nlp

Description

@coder4nlp
import os
import sys

sys.path.append("..")

import time
import torch
from diffusers import (
    QwenImagePipeline,
    QwenImageTransformer2DModel,
)

from utils import (
    GiB,
    get_args,
    strify,
    cachify,
    maybe_init_distributed,
    maybe_destroy_distributed,
)
import cache_dit


args = get_args()
print(args)

rank, device = maybe_init_distributed(args)

pipe = QwenImagePipeline.from_pretrained(
    os.environ.get(
        "QWEN_IMAGE_DIR",
        "/root/Qwen-Image",
    ),
    torch_dtype=torch.bfloat16,
)

assert isinstance(pipe.transformer, QwenImageTransformer2DModel)

enable_quatization = args.quantize and GiB() < 96
if GiB() < 96:
    if enable_quatization:
        print("Apply FP8 Weight Only Quantize ...")
        args.quantize_type = "fp8_w8a16_wo"  # force
        pipe.transformer = cache_dit.quantize(
            pipe.transformer,
            quant_type=args.quantize_type,
            exclude_layers=[
                "img_in",
                "txt_in",
            ],
        )
        pipe.text_encoder = cache_dit.quantize(
            pipe.text_encoder,
            quant_type=args.quantize_type,
        )
        pipe.to(device)
else:
    pipe.to(device)

# assert isinstance(pipe.vae, AutoencoderKLQwenImage)
# pipe.vae.enable_tiling()

# Apply cache and context parallelism here
if args.cache or args.parallel_type is not None:
    cachify(args, pipe)


if GiB() < 96 and not enable_quatization:
    # NOTE: Enable cpu offload before enabling parallelism will
    # raise shape error after first pipe call, so we enable it after.
    # It seems a bug of diffusers that cpu offload is not fully
    # compatible with context parallelism, visa versa.
    pipe.enable_model_cpu_offload(device=device)


positive_magic = {
    "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
    "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}

# Generate image
prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""

# using an empty string if you do not have specific concept to remove
negative_prompt = " "

pipe.set_progress_bar_config(disable=rank != 0)


def run_pipe():
    # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
    output = pipe(
        prompt=prompt + positive_magic["en"],
        negative_prompt=negative_prompt,
        width=1024 if args.width is None else args.width,
        height=1024 if args.height is None else args.height,
        num_inference_steps=50 if args.steps is None else args.steps,
        true_cfg_scale=4.0,
        generator=torch.Generator(device="cpu").manual_seed(42),
        output_type="latent" if args.perf else "pil",
    )
    image = output.images[0] if not args.perf else None
    return image

def run_pipe_v2():
    # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
    output = pipe(
        prompt=prompt + positive_magic["zh"],
        negative_prompt=negative_prompt,
        width=1024,
        height=1024,
        num_inference_steps=50,
        true_cfg_scale=4.0,
        generator=torch.Generator(device="cpu").manual_seed(42),
        output_type="pil",
    )
    image = output.images[0]
    return image


if args.compile:
    cache_dit.set_compile_configs()
    pipe.transformer = torch.compile(pipe.transformer)

# warmup
output = pipe(
        prompt=prompt + positive_magic["zh"],
        negative_prompt=negative_prompt,
        width=1024,
        height=1024,
        num_inference_steps=50,
        true_cfg_scale=4.0,
        generator=torch.Generator(device="cpu").manual_seed(42),
        output_type="pil",
    )
image = output.images[0]
    # image = output.images[0]

start = time.time()
output = pipe(
        prompt=prompt + positive_magic["zh"],
        negative_prompt=negative_prompt,
        width=1024,
        height=1024,
        num_inference_steps=50,
        true_cfg_scale=4.0,
        generator=torch.Generator(device="cpu").manual_seed(42),
        output_type="pil",
    )
image = output.images[0]
end = time.time()

cache_dit.summary(pipe)

if rank == 0:
    time_cost = end - start
    save_path = f"qwen-image-fast.{strify(args, pipe)}.png"
    print(f"Time cost: {time_cost:.2f}s")
    if not args.perf:
        print(f"Saving image to {save_path}")
        image.save(save_path)

maybe_destroy_distributed()

报错

[rank1]:   File "/root/autodl-tmp/cache-dit/src/cache_dit/caching/cache_adapters/cache_adapter.py", line 218, in new_call
[rank1]:     outputs = original_call(self, *args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 691, in __call__
[rank1]:     noise_pred = self.transformer(
[rank1]:                  ^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
[rank1]:     return super().__call__(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]:     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank1]:     input_val = self._prepare_cp_input(input_val, cpm)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank1]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
[rank1]:     return self._torchdynamo_orig_callable(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1272, in __call__
[rank1]:     result = self._inner_convert(
[rank1]:              ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
[rank1]:     return _compile(
[rank1]:            ^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
[rank1]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
[rank1]:     return function(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
[rank1]:     return _compile_inner(code, one_graph, hooks, transform)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
[rank1]:     out_code = transform_code_object(code, transform)
[rank1]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
[rank1]:     transformations(instructions, code_options)
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 753, in transform
[rank1]:     tracer.run()
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
[rank1]:     super().run()
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
[rank1]:     while self.step():
[rank1]:           ^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
[rank1]:     self.dispatch_table[inst.opcode](self, inst)
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 710, in inner
[rank1]:     self.jump(inst)
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1697, in jump
[rank1]:     assert self.instruction_pointer is not None
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: 

[rank1]: from user code:
[rank1]:    File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 261, in shard
[rank1]:     assert tensor.size()[dim] % mesh.size() == 0, (

[rank1]: Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

  0%|                                                                            | 0/50 [00:00<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/autodl-tmp/qwen-image-fast/qwen_image_fast.py", line 123, in <module>
[rank0]:     image = run_pipe(prompt, width=1664, height=938)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/qwen-image-fast/qwen_image_fast.py", line 93, in run_pipe
[rank0]:     output = pipe(
[rank0]:              ^^^^^
[rank0]:   File "/root/autodl-tmp/cache-dit/src/cache_dit/caching/cache_adapters/cache_adapter.py", line 218, in new_call
[rank0]:     outputs = original_call(self, *args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 691, in __call__
[rank0]:     noise_pred = self.transformer(
[rank0]:                  ^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
[rank0]:     return super().__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/hooks.py", line 188, in new_forward
[rank0]:     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank0]:     input_val = self._prepare_cp_input(input_val, cpm)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank0]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1272, in __call__
[rank0]:     result = self._inner_convert(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
[rank0]:     return _compile(
[rank0]:            ^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
[rank0]:     return _compile_inner(code, one_graph, hooks, transform)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 753, in transform
[rank0]:     tracer.run()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
[rank0]:     super().run()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 710, in inner
[rank0]:     self.jump(inst)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1697, in jump
[rank0]:     assert self.instruction_pointer is not None
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: 

[rank0]: from user code:
[rank0]:    File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 261, in shard
[rank0]:     assert tensor.size()[dim] % mesh.size() == 0, (

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions