Skip to content

Commit b9fd68e

Browse files
authored
feat: support 🔥FLUX.2 context parallel (#492)
* feat: support hybrid cache + tp for flux.2 * feat: enable seq offload for FLUX.2 w/ GPU=1 * feat: support FLUX.2 context parallel * feat: support FLUX.2 context parallel * feat: support FLUX.2 context parallel * feat: support FLUX.2 context parallel * feat: support FLUX.2 context parallel * feat: support FLUX.2 context parallel
1 parent 4cae178 commit b9fd68e

File tree

8 files changed

+284
-6
lines changed

8 files changed

+284
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ The comparison between **cache-dit** and other algorithms shows that within a sp
114114

115115
| 📚Model | Cache | CP | TP | 📚Model | Cache | CP | TP |
116116
|:---|:---|:---|:---|:---|:---|:---|:---|
117-
| **🔥[FLUX.2: 56B](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️🔥 | ✖️ | ✔️🔥 | **🎉[FLUX.1 `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
117+
| **🔥[FLUX.2: 56B](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️🔥 | ✔️🔥 | ✔️🔥 | **🎉[FLUX.1 `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
118118
| **🎉[FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✔️ | **🎉[FLUX.1-Fill `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
119119
| **🎉[FLUX.1-Fill](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✔️ | **🎉[Qwen-Image `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
120120
| **🎉[Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✔️ | **🎉[Qwen...Edit `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |

docs/User_Guide.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
7575
```
7676

7777
> [!Tip]
78-
> One **Model Series** may contain **many** pipelines. cache-dit applies optimizations at the **Transformer** level; thus, any pipelines that include the supported transformer are already supported by cache-dit. ✔️: known work and official supported now; ✖️: unofficial supported now, but maybe support in the future; **[`Q`](https://github.com/nunchaku-tech/nunchaku)**: **4-bits** models w/ [nunchaku](https://github.com/nunchaku-tech/nunchaku) + SVDQ **W4A4**.
78+
> One **Model Series** may contain **many** pipelines. cache-dit applies optimizations at the **Transformer** level; thus, any pipelines that include the supported transformer are already supported by cache-dit. ✔️: known work and official supported now; ✖️: unofficial supported now, but maybe support in the future; **[`Q`](https://github.com/nunchaku-tech/nunchaku)**: **4-bits** models w/ [nunchaku](https://github.com/nunchaku-tech/nunchaku) + SVDQ **W4A4**; **🔥FLUX.2**: 24B + 32B = 56B.
7979
8080
<div align="center">
8181

8282
| 📚Model | Cache | CP | TP | 📚Model | Cache | CP | TP |
8383
|:---|:---|:---|:---|:---|:---|:---|:---|
84-
| **🔥🔥[FLUX.2](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | 🔥✔️ | ✖️ | 🔥✔️ | **🎉[FLUX.1 `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
84+
| **🔥[FLUX.2: 56B](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️🔥 | ✔️🔥 | ✔️🔥 | **🎉[FLUX.1 `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
8585
| **🎉[FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✔️ | **🎉[FLUX.1-Fill `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
8686
| **🎉[FLUX.1-Fill](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✔️ | **🎉[Qwen-Image `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
8787
| **🎉[Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✔️ | **🎉[Qwen...Edit `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✔️ | ✔️ | ✖️ |
@@ -703,7 +703,7 @@ As we can observe, in the case of **static cache**, the image of `SCM Slow S*` (
703703

704704
<div id="context-parallelism"></div>
705705

706-
cache-dit is compatible with context parallelism. Currently, we support the use of `Hybrid Cache` + `Context Parallelism` scheme (via NATIVE_DIFFUSER parallelism backend) in cache-dit. Users can use Context Parallelism to further accelerate the speed of inference! For more details, please refer to [📚examples/parallelism](https://github.com/vipshop/cache-dit/tree/main/examples/parallelism). Currently, cache-dit supported context parallelism for [FLUX.1](https://huggingface.co/black-forest-labs/FLUX.1-dev), [Qwen-Image](https://github.com/QwenLM/Qwen-Image), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning), [LTXVideo](https://huggingface.co/Lightricks/LTX-Video), [Wan 2.1](https://github.com/Wan-Video/Wan2.1), [Wan 2.2](https://github.com/Wan-Video/Wan2.2), [HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1), [HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo), [CogVideoX 1.0](https://github.com/zai-org/CogVideo), [CogVideoX 1.5](https://github.com/zai-org/CogVideo), [CogView 3/4](https://github.com/zai-org/CogView4) and [VisualCloze](https://github.com/lzyhha/VisualCloze), etc. cache-dit will support more models in the future.
706+
cache-dit is compatible with context parallelism. Currently, we support the use of `Hybrid Cache` + `Context Parallelism` scheme (via NATIVE_DIFFUSER parallelism backend) in cache-dit. Users can use Context Parallelism to further accelerate the speed of inference! For more details, please refer to [📚examples/parallelism](https://github.com/vipshop/cache-dit/tree/main/examples/parallelism). Currently, cache-dit supported context parallelism for [FLUX.1](https://huggingface.co/black-forest-labs/FLUX.1-dev), 🔥[FLUX.2](https://huggingface.co/black-forest-labs/FLUX.2-dev), [Qwen-Image](https://github.com/QwenLM/Qwen-Image), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning), [LTXVideo](https://huggingface.co/Lightricks/LTX-Video), [Wan 2.1](https://github.com/Wan-Video/Wan2.1), [Wan 2.2](https://github.com/Wan-Video/Wan2.2), [HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1), [HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo), [CogVideoX 1.0](https://github.com/zai-org/CogVideo), [CogVideoX 1.5](https://github.com/zai-org/CogVideo), [CogView 3/4](https://github.com/zai-org/CogView4) and [VisualCloze](https://github.com/lzyhha/VisualCloze), etc. cache-dit will support more models in the future.
707707

708708
```python
709709
# pip3 install "cache-dit[parallelism]"
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
6+
import time
7+
8+
import torch
9+
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
10+
from diffusers.quantizers import PipelineQuantizationConfig
11+
12+
from utils import (
13+
MemoryTracker,
14+
GiB,
15+
cachify,
16+
get_args,
17+
maybe_destroy_distributed,
18+
maybe_init_distributed,
19+
strify,
20+
)
21+
22+
import cache_dit
23+
24+
args = get_args()
25+
print(args)
26+
27+
rank, device = maybe_init_distributed(args)
28+
29+
if GiB() < 128:
30+
assert args.quantize, "Quantization is required to fit FLUX.2 in <128GB memory."
31+
assert args.quantize_type in ["bitsandbytes_4bit", "float8_weight_only"], (
32+
f"Unsupported quantization type: {args.quantize_type}, only supported"
33+
"'bitsandbytes_4bit (bnb_4bit)' and 'float8_weight_only'."
34+
)
35+
36+
pipe: Flux2Pipeline = Flux2Pipeline.from_pretrained(
37+
(
38+
args.model_path
39+
if args.model_path is not None
40+
else os.environ.get(
41+
"FLUX_2_DIR",
42+
"black-forest-labs/FLUX.2-dev",
43+
)
44+
),
45+
torch_dtype=torch.bfloat16,
46+
quantization_config=(
47+
(
48+
PipelineQuantizationConfig(
49+
quant_backend="bitsandbytes_4bit",
50+
quant_kwargs={
51+
"load_in_4bit": True,
52+
"bnb_4bit_quant_type": "nf4",
53+
"bnb_4bit_compute_dtype": torch.bfloat16,
54+
},
55+
# 112/4 = 28GB total for text_encoder + transformer in 4-bit
56+
components_to_quantize=["text_encoder", "transformer"],
57+
)
58+
)
59+
if args.quantize and args.quantize_type in ("bitsandbytes_4bit",)
60+
else None
61+
),
62+
)
63+
64+
if args.quantize and args.quantize_type == "float8_weight_only":
65+
pipe.transformer = cache_dit.quantize(
66+
pipe.transformer,
67+
quant_type=args.quantize_type,
68+
exclude_layers=[
69+
"img_in",
70+
"txt_in",
71+
],
72+
)
73+
pipe.text_encoder = cache_dit.quantize(
74+
pipe.text_encoder,
75+
quant_type=args.quantize_type,
76+
)
77+
78+
if args.cache or args.parallel_type is not None:
79+
from cache_dit import DBCacheConfig, ParamsModifier
80+
81+
cachify(
82+
args,
83+
pipe,
84+
extra_parallel_modules=(
85+
# Specify extra modules to be parallelized in addition to the main transformer,
86+
# e.g., text_encoder_2 in FluxPipeline, text_encoder in Flux2Pipeline. Currently,
87+
# only supported in native pytorch backend (namely, Tensor Parallelism).
88+
[pipe.text_encoder]
89+
if args.parallel_type == "tp"
90+
else []
91+
),
92+
params_modifiers=[
93+
ParamsModifier(
94+
# Modified config only for transformer_blocks
95+
# Must call the `reset` method of DBCacheConfig.
96+
cache_config=DBCacheConfig().reset(
97+
residual_diff_threshold=args.rdt,
98+
),
99+
),
100+
ParamsModifier(
101+
# Modified config only for single_transformer_blocks
102+
# NOTE: FLUX.2, single_transformer_blocks should have `higher`
103+
# residual_diff_threshold because of the precision error
104+
# accumulation from previous transformer_blocks
105+
cache_config=DBCacheConfig().reset(
106+
residual_diff_threshold=args.rdt * 3,
107+
),
108+
),
109+
],
110+
)
111+
112+
torch.cuda.empty_cache()
113+
114+
if args.quantize_type == "bitsandbytes_4bit":
115+
pipe.to(device)
116+
else:
117+
pipe.enable_model_cpu_offload(device=device)
118+
119+
assert isinstance(pipe.transformer, Flux2Transformer2DModel)
120+
121+
pipe.set_progress_bar_config(disable=rank != 0)
122+
123+
prompt = (
124+
"Realistic macro photograph of a hermit crab using a soda can as its shell, "
125+
"partially emerging from the can, captured with sharp detail and natural colors, "
126+
"on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean "
127+
"waves in the background. The can has the text `BFL Diffusers` on it and it has a color "
128+
"gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
129+
)
130+
131+
if args.prompt is not None:
132+
prompt = args.prompt
133+
134+
135+
def run_pipe(warmup: bool = False):
136+
generator = torch.Generator("cpu").manual_seed(0)
137+
image = pipe(
138+
prompt=prompt,
139+
# 28 steps can be a good trade-off
140+
num_inference_steps=5 if warmup else (28 if args.steps is None else args.steps),
141+
guidance_scale=4,
142+
generator=generator,
143+
).images[0]
144+
return image
145+
146+
147+
if args.compile:
148+
cache_dit.set_compile_configs()
149+
pipe.transformer = torch.compile(pipe.transformer)
150+
151+
# warmup
152+
_ = run_pipe(warmup=True)
153+
154+
memory_tracker = MemoryTracker() if args.track_memory else None
155+
if memory_tracker:
156+
memory_tracker.__enter__()
157+
158+
start = time.time()
159+
image = run_pipe()
160+
end = time.time()
161+
162+
if memory_tracker:
163+
memory_tracker.__exit__(None, None, None)
164+
memory_tracker.report()
165+
166+
if rank == 0:
167+
cache_dit.summary(pipe)
168+
169+
time_cost = end - start
170+
save_path = f"flux2.{strify(args, pipe)}.png"
171+
print(f"Time cost: {time_cost:.2f}s")
172+
print(f"Saving image to {save_path}")
173+
image.save(save_path)
174+
175+
maybe_destroy_distributed()

examples/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def get_args(
8888
"int4",
8989
"int4_weight_only",
9090
"bitsandbytes_4bit",
91+
"bnb_4bit", # alias for bitsandbytes_4bit
9192
],
9293
)
9394
parser.add_argument(
@@ -150,7 +151,11 @@ def get_args(
150151
default=False,
151152
help="Disable compute-communication overlap during compilation",
152153
)
153-
return parser.parse_args() if parse else parser
154+
args_or_parser = parser.parse_args() if parse else parser
155+
if parse:
156+
if args_or_parser.quantize_type == "bnb_4bit": # alias
157+
args_or_parser.quantize_type = "bitsandbytes_4bit"
158+
return args_or_parser
154159

155160

156161
def cachify(

src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
logger = init_logger(__name__)
4141

4242

43-
@ContextParallelismPlannerRegister.register("Flux")
43+
@ContextParallelismPlannerRegister.register("FluxTransformer2DModel")
4444
class FluxContextParallelismPlanner(ContextParallelismPlanner):
4545
def apply(
4646
self,
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from typing import Optional
3+
from diffusers.models.modeling_utils import ModelMixin
4+
from diffusers import Flux2Transformer2DModel
5+
6+
try:
7+
from diffusers.models._modeling_parallel import (
8+
ContextParallelInput,
9+
ContextParallelOutput,
10+
ContextParallelModelPlan,
11+
)
12+
except ImportError:
13+
raise ImportError(
14+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
15+
"Please install latest version of diffusers from source: \n"
16+
"pip3 install git+https://github.com/huggingface/diffusers.git"
17+
)
18+
from .cp_plan_registers import (
19+
ContextParallelismPlanner,
20+
ContextParallelismPlannerRegister,
21+
)
22+
23+
from cache_dit.logger import init_logger
24+
25+
logger = init_logger(__name__)
26+
27+
28+
@ContextParallelismPlannerRegister.register("Flux2Transformer2DModel")
29+
class Flux2ContextParallelismPlanner(ContextParallelismPlanner):
30+
def apply(
31+
self,
32+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
33+
**kwargs,
34+
) -> ContextParallelModelPlan:
35+
36+
# NOTE: Diffusers native CP plan still have bugs for Flux2 now.
37+
self._cp_planner_preferred_native_diffusers = False
38+
39+
if transformer is not None and self._cp_planner_preferred_native_diffusers:
40+
assert isinstance(
41+
transformer, Flux2Transformer2DModel
42+
), "Transformer must be an instance of Flux2Transformer2DModel"
43+
if hasattr(transformer, "_cp_plan"):
44+
if transformer._cp_plan is not None:
45+
return transformer._cp_plan
46+
47+
# Otherwise, use the custom CP plan defined here, this maybe
48+
# a little different from the native diffusers implementation
49+
# for some models.
50+
_cp_plan = {
51+
# Here is a Transformer level CP plan for Flux, which will
52+
# only apply the only 1 split hook (pre_forward) on the forward
53+
# of Transformer, and gather the output after Transformer forward.
54+
# Pattern of transformer forward, split_output=False:
55+
# un-split input -> splited input (inside transformer)
56+
# Pattern of the transformer_blocks, single_transformer_blocks:
57+
# splited input (previous splited output) -> to_qkv/...
58+
# -> all2all
59+
# -> attn (local head, full seqlen)
60+
# -> all2all
61+
# -> splited output
62+
# The `hidden_states` and `encoder_hidden_states` will still keep
63+
# itself splited after block forward (namely, automatic split by
64+
# the all2all comm op after attn) for the all blocks.
65+
# img_ids and txt_ids will only be splited once at the very beginning,
66+
# and keep splited through the whole transformer forward. The all2all
67+
# comm op only happens on the `out` tensor after local attn not on
68+
# img_ids and txt_ids.
69+
"": {
70+
"hidden_states": ContextParallelInput(
71+
split_dim=1, expected_dims=3, split_output=False
72+
),
73+
"encoder_hidden_states": ContextParallelInput(
74+
split_dim=1, expected_dims=3, split_output=False
75+
),
76+
"img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
77+
"txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
78+
},
79+
# Then, the final proj_out will gather the splited output.
80+
# splited input (previous splited output)
81+
# -> all gather
82+
# -> un-split output
83+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
84+
}
85+
return _cp_plan
86+
87+
88+
# TODO: Add async Ulysses QKV proj for FLUX2 model

src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from .cp_plan_dit import DiTContextParallelismPlanner
7777
from .cp_plan_kandinsky import Kandinsky5ContextParallelismPlanner
7878
from .cp_plan_skyreels import SkyReelsV2ContextParallelismPlanner
79+
from .cp_plan_flux2 import Flux2ContextParallelismPlanner
7980

8081
try:
8182
import nunchaku # noqa: F401
@@ -112,6 +113,7 @@
112113
"DiTContextParallelismPlanner",
113114
"Kandinsky5ContextParallelismPlanner",
114115
"SkyReelsV2ContextParallelismPlanner",
116+
"Flux2ContextParallelismPlanner",
115117
]
116118

117119
if _nunchaku_available:

src/cache_dit/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def print_tensor(
5454
if disable:
5555
return
5656

57+
if x is None:
58+
print(f"{name} is None")
59+
return
60+
61+
if not isinstance(x, torch.Tensor):
62+
print(f"{name} is not a tensor, type: {type(x)}")
63+
return
64+
5765
x = x.contiguous()
5866
if torch.distributed.is_initialized():
5967
# all gather hidden_states and check values mean

0 commit comments

Comments
 (0)