diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 1512485b..0837a503 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -58,6 +58,7 @@ jobs: pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - name: PyTest run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py + export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x # add_pull_ready: # if: github.ref != 'refs/heads/main' diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 84d4505d..8cd7e70f 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -64,14 +64,13 @@ flash_min_seq_length: 4096 dropout: 0.1 flash_block_sizes: { - "block_q" : 1024, - "block_kv_compute" : 256, - "block_kv" : 1024, - "block_q_dkv" : 1024, - "block_kv_dkv" : 1024, - "block_kv_dkv_compute" : 256, - "block_q_dq" : 1024, - "block_kv_dq" : 1024 + "block_q" : 2048, + "block_kv_compute" : 512, + "block_kv" : 2048, + "block_q_dkv" : 2048, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 512, + "use_fused_bwd_kernel" : True } # Use on v6e # flash_block_sizes: { diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 153c225d..557a9dfe 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -302,8 +302,8 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, + weight_calibration_method="fixed,-224,224", + act_calibration_method="fixed,-224,224", bwd_calibration_method=config.quantization_calibration_method, op_names=("dot_general", "einsum"), ), @@ -313,8 +313,8 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, + weight_calibration_method="fixed,-224,224", + act_calibration_method="fixed,-224,224", bwd_calibration_method=config.quantization_calibration_method, op_names=("conv_general_dilated"), ),