diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 78dca3be..7222f291 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -74,14 +74,13 @@ attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { - "block_q" : 3024, - "block_kv_compute" : 1024, + "block_q" : 2048, + "block_kv_compute" : 512, "block_kv" : 2048, - "block_q_dkv" : 3024, + "block_q_dkv" : 2048, "block_kv_dkv" : 2048, - "block_kv_dkv_compute" : 2048, - "block_q_dq" : 3024, - "block_kv_dq" : 2048 + "block_kv_dkv_compute" : 512, + "use_fused_bwd_kernel": True } # Use on v6e # flash_block_sizes: {