-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
import os
import sys
sys.path.append("..")
import time
import torch
from diffusers import (
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from utils import (
GiB,
get_args,
strify,
cachify,
maybe_init_distributed,
maybe_destroy_distributed,
)
import cache_dit
args = get_args()
print(args)
rank, device = maybe_init_distributed(args)
pipe = QwenImagePipeline.from_pretrained(
os.environ.get(
"QWEN_IMAGE_DIR",
"/root/Qwen-Image",
),
torch_dtype=torch.bfloat16,
)
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
enable_quatization = args.quantize and GiB() < 96
if GiB() < 96:
if enable_quatization:
print("Apply FP8 Weight Only Quantize ...")
args.quantize_type = "fp8_w8a16_wo" # force
pipe.transformer = cache_dit.quantize(
pipe.transformer,
quant_type=args.quantize_type,
exclude_layers=[
"img_in",
"txt_in",
],
)
pipe.text_encoder = cache_dit.quantize(
pipe.text_encoder,
quant_type=args.quantize_type,
)
pipe.to(device)
else:
pipe.to(device)
# assert isinstance(pipe.vae, AutoencoderKLQwenImage)
# pipe.vae.enable_tiling()
# Apply cache and context parallelism here
if args.cache or args.parallel_type is not None:
cachify(args, pipe)
if GiB() < 96 and not enable_quatization:
# NOTE: Enable cpu offload before enabling parallelism will
# raise shape error after first pipe call, so we enable it after.
# It seems a bug of diffusers that cpu offload is not fully
# compatible with context parallelism, visa versa.
pipe.enable_model_cpu_offload(device=device)
positive_magic = {
"en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
"zh": ", 超清,4K,电影级构图.", # for chinese prompt
}
# Generate image
prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
# using an empty string if you do not have specific concept to remove
negative_prompt = " "
pipe.set_progress_bar_config(disable=rank != 0)
def run_pipe():
# do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
output = pipe(
prompt=prompt + positive_magic["en"],
negative_prompt=negative_prompt,
width=1024 if args.width is None else args.width,
height=1024 if args.height is None else args.height,
num_inference_steps=50 if args.steps is None else args.steps,
true_cfg_scale=4.0,
generator=torch.Generator(device="cpu").manual_seed(42),
output_type="latent" if args.perf else "pil",
)
image = output.images[0] if not args.perf else None
return image
def run_pipe_v2():
# do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
output = pipe(
prompt=prompt + positive_magic["zh"],
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=50,
true_cfg_scale=4.0,
generator=torch.Generator(device="cpu").manual_seed(42),
output_type="pil",
)
image = output.images[0]
return image
if args.compile:
cache_dit.set_compile_configs()
pipe.transformer = torch.compile(pipe.transformer)
# warmup
output = pipe(
prompt=prompt + positive_magic["zh"],
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=50,
true_cfg_scale=4.0,
generator=torch.Generator(device="cpu").manual_seed(42),
output_type="pil",
)
image = output.images[0]
# image = output.images[0]
start = time.time()
output = pipe(
prompt=prompt + positive_magic["zh"],
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=50,
true_cfg_scale=4.0,
generator=torch.Generator(device="cpu").manual_seed(42),
output_type="pil",
)
image = output.images[0]
end = time.time()
cache_dit.summary(pipe)
if rank == 0:
time_cost = end - start
save_path = f"qwen-image-fast.{strify(args, pipe)}.png"
print(f"Time cost: {time_cost:.2f}s")
if not args.perf:
print(f"Saving image to {save_path}")
image.save(save_path)
maybe_destroy_distributed()
报错
[rank1]: File "/root/autodl-tmp/cache-dit/src/cache_dit/caching/cache_adapters/cache_adapter.py", line 218, in new_call
[rank1]: outputs = original_call(self, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 691, in __call__
[rank1]: noise_pred = self.transformer(
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
[rank1]: return super().__call__(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
[rank1]: return fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]: args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank1]: input_val = self._prepare_cp_input(input_val, cpm)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank1]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
[rank1]: return self._torchdynamo_orig_callable(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1272, in __call__
[rank1]: result = self._inner_convert(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
[rank1]: return _compile(
[rank1]: ^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
[rank1]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
[rank1]: return function(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
[rank1]: return _compile_inner(code, one_graph, hooks, transform)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
[rank1]: out_code = transform_code_object(code, transform)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
[rank1]: transformations(instructions, code_options)
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 753, in transform
[rank1]: tracer.run()
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
[rank1]: super().run()
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
[rank1]: while self.step():
[rank1]: ^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
[rank1]: self.dispatch_table[inst.opcode](self, inst)
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 710, in inner
[rank1]: self.jump(inst)
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1697, in jump
[rank1]: assert self.instruction_pointer is not None
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError:
[rank1]: from user code:
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 261, in shard
[rank1]: assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]: Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
0%| | 0/50 [00:00<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]: File "/root/autodl-tmp/qwen-image-fast/qwen_image_fast.py", line 123, in <module>
[rank0]: image = run_pipe(prompt, width=1664, height=938)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/autodl-tmp/qwen-image-fast/qwen_image_fast.py", line 93, in run_pipe
[rank0]: output = pipe(
[rank0]: ^^^^^
[rank0]: File "/root/autodl-tmp/cache-dit/src/cache_dit/caching/cache_adapters/cache_adapter.py", line 218, in new_call
[rank0]: outputs = original_call(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 691, in __call__
[rank0]: noise_pred = self.transformer(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
[rank0]: return super().__call__(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/hooks.py", line 188, in new_forward
[rank0]: args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank0]: input_val = self._prepare_cp_input(input_val, cpm)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank0]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
[rank0]: return self._torchdynamo_orig_callable(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1272, in __call__
[rank0]: result = self._inner_convert(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
[rank0]: return _compile(
[rank0]: ^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
[rank0]: return function(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
[rank0]: return _compile_inner(code, one_graph, hooks, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
[rank0]: out_code = transform_code_object(code, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
[rank0]: transformations(instructions, code_options)
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 753, in transform
[rank0]: tracer.run()
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
[rank0]: super().run()
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 710, in inner
[rank0]: self.jump(inst)
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1697, in jump
[rank0]: assert self.instruction_pointer is not None
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError:
[rank0]: from user code:
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py", line 261, in shard
[rank0]: assert tensor.size()[dim] % mesh.size() == 0, (
Metadata
Metadata
Assignees
Labels
No labels