Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
Expand Down
15 changes: 7 additions & 8 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ flash_min_seq_length: 4096
dropout: 0.1

flash_block_sizes: {
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 1024,
"block_q_dkv" : 1024,
"block_kv_dkv" : 1024,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 1024
"block_q" : 2048,
"block_kv_compute" : 512,
"block_kv" : 2048,
"block_q_dkv" : 2048,
"block_kv_dkv" : 2048,
"block_kv_dkv_compute" : 512,
"use_fused_bwd_kernel" : True
}
# Use on v6e
# flash_block_sizes: {
Expand Down
8 changes: 4 additions & 4 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def get_fp8_config(cls, config: HyperParameters):
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
weight_calibration_method="fixed,-224,224",
act_calibration_method="fixed,-224,224",
bwd_calibration_method=config.quantization_calibration_method,
op_names=("dot_general", "einsum"),
),
Expand All @@ -313,8 +313,8 @@ def get_fp8_config(cls, config: HyperParameters):
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
weight_calibration_method="fixed,-224,224",
act_calibration_method="fixed,-224,224",
bwd_calibration_method=config.quantization_calibration_method,
op_names=("conv_general_dilated"),
),
Expand Down
Loading