diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 0837a503..353ea0b2 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -31,7 +31,7 @@ jobs: strategy: fail-fast: false matrix: - tpu-type: ["v5p-8"] + tpu-type: ["v4-8"] name: "TPU test (${{ matrix.tpu-type }})" runs-on: ["self-hosted","${{ matrix.tpu-type }}"] steps: diff --git a/README.md b/README.md index b5720a6a..1b0dc69c 100644 --- a/README.md +++ b/README.md @@ -279,7 +279,7 @@ After installation completes, run the training script. ### Deploying with XPK - This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md). + This assumes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md). Using v5p-256 Then the command to run on xpk is as follows: diff --git a/docs/data_README.md b/docs/data_README.md index 2b726dc2..5459f290 100644 --- a/docs/data_README.md +++ b/docs/data_README.md @@ -5,7 +5,7 @@ Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag ` | Pipeline | Dataset Location | Dataset formats | Features or limitations | | -------- | ---------------- | --------------- | ----------------------- | | HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset | -| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset | +| tf | dataset will be downloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset | | tfrecord | local/Cloud Storage | TFRecord | data are not loaded in memory but streamed from the saved location, good for big dataset | | Grain | local/Cloud Storage | ArrayRecord (or any random access format) | data are not loaded in memory but streamed from the saved location, good for big dataset, supports global shuffle and data iterator checkpoint for determinism (see details in [doc](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-pipeline---for-determinism)) | diff --git a/docs/train_README.md b/docs/train_README.md index 74ff8ce4..efd0d48b 100644 --- a/docs/train_README.md +++ b/docs/train_README.md @@ -129,7 +129,7 @@ Now let's change the configuration as follows: Then our mesh will look like `Mesh('data': 2, 'fsdp': 2, 'tensor': 1)`. -The `logical_axis_rules` specifies the sharding across the mesh. You are encouranged to add or remove rules and find what best works for you. +The `logical_axis_rules` specifies the sharding across the mesh. You are encouraged to add or remove rules and find what best works for you. ### Checkpointing diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1b647424..f152ac73 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -74,13 +74,15 @@ attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { - "block_q" : 2048, + "block_q" : 512, "block_kv_compute" : 512, - "block_kv" : 2048, - "block_q_dkv" : 2048, - "block_kv_dkv" : 2048, + "block_kv" : 512, + "block_q_dkv" : 512, + "block_kv_dkv" : 512, "block_kv_dkv_compute" : 512, - "use_fused_bwd_kernel": True + "block_q_dq" : 512, + "block_kv_dq" : 512, + "use_fused_bwd_kernel": False, } # Use on v6e # flash_block_sizes: { diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 1b93a32a..314d1141 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -73,14 +73,15 @@ attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { - "block_q" : 1024, - "block_kv_compute" : 256, - "block_kv" : 1024, - "block_q_dkv" : 1024, - "block_kv_dkv" : 1024, - "block_kv_dkv_compute" : 256, - "block_q_dq" : 1024, - "block_kv_dq" : 1024 + "block_q" : 512, + "block_kv_compute" : 512, + "block_kv" : 512, + "block_q_dkv" : 512, + "block_kv_dkv" : 512, + "block_kv_dkv_compute" : 512, + "block_q_dq" : 512, + "block_kv_dq" : 512, + "use_fused_bwd_kernel": False, } # Use on v6e # flash_block_sizes: { diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 76c02d10..3a495e02 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -145,7 +145,7 @@ def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 - Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite: + Tihs function is mostly copied from PyTorch's __getattr__ overwrite: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module """ @@ -540,7 +540,7 @@ def extract_init_dict(cls, config_dict, **kwargs): f"{cls.config_name} configuration file." ) - # 5. Give nice info if config attributes are initiliazed to default because they have not been passed + # 5. Give nice info if config attributes are initialized to default because they have not been passed passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.info(f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values.") diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index be66ac0f..ac4fbb7f 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -66,7 +66,7 @@ def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidan # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Helps solve overexposure problem when terminal SNR approaches zero. - # Empirical values recomended from the paper are guidance_scale=7.5 and guidance_rescale=0.7 + # Empirical values recommended from the paper are guidance_scale=7.5 and guidance_rescale=0.7 noise_pred = jax.lax.cond( guidance_rescale[0] > 0, lambda _: rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale), diff --git a/src/maxdiffusion/generate_sdxl_replicated.py b/src/maxdiffusion/generate_sdxl_replicated.py index 5ed710b2..d17fc02d 100644 --- a/src/maxdiffusion/generate_sdxl_replicated.py +++ b/src/maxdiffusion/generate_sdxl_replicated.py @@ -32,7 +32,7 @@ NUM_DEVICES = jax.device_count() # 1. Let's start by downloading the model and loading it into our pipeline class -# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and +# Adhering to JAX's functional approach, the model's parameters are returned separately and # will have to be passed to the pipeline during inference pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True @@ -83,7 +83,7 @@ def replicate_all(prompt_ids, neg_prompt_ids, seed): # to the function and tell JAX which are static arguments, that is, arguments that # are known at compile time and won't change. In our case, it is num_inference_steps, # height, width and return_latents. -# Once the function is compiled, these parameters are ommited from future calls and +# Once the function is compiled, these parameters are omitted from future calls and # cannot be changed without modifying the code and recompiling. def aot_compile( prompt=default_prompt, diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index cc0c8fd3..b8992415 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -97,7 +97,7 @@ def _make_tfrecord_iterator( # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. # if is_training is True, loads the training dataset. If False, loads the evaluation dataset. - # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked. + # checks that the dataset path is valid. In case of gcs, the existence of the dir is not checked. is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location) # Determine whether to use the "cached" dataset, which requires externally diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 854aeaf6..5f9e72a6 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -504,7 +504,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): ) if len(sds_sd) > 0: - max_logging.log(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + max_logging.log(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}") return ait_sd diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 48c6ca44..fb7266a1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -222,7 +222,7 @@ def walk_and_upload_blobs(config, output_dir): def device_put_replicated(x, sharding): """ - Although the name indiciates replication, this function can be used + Although the name indicates replication, this function can be used to also shard an array based on sharding. """ return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index]) diff --git a/src/maxdiffusion/models/controlnet_flax.py b/src/maxdiffusion/models/controlnet_flax.py index e37694a1..3ab58d33 100644 --- a/src/maxdiffusion/models/controlnet_flax.py +++ b/src/maxdiffusion/models/controlnet_flax.py @@ -427,7 +427,7 @@ def __call__( # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) - # 5. contronet blocks + # 5. ControlNet blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 505a4f4f..7f63da67 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -302,7 +302,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb= @flax_register_to_config class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" - The Tranformer model introduced in Flux. + The Transformer model introduced in Flux. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ diff --git a/src/maxdiffusion/models/quantizations.py b/src/maxdiffusion/models/quantizations.py index f7667d65..766842d1 100644 --- a/src/maxdiffusion/models/quantizations.py +++ b/src/maxdiffusion/models/quantizations.py @@ -221,7 +221,7 @@ def match_aqt_and_unquantized_param(aqt_params, params): ) param_tree_flat, _ = jax.tree_util.tree_flatten_with_path(params) aqt_paths = [] - # Orginal path of quantized AQT param path. + # Original path of quantized AQT param path. param_paths = [] for aqt_k, _ in aqt_param_flat: diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 77f35073..1da2d18f 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -157,7 +157,7 @@ class WanUpsample(nnx.Module): def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"): # scale_factor for (H, W) - # JAX resize works on spatial dims, H, W assumming (N, D, H, W, C) or (N, H, W, C) + # JAX resize works on spatial dims, H, W assuming (N, D, H, W, C) or (N, H, W, C) self.scale_factor = scale_factor self.method = method @@ -1116,7 +1116,7 @@ def _decode( # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. # Most likely due to an incorrect reshaping in the decoder. fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] - # When batch_size is 0, expand batch dim for contatenation + # When batch_size is 0, expand batch dim for concatenation # else, expand frame dim for concatenation so that batch dim stays intact. axis = 0 if fm1.shape[0] > 1: diff --git a/src/maxdiffusion/schedulers/scheduling_ddim_flax.py b/src/maxdiffusion/schedulers/scheduling_ddim_flax.py index a7379bfd..2c478b79 100644 --- a/src/maxdiffusion/schedulers/scheduling_ddim_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_ddim_flax.py @@ -37,7 +37,7 @@ class DDIMSchedulerState: common: CommonSchedulerState final_alpha_cumprod: jnp.ndarray - # setable values + # settable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index 5b51f1d5..d305427c 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -22,7 +22,7 @@ def setUp(self): super().setUp() Generate.dummy_data = {} - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + @pytest.mark.skip("This test is deprecated and will be removed in a future version.") def test_sd14_config(self): img_url = os.path.join(THIS_DIR, "images", "test_gen_sd14.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) @@ -42,7 +42,7 @@ def test_sd14_config(self): assert base_image.shape == test_image.shape assert ssim_compare >= 0.70 - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + @pytest.mark.skip("This test is deprecated and will be removed in a future version.") def test_sd_2_base_from_gcs(self): img_url = os.path.join(THIS_DIR, "images", "test_2_base.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) @@ -64,7 +64,7 @@ def test_sd_2_base_from_gcs(self): assert base_image.shape == test_image.shape assert ssim_compare >= 0.70 - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + @pytest.mark.skip("This test is deprecated and will be removed in a future version.") def test_controlnet(self): img_url = os.path.join(THIS_DIR, "images", "cnet_test.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 03825589..1141ec8c 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -20,6 +20,7 @@ import shutil import subprocess import unittest +import pytest from absl.testing import absltest import numpy as np import tensorflow as tf @@ -134,6 +135,7 @@ def test_make_dreambooth_train_iterator(self): cleanup(instance_class_local_dir) cleanup(class_class_local_dir) + @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") def test_make_pokemon_hf_iterator(self): pyconfig.initialize( [ @@ -237,6 +239,7 @@ def test_make_pokemon_hf_iterator_sdxl(self): assert data["input_ids"].shape == (device_count, 2, 77) assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) + @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") def test_make_pokemon_tf_iterator_cache(self): pyconfig.initialize( [ @@ -299,6 +302,7 @@ def test_make_pokemon_tf_iterator_cache(self): config.resolution // vae_scale_factor, ) + @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") def test_make_pokemon_iterator_no_cache(self): pyconfig.initialize( [ @@ -431,6 +435,7 @@ def test_make_pokemon_iterator_sdxl_cache(self): config.resolution // vae_scale_factor, ) + @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") def test_make_laion_grain_iterator(self): try: subprocess.check_output( @@ -487,6 +492,7 @@ def test_make_laion_grain_iterator(self): 8, ) + @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") def test_make_laion_tfrecord_iterator(self): pyconfig.initialize( [ @@ -547,6 +553,7 @@ def _parse_tfrecord_fn(example): 8, ) + @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") def test_tfrecord(self): """Validate latents match a deterministic output image""" diff --git a/src/maxdiffusion/tests/train_smoke_test.py b/src/maxdiffusion/tests/train_smoke_test.py index d91aa8b3..a7d0f4b8 100644 --- a/src/maxdiffusion/tests/train_smoke_test.py +++ b/src/maxdiffusion/tests/train_smoke_test.py @@ -96,7 +96,7 @@ def test_sdxl_config(self): delete_blobs(os.path.join(output_dir, run_name)) - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + @pytest.mark.skip("This test is deprecated and will be removed in a future version.") def test_dreambooth_orbax(self): num_class_images = 100 output_dir = "gs://maxdiffusion-github-runner-test-assets" @@ -149,7 +149,7 @@ def test_dreambooth_orbax(self): cleanup(class_class_local_dir) delete_blobs(os.path.join(output_dir, run_name)) - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + @pytest.mark.skip("This test is deprecated and will be removed in a future version.") def test_sd15_orbax(self): output_dir = "gs://maxdiffusion-github-runner-test-assets" run_name = "sd15_orbax_smoke_test" diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 9d2b8a3f..79e65e99 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -50,7 +50,7 @@ def _validate_gcs_bucket_name(bucket_name, config_var): assert ( config.max_train_steps > 0 or config.num_train_epochs > 0 - ), "You must set steps or learning_rate_schedule_steps to a positive interger." + ), "You must set steps or learning_rate_schedule_steps to a positive integer." if config.checkpoint_every > 0 and len(config.checkpoint_dir) <= 0: raise AssertionError("Need to set checkpoint_dir when checkpoint_every is set.") @@ -201,7 +201,7 @@ def generate_timestep_weights(config, num_timesteps): @contextmanager def transformer_engine_context(): - """If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation.""" + """If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correct operation.""" try: from transformer_engine.jax.sharding import global_shard_guard, MeshResource # Inform TransformerEngine of MaxDiffusion's physical mesh resources. diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index 5dfa3562..2106904e 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -154,7 +154,7 @@ def export_to_video( bitrate: Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter - rather than specifiying a fixed bitrate with this parameter. + rather than specifying a fixed bitrate with this parameter. macro_block_size: Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number diff --git a/src/maxdiffusion/utils/peft_utils.py b/src/maxdiffusion/utils/peft_utils.py index fd81b700..92a4b327 100644 --- a/src/maxdiffusion/utils/peft_utils.py +++ b/src/maxdiffusion/utils/peft_utils.py @@ -114,7 +114,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True r = lora_alpha = list(rank_dict.values())[0] if len(set(rank_dict.values())) > 1: - # get the rank occuring the most number of times + # get the rank occurring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] # for modules with rank different from the most occuring rank, add it to the `rank_pattern` @@ -178,7 +178,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): for adapter_name, weight in zip(adapter_names, weights): for module in model.modules(): if isinstance(module, BaseTunerLayer): - # For backward compatbility with previous PEFT versions + # For backward compatibility with previous PEFT versions if hasattr(module, "set_adapter"): module.set_adapter(adapter_name) else: @@ -188,7 +188,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): # set multiple active adapters for module in model.modules(): if isinstance(module, BaseTunerLayer): - # For backward compatbility with previous PEFT versions + # For backward compatibility with previous PEFT versions if hasattr(module, "set_adapter"): module.set_adapter(adapter_names) else: diff --git a/src/maxdiffusion/utils/torch_utils.py b/src/maxdiffusion/utils/torch_utils.py index 54b046a7..dfa79380 100644 --- a/src/maxdiffusion/utils/torch_utils.py +++ b/src/maxdiffusion/utils/torch_utils.py @@ -60,7 +60,7 @@ def randn_tensor( logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" - f" slighly speed up this function by passing a generator that was created on the {device} device." + f" slightly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") diff --git a/src/maxdiffusion/video_processor.py b/src/maxdiffusion/video_processor.py index c2948511..30368a76 100644 --- a/src/maxdiffusion/video_processor.py +++ b/src/maxdiffusion/video_processor.py @@ -67,7 +67,7 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[ # ensure the input is a list of videos: # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) - # - if it is a single video, it is convereted to a list of one video. + # - if it is a single video, it is converted to a list of one video. if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: video = list(video) elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py index 29fd446a..eab5cb91 100644 --- a/tests/schedulers/test_scheduler_flax.py +++ b/tests/schedulers/test_scheduler_flax.py @@ -919,7 +919,7 @@ def test_full_loop_with_set_alpha_to_one(self): result_mean = jnp.mean(jnp.abs(sample)) if jax_device == "tpu": - assert abs(result_sum - 186.94574) < 8e-2 + assert abs(result_sum - 186.83226) < 8e-2 assert abs(result_mean - 0.24327) < 1e-3 else: assert abs(result_sum - 186.9466) < 1e-2 @@ -932,7 +932,7 @@ def test_full_loop_with_no_set_alpha_to_one(self): result_mean = jnp.mean(jnp.abs(sample)) if jax_device == "tpu": - assert abs(result_sum - 186.94574) < 8e-2 + assert abs(result_sum - 186.83226) < 8e-2 assert abs(result_mean - 0.24327) < 1e-3 else: assert abs(result_sum - 186.9482) < 1e-2