Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
strategy:
fail-fast: false
matrix:
tpu-type: ["v5p-8"]
tpu-type: ["v4-8"]
name: "TPU test (${{ matrix.tpu-type }})"
runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
steps:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ After installation completes, run the training script.

### Deploying with XPK

This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).
This assumes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).

Using v5p-256 Then the command to run on xpk is as follows:

Expand Down
2 changes: 1 addition & 1 deletion docs/data_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `
| Pipeline | Dataset Location | Dataset formats | Features or limitations |
| -------- | ---------------- | --------------- | ----------------------- |
| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset |
| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
| tf | dataset will be downloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
| tfrecord | local/Cloud Storage | TFRecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
| Grain | local/Cloud Storage | ArrayRecord (or any random access format) | data are not loaded in memory but streamed from the saved location, good for big dataset, supports global shuffle and data iterator checkpoint for determinism (see details in [doc](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-pipeline---for-determinism)) |

Expand Down
2 changes: 1 addition & 1 deletion docs/train_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Now let's change the configuration as follows:

Then our mesh will look like `Mesh('data': 2, 'fsdp': 2, 'tensor': 1)`.

The `logical_axis_rules` specifies the sharding across the mesh. You are encouranged to add or remove rules and find what best works for you.
The `logical_axis_rules` specifies the sharding across the mesh. You are encouraged to add or remove rules and find what best works for you.

### Checkpointing

Expand Down
12 changes: 7 additions & 5 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ attention_sharding_uniform: True
dropout: 0.1

flash_block_sizes: {
"block_q" : 2048,
"block_q" : 512,
"block_kv_compute" : 512,
"block_kv" : 2048,
"block_q_dkv" : 2048,
"block_kv_dkv" : 2048,
"block_kv" : 512,
"block_q_dkv" : 512,
"block_kv_dkv" : 512,
"block_kv_dkv_compute" : 512,
"use_fused_bwd_kernel": True
"block_q_dq" : 512,
"block_kv_dq" : 512,
"use_fused_bwd_kernel": False,
}
# Use on v6e
# flash_block_sizes: {
Expand Down
17 changes: 9 additions & 8 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ attention_sharding_uniform: True
dropout: 0.1

flash_block_sizes: {
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 1024,
"block_q_dkv" : 1024,
"block_kv_dkv" : 1024,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 1024
"block_q" : 512,
"block_kv_compute" : 512,
"block_kv" : 512,
"block_q_dkv" : 512,
"block_kv_dkv" : 512,
"block_kv_dkv_compute" : 512,
"block_q_dq" : 512,
"block_kv_dq" : 512,
"use_fused_bwd_kernel": False,
}
# Use on v6e
# flash_block_sizes: {
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129

Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
Tihs function is mostly copied from PyTorch's __getattr__ overwrite:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""

Expand Down Expand Up @@ -540,7 +540,7 @@ def extract_init_dict(cls, config_dict, **kwargs):
f"{cls.config_name} configuration file."
)

# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
# 5. Give nice info if config attributes are initialized to default because they have not been passed
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.info(f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values.")
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidan

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# Helps solve overexposure problem when terminal SNR approaches zero.
# Empirical values recomended from the paper are guidance_scale=7.5 and guidance_rescale=0.7
# Empirical values recommended from the paper are guidance_scale=7.5 and guidance_rescale=0.7
noise_pred = jax.lax.cond(
guidance_rescale[0] > 0,
lambda _: rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale),
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/generate_sdxl_replicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
NUM_DEVICES = jax.device_count()

# 1. Let's start by downloading the model and loading it into our pipeline class
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
# Adhering to JAX's functional approach, the model's parameters are returned separately and
# will have to be passed to the pipeline during inference
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
Expand Down Expand Up @@ -83,7 +83,7 @@ def replicate_all(prompt_ids, neg_prompt_ids, seed):
# to the function and tell JAX which are static arguments, that is, arguments that
# are known at compile time and won't change. In our case, it is num_inference_steps,
# height, width and return_latents.
# Once the function is compiled, these parameters are ommited from future calls and
# Once the function is compiled, these parameters are omitted from future calls and
# cannot be changed without modifying the code and recompiling.
def aot_compile(
prompt=default_prompt,
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _make_tfrecord_iterator(
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.

# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
# checks that the dataset path is valid. In case of gcs, the existence of the dir is not checked.
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)

# Determine whether to use the "cached" dataset, which requires externally
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
)

if len(sds_sd) > 0:
max_logging.log(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
max_logging.log(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")

return ait_sd

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def walk_and_upload_blobs(config, output_dir):

def device_put_replicated(x, sharding):
"""
Although the name indiciates replication, this function can be used
Although the name indicates replication, this function can be used
to also shard an array based on sharding.
"""
return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index])
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def __call__(
# 4. mid
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)

# 5. contronet blocks
# 5. ControlNet blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
@flax_register_to_config
class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
The Tranformer model introduced in Flux.
The Transformer model introduced in Flux.

Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def match_aqt_and_unquantized_param(aqt_params, params):
)
param_tree_flat, _ = jax.tree_util.tree_flatten_with_path(params)
aqt_paths = []
# Orginal path of quantized AQT param path.
# Original path of quantized AQT param path.
param_paths = []

for aqt_k, _ in aqt_param_flat:
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class WanUpsample(nnx.Module):

def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"):
# scale_factor for (H, W)
# JAX resize works on spatial dims, H, W assumming (N, D, H, W, C) or (N, H, W, C)
# JAX resize works on spatial dims, H, W assuming (N, D, H, W, C) or (N, H, W, C)
self.scale_factor = scale_factor
self.method = method

Expand Down Expand Up @@ -1116,7 +1116,7 @@ def _decode(
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
# Most likely due to an incorrect reshaping in the decoder.
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
# When batch_size is 0, expand batch dim for contatenation
# When batch_size is 0, expand batch dim for concatenation
# else, expand frame dim for concatenation so that batch dim stays intact.
axis = 0
if fm1.shape[0] > 1:
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DDIMSchedulerState:
common: CommonSchedulerState
final_alpha_cumprod: jnp.ndarray

# setable values
# settable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
Expand Down
6 changes: 3 additions & 3 deletions src/maxdiffusion/tests/generate_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUp(self):
super().setUp()
Generate.dummy_data = {}

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
def test_sd14_config(self):
img_url = os.path.join(THIS_DIR, "images", "test_gen_sd14.png")
base_image = np.array(Image.open(img_url)).astype(np.uint8)
Expand All @@ -42,7 +42,7 @@ def test_sd14_config(self):
assert base_image.shape == test_image.shape
assert ssim_compare >= 0.70

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
def test_sd_2_base_from_gcs(self):
img_url = os.path.join(THIS_DIR, "images", "test_2_base.png")
base_image = np.array(Image.open(img_url)).astype(np.uint8)
Expand All @@ -64,7 +64,7 @@ def test_sd_2_base_from_gcs(self):
assert base_image.shape == test_image.shape
assert ssim_compare >= 0.70

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
def test_controlnet(self):
img_url = os.path.join(THIS_DIR, "images", "cnet_test.png")
base_image = np.array(Image.open(img_url)).astype(np.uint8)
Expand Down
7 changes: 7 additions & 0 deletions src/maxdiffusion/tests/input_pipeline_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import shutil
import subprocess
import unittest
import pytest
from absl.testing import absltest
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -134,6 +135,7 @@ def test_make_dreambooth_train_iterator(self):
cleanup(instance_class_local_dir)
cleanup(class_class_local_dir)

@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
def test_make_pokemon_hf_iterator(self):
pyconfig.initialize(
[
Expand Down Expand Up @@ -237,6 +239,7 @@ def test_make_pokemon_hf_iterator_sdxl(self):
assert data["input_ids"].shape == (device_count, 2, 77)
assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution)

@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
def test_make_pokemon_tf_iterator_cache(self):
pyconfig.initialize(
[
Expand Down Expand Up @@ -299,6 +302,7 @@ def test_make_pokemon_tf_iterator_cache(self):
config.resolution // vae_scale_factor,
)

@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
def test_make_pokemon_iterator_no_cache(self):
pyconfig.initialize(
[
Expand Down Expand Up @@ -431,6 +435,7 @@ def test_make_pokemon_iterator_sdxl_cache(self):
config.resolution // vae_scale_factor,
)

@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
def test_make_laion_grain_iterator(self):
try:
subprocess.check_output(
Expand Down Expand Up @@ -487,6 +492,7 @@ def test_make_laion_grain_iterator(self):
8,
)

@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
def test_make_laion_tfrecord_iterator(self):
pyconfig.initialize(
[
Expand Down Expand Up @@ -547,6 +553,7 @@ def _parse_tfrecord_fn(example):
8,
)

@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
def test_tfrecord(self):
"""Validate latents match a deterministic output image"""

Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/tests/train_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_sdxl_config(self):

delete_blobs(os.path.join(output_dir, run_name))

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
def test_dreambooth_orbax(self):
num_class_images = 100
output_dir = "gs://maxdiffusion-github-runner-test-assets"
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_dreambooth_orbax(self):
cleanup(class_class_local_dir)
delete_blobs(os.path.join(output_dir, run_name))

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
def test_sd15_orbax(self):
output_dir = "gs://maxdiffusion-github-runner-test-assets"
run_name = "sd15_orbax_smoke_test"
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _validate_gcs_bucket_name(bucket_name, config_var):

assert (
config.max_train_steps > 0 or config.num_train_epochs > 0
), "You must set steps or learning_rate_schedule_steps to a positive interger."
), "You must set steps or learning_rate_schedule_steps to a positive integer."

if config.checkpoint_every > 0 and len(config.checkpoint_dir) <= 0:
raise AssertionError("Need to set checkpoint_dir when checkpoint_every is set.")
Expand Down Expand Up @@ -201,7 +201,7 @@ def generate_timestep_weights(config, num_timesteps):

@contextmanager
def transformer_engine_context():
"""If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation."""
"""If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correct operation."""
try:
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def export_to_video(
bitrate:
Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead.
Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter
rather than specifiying a fixed bitrate with this parameter.
rather than specifying a fixed bitrate with this parameter.

macro_block_size:
Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number
Expand Down
6 changes: 3 additions & 3 deletions src/maxdiffusion/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
r = lora_alpha = list(rank_dict.values())[0]

if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
# get the rank occurring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]

# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
Expand Down Expand Up @@ -178,7 +178,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
# For backward compatibility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
Expand All @@ -188,7 +188,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
# set multiple active adapters
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
# For backward compatibility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_names)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def randn_tensor(
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
f" slightly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/video_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[

# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
# - if it is a single video, it is convereted to a list of one video.
# - if it is a single video, it is converted to a list of one video.
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
Expand Down
4 changes: 2 additions & 2 deletions tests/schedulers/test_scheduler_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def test_full_loop_with_set_alpha_to_one(self):
result_mean = jnp.mean(jnp.abs(sample))

if jax_device == "tpu":
assert abs(result_sum - 186.94574) < 8e-2
assert abs(result_sum - 186.83226) < 8e-2
assert abs(result_mean - 0.24327) < 1e-3
else:
assert abs(result_sum - 186.9466) < 1e-2
Expand All @@ -932,7 +932,7 @@ def test_full_loop_with_no_set_alpha_to_one(self):
result_mean = jnp.mean(jnp.abs(sample))

if jax_device == "tpu":
assert abs(result_sum - 186.94574) < 8e-2
assert abs(result_sum - 186.83226) < 8e-2
assert abs(result_mean - 0.24327) < 1e-3
else:
assert abs(result_sum - 186.9482) < 1e-2
Expand Down
Loading