diff --git a/.github/workflows/CPUTests.yml b/.github/workflows/CPUTests.yml index 79eda01b..df0bccf2 100644 --- a/.github/workflows/CPUTests.yml +++ b/.github/workflows/CPUTests.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ['3.10'] + python-version: ['3.12'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -22,7 +22,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - pip install pylint pyink pytype==2024.2.27 + pip install pylint pyink==23.10.0 pytype==2024.2.27 # - name: Typecheck the code with pytype # run: | # pytype --jobs auto --disable import-error src/maxdiffusion/ diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py index 20fd0d8a..33f4c0b2 100644 --- a/end_to_end/tpu/eval_assert.py +++ b/end_to_end/tpu/eval_assert.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Example to run diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 42e50d77..a1a2c2f5 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" __version__ = "0.22.0.dev0" @@ -84,25 +84,23 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) + _import_structure["models"].extend([ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ]) _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -113,56 +111,52 @@ "get_scheduler", ] - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) + _import_structure["pipelines"].extend([ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ]) + _import_structure["schedulers"].extend([ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ]) _import_structure["training_utils"] = ["EMAModel"] try: @@ -202,100 +196,98 @@ ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ]) try: if not (is_torch_available() and is_k_diffusion_available()): @@ -321,16 +313,14 @@ ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ]) try: if not (is_torch_available() and is_librosa_available()): @@ -376,19 +366,17 @@ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["schedulers"].extend([ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ]) try: @@ -403,16 +391,14 @@ else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ]) try: if not (is_note_seq_available()): diff --git a/src/maxdiffusion/checkpointing/__init__.py b/src/maxdiffusion/checkpointing/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/checkpointing/__init__.py +++ b/src/maxdiffusion/checkpointing/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index 9faba8bc..baf5bdd6 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from abc import ABC from contextlib import nullcontext @@ -66,7 +66,6 @@ def __init__(self, config, checkpoint_type): ) def _create_optimizer(self, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) @@ -74,7 +73,6 @@ def _create_optimizer(self, config, learning_rate): return tx, learning_rate_scheduler def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training): - tx, learning_rate_scheduler = None, None if is_training: learning_rate = self.config.learning_rate @@ -96,7 +94,6 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training) return unet_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): - # Currently VAE training is not supported. weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) return max_utils.setup_initial_state( @@ -112,7 +109,6 @@ def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=F ) def create_text_encoder_state(self, pipeline, params, checkpoint_item_name, is_training): - tx = None if is_training: learning_rate = self.config.text_encoder_learning_rate @@ -260,11 +256,9 @@ def config_to_json(model_or_config): self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step=None): - self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX def load_checkpoint(self, step=None, scheduler_class=None): - pipeline_class = self._get_pipeline_class() self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index bbad3ad1..960e0692 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -1,19 +1,19 @@ # ruff: noqa """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index 89ac3764..78ad000b 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from abc import ABC from contextlib import nullcontext @@ -67,7 +67,6 @@ def __init__(self, config, checkpoint_type): ) def _create_optimizer(self, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) @@ -109,7 +108,6 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training) return flux_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): - # Currently VAE training is not supported. weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) return max_utils.setup_initial_state( @@ -163,7 +161,6 @@ def config_to_json(model_or_config): self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step=None): - self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX def load_flux_configs_from_orbax(self, step): @@ -243,7 +240,6 @@ def load_diffusers_checkpoint(self): return pipeline, params def load_checkpoint(self, step=None, scheduler_class=None): - model_configs = self.load_flux_configs_from_orbax(step) pipeline, params = None, {} diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 006b3ec8..4ab90971 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from abc import ABC, abstractmethod @@ -35,14 +35,12 @@ def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT): self.checkpoint_type = checkpoint_type self.opt_state = None - self.checkpoint_manager: ocp.CheckpointManager = ( - create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, - ) + self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, ) def _create_optimizer(self, model, config, learning_rate): @@ -61,13 +59,18 @@ def load_diffusers_checkpoint(self): raise NotImplementedError @abstractmethod - def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int]]: + def load_checkpoint( + self, step=None + ) -> Tuple[ + Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int] + ]: raise NotImplementedError @abstractmethod def save_checkpoint(self, train_step, pipeline, train_states: dict): raise NotImplementedError + def save_checkpoint_orig(self, train_step, pipeline, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py index a8e2a297..da30567b 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointer2_1(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 30cff387..533a00db 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointer2_2(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: @@ -38,7 +39,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle low_noise_transformer low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + abstract_tree_structure_low_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata + ) low_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -48,7 +51,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle high_noise_transformer high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + abstract_tree_structure_high_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata + ) high_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -67,10 +72,18 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log( + f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}" + ) max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py index 6f4bbc90..5850692f 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointerI2V_2_1(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py index a55048cf..98f76f48 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointerI2V_2_2(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: @@ -38,7 +39,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle low_noise_transformer low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + abstract_tree_structure_low_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata + ) low_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -48,7 +51,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle high_noise_transformer high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + abstract_tree_structure_high_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata + ) high_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -67,10 +72,18 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log( + f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}" + ) max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 15553727..2be883e5 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -66,19 +66,19 @@ ### Common axis rules for ring attention ### RING_ATTENTION_AXIS_RULES = [ - [SELF_ATTN_HEAD, None], - [SELF_ATTN_Q_LENGTH, FSDP], - [SELF_ATTN_KV_LENGTH, FSDP], - [CROSS_ATTN_HEAD, None], - [CROSS_ATTN_Q_LENGTH, FSDP], - [CROSS_ATTN_KV_LENGTH, FSDP], + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, FSDP], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, FSDP], ] SEQUENCE_PARALLEL_AXIS_RULES = [ - [SELF_ATTN_HEAD, None], - [SELF_ATTN_Q_LENGTH, FSDP], - [SELF_ATTN_KV_LENGTH, None], - [CROSS_ATTN_HEAD, None], - [CROSS_ATTN_Q_LENGTH, FSDP], - [CROSS_ATTN_KV_LENGTH, None], + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, None], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, None], ] diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 3a495e02..0e8c9968 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" ConfigMixin base class and utilities.""" +"""ConfigMixin base class and utilities.""" import dataclasses import functools import importlib @@ -611,7 +611,6 @@ def to_json_saveable(value): config_dict.pop(key) try: - json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder) except Exception as e: max_logging.log(f"Error serializing config to JSON: {e}") diff --git a/src/maxdiffusion/controlnet/__init__.py b/src/maxdiffusion/controlnet/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/controlnet/__init__.py +++ b/src/maxdiffusion/controlnet/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/controlnet/generate_controlnet_replicated.py b/src/maxdiffusion/controlnet/generate_controlnet_replicated.py index a3959cbb..bd4ef6eb 100644 --- a/src/maxdiffusion/controlnet/generate_controlnet_replicated.py +++ b/src/maxdiffusion/controlnet/generate_controlnet_replicated.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence from absl import app @@ -28,7 +28,6 @@ def run(config): - rng = jax.random.PRNGKey(config.seed) # get canny image diff --git a/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py b/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py index b38202c8..235159a9 100644 --- a/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py +++ b/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence from absl import app diff --git a/src/maxdiffusion/data_preprocessing/__init__.py b/src/maxdiffusion/data_preprocessing/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/data_preprocessing/__init__.py +++ b/src/maxdiffusion/data_preprocessing/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index e0191373..64e9d54b 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ @@ -71,7 +71,6 @@ def create_example(latent, hidden_states, timestep=None): def generate_dataset(config): - tfrecords_dir = config.tfrecords_dir if not os.path.exists(tfrecords_dir): os.makedirs(tfrecords_dir) diff --git a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py index ae0b15f4..23baaffb 100644 --- a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py +++ b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ @@ -84,7 +84,6 @@ def vae_encode(video, rng, vae, vae_cache): def generate_dataset(config, pipeline): - tfrecords_dir = config.tfrecords_dir if not os.path.exists(tfrecords_dir): os.makedirs(tfrecords_dir) diff --git a/src/maxdiffusion/dreambooth/__init__.py b/src/maxdiffusion/dreambooth/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/dreambooth/__init__.py +++ b/src/maxdiffusion/dreambooth/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/dreambooth/dreambooth_constants.py b/src/maxdiffusion/dreambooth/dreambooth_constants.py index 72ac6003..bb366e15 100644 --- a/src/maxdiffusion/dreambooth/dreambooth_constants.py +++ b/src/maxdiffusion/dreambooth/dreambooth_constants.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" INSTANCE_IMAGES = "instance_images" INSTANCE_IMAGE_LATENTS = "instance_image_latents" diff --git a/src/maxdiffusion/dreambooth/train_dreambooth.py b/src/maxdiffusion/dreambooth/train_dreambooth.py index 5cb7e233..d9b17475 100644 --- a/src/maxdiffusion/dreambooth/train_dreambooth.py +++ b/src/maxdiffusion/dreambooth/train_dreambooth.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index ac4fbb7f..7b1f1f62 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import functools import time @@ -86,7 +86,6 @@ def tokenize(prompt, tokenizer): def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -132,7 +131,6 @@ def vae_decode(latents, state, pipeline): def run_inference(states, pipeline, params, config, rng, mesh, batch_size): - unet_state = states["unet_state"] vae_state = states["vae_state"] @@ -158,7 +156,6 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size): def run(config): - checkpoint_loader = GenerateSD(config, STABLE_DIFFUSION_CHECKPOINT) pipeline, params = checkpoint_loader.load_checkpoint() diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index b248156e..0ba8a7a8 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Callable, List, Union, Sequence @@ -137,7 +137,6 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo def run_inference( states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): - transformer_state = states["transformer"] vae_state = states["vae"] @@ -175,7 +174,6 @@ def pack_latents( def prepare_latents( batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array ): - # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) @@ -223,7 +221,6 @@ def get_t5_prompt_embeds( text_encoder: T5EncoderModel, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( @@ -256,7 +253,6 @@ def encode_prompt( num_images_per_prompt: int = 1, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt prompt_2 = prompt or prompt_2 prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 diff --git a/src/maxdiffusion/generate_flux_multi_res.py b/src/maxdiffusion/generate_flux_multi_res.py index 7d07883c..33179295 100644 --- a/src/maxdiffusion/generate_flux_multi_res.py +++ b/src/maxdiffusion/generate_flux_multi_res.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import List, Union, Sequence @@ -154,7 +154,6 @@ def run_inference( p_ts, vae_scale_factor, ): - transformer_state = states["transformer"] vae_state = states["vae"] @@ -194,7 +193,6 @@ def pack_latents( def prepare_latents( batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array ): - # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) @@ -270,7 +268,6 @@ def get_t5_prompt_embeds( text_encoder: T5EncoderModel, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( @@ -303,7 +300,6 @@ def encode_prompt( num_images_per_prompt: int = 1, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt prompt_2 = prompt or prompt_2 prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py index e6b8d4e2..c89f413a 100644 --- a/src/maxdiffusion/generate_flux_pipeline.py +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Sequence diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6ecc6666..93753f0c 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import numpy as np from absl import app @@ -34,7 +34,6 @@ def calculate_padding( source_height: int, source_width: int, target_height: int, target_width: int ) -> tuple[int, int, int, int]: - # Calculate total padding needed pad_height = target_height - source_height pad_width = target_width - source_width diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index 9ad1022d..3ab70370 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import functools from absl import app @@ -115,7 +115,6 @@ def tokenize(prompt, pipeline): def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -189,7 +188,6 @@ def vae_decode(latents, state, pipeline): def run_inference(states, pipeline, params, config, rng, mesh, batch_size): - unet_state = states["unet_state"] vae_state = states["vae_state"] diff --git a/src/maxdiffusion/generate_sdxl_replicated.py b/src/maxdiffusion/generate_sdxl_replicated.py index d17fc02d..83df3a99 100644 --- a/src/maxdiffusion/generate_sdxl_replicated.py +++ b/src/maxdiffusion/generate_sdxl_replicated.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import time diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d3aad31d..1075ac9d 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -66,10 +66,11 @@ def delete_file(file_path: str): else: max_logging.log(f"The file '{file_path}' does not exist.") + def get_git_commit_hash(): """Tries to get the current Git commit hash.""" try: - commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8') + commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8") return commit_hash except subprocess.CalledProcessError: max_logging.log("Warning: 'git rev-parse HEAD' failed. Not running in a git repo?") @@ -78,8 +79,10 @@ def get_git_commit_hash(): max_logging.log("Warning: 'git' command not found.") return None + jax.config.update("jax_use_shardy_partitioner", True) + def call_pipeline(config, pipeline, prompt, negative_prompt): model_key = config.model_name model_type = config.model_type diff --git a/src/maxdiffusion/input_pipeline/__init__.py b/src/maxdiffusion/input_pipeline/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/input_pipeline/__init__.py +++ b/src/maxdiffusion/input_pipeline/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/input_pipeline/_grain_data_processing.py b/src/maxdiffusion/input_pipeline/_grain_data_processing.py index 5ba3b637..6498b263 100644 --- a/src/maxdiffusion/input_pipeline/_grain_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_grain_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import dataclasses import glob diff --git a/src/maxdiffusion/input_pipeline/_hf_data_processing.py b/src/maxdiffusion/input_pipeline/_hf_data_processing.py index e0f1d725..10f276d6 100644 --- a/src/maxdiffusion/input_pipeline/_hf_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_hf_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import warnings import datasets diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index b8992415..dae9a3a1 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import tensorflow as tf @@ -41,7 +41,6 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou def make_tf_iterator( config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, tokenize_fn, image_transforms_fn ): - if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location): train_ds = load_from_disk(config.dataset_save_location) else: diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 27f2ad25..5486307d 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os from functools import partial diff --git a/src/maxdiffusion/loaders/flux_lora_pipeline.py b/src/maxdiffusion/loaders/flux_lora_pipeline.py index 5f449ee9..56844db7 100644 --- a/src/maxdiffusion/loaders/flux_lora_pipeline.py +++ b/src/maxdiffusion/loaders/flux_lora_pipeline.py @@ -22,7 +22,6 @@ class FluxLoraLoaderMixin(LoRABaseMixin): - _lora_lodable_modules = ["transformer", "text_encoder"] def load_lora_weights( @@ -98,7 +97,6 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name @classmethod @validate_hf_hub_args def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): - cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 7feb20ca..2d8c1c75 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -134,7 +134,6 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name): @classmethod def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): - network_alphas_for_interceptor = {} unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() diff --git a/src/maxdiffusion/max_logging.py b/src/maxdiffusion/max_logging.py index 32ac3d8f..2edb43f4 100644 --- a/src/maxdiffusion/max_logging.py +++ b/src/maxdiffusion/max_logging.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Stub for logging utilities. Right now just meant to avoid raw prints""" diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fb7266a1..37dbda94 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -1,19 +1,19 @@ # ruff: noqa """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator """ Common Max Utils needed by multiple modules""" @@ -502,17 +502,20 @@ def get_flash_block_sizes(config): flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: attention_is_tokamax = "tokamax" in config.attention - user_block_sizes:Dict[str, int] = config.flash_block_sizes + user_block_sizes: Dict[str, int] = config.flash_block_sizes if attention_is_tokamax: - max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." - "Hence following flash block properties specified will be ignored:" - f"block_q: {user_block_sizes['block_q']}," - f"block_q_dq: {user_block_sizes.get('block_q_dq')}," - f"block_kv_dq: {user_block_sizes.get('block_kv_dq')}," - f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}" - ) + max_logging.log( + "Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." + "Hence following flash block properties specified will be ignored:" + f"block_q: {user_block_sizes['block_q']}," + f"block_q_dq: {user_block_sizes.get('block_q_dq')}," + f"block_kv_dq: {user_block_sizes.get('block_kv_dq')}," + f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}" + ) flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"], + block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) + if attention_is_tokamax + else user_block_sizes["block_q"], block_kv_compute=user_block_sizes["block_kv_compute"], block_kv=user_block_sizes["block_kv"], block_q_dkv=user_block_sizes["block_q_dkv"], @@ -541,7 +544,6 @@ def get_memory_allocations(): def get_live_arrays(): - backend = jax.extend.backend.get_backend() live_arrays = backend.live_arrays() diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index b9b1abdc..32eed7f4 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import io from PIL import Image @@ -42,7 +42,6 @@ def load_sdxllightning_unet(config, pipeline, params): def maybe_load_sdxl_lora(config, pipeline, params): - def _noop_interceptor(next_fn, args, kwargs, context): return next_fn(*args, **kwargs) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 96a6f128..7ff8fd8f 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -24,7 +24,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 2982e19e..428e4a3c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -173,17 +173,20 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): return tensor, kv_size, seq_len -def convert_to_tokamax_splash_config( block_sizes: BlockSizes, - q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - residual_checkpoint_name: str | None = None, - attn_logits_soft_cap: float | None = None, - fuse_reciprocal: bool = True, - use_base2_exp: bool = False, - max_logit_const: float | None = None, - interpret: bool = False, - dq_reduction_steps: int | None = None) -> tokamax_splash_attention_kernel.SplashConfig: + +def convert_to_tokamax_splash_config( + block_sizes: BlockSizes, + q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + residual_checkpoint_name: str | None = None, + attn_logits_soft_cap: float | None = None, + fuse_reciprocal: bool = True, + use_base2_exp: bool = False, + max_logit_const: float | None = None, + interpret: bool = False, + dq_reduction_steps: int | None = None, +) -> tokamax_splash_attention_kernel.SplashConfig: assert block_sizes.use_fused_bwd_kernel, "Tokamax Splash attention only supports fused bwd kernel." return tokamax_splash_attention_kernel.SplashConfig( block_q=block_sizes.block_q, @@ -192,7 +195,7 @@ def convert_to_tokamax_splash_config( block_sizes: BlockSizes, block_q_dkv=block_sizes.block_q_dkv, block_kv_dkv=block_sizes.block_kv_dkv, block_kv_dkv_compute=block_sizes.block_kv_dkv_compute, - block_q_dq= None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, + block_q_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, block_kv_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_kv_dq, use_fused_bwd_kernel=block_sizes.use_fused_bwd_kernel, q_layout=q_layout, @@ -319,7 +322,9 @@ def wrap_flash_attention(query, key, value): # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. if attention_kernel == "tokamax_flash": - mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) + mask = tokamax_splash_attention_mask.FullMask( + _shape=(query.shape[2], key.shape[2]), + ) splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( mask=mask, q_seq_shards=1, # the sizes of the axis is sharding over seq_len @@ -333,7 +338,7 @@ def wrap_flash_attention(query, key, value): q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, save_residuals=True if attention_kernel == "ring" else False, - residual_checkpoint_name=residual_checkpoint_name + residual_checkpoint_name=residual_checkpoint_name, ) vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) @@ -559,7 +564,16 @@ def _apply_attention( ) elif attention_kernel == "ring": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, + query, + key * scale, + value, + heads, + mesh, + axis_names_q, + axis_names_kv, + flash_block_sizes, + dtype, + attention_kernel, mask_padding_tokens=mask_padding_tokens, ) elif attention_kernel == "cudnn_flash_te": @@ -671,9 +685,21 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + # New Class for Wan I2V class NNXSimpleFeedForward(nnx.Module): - def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult: int = 4, activation_fn: str = "gelu", dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: Optional[jax.lax.Precision] = None): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + activation_fn: str = "gelu", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: Optional[jax.lax.Precision] = None, + ): inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim self.net_0 = nnx.Linear( @@ -706,6 +732,7 @@ def __call__(self, hidden_states: Array) -> Array: hidden_states = self.net_2(hidden_states) return hidden_states + class NNXAttentionOp(nnx.Module): def __init__( @@ -864,8 +891,8 @@ def __init__( mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, - added_kv_proj_dim: Optional[int] = None, # New for I2V - image_seq_len: Optional[int] = None, # New for I2V + added_kv_proj_dim: Optional[int] = None, # New for I2V + image_seq_len: Optional[int] = None, # New for I2V ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -889,8 +916,8 @@ def __init__( else: axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) - self.added_kv_proj_dim = added_kv_proj_dim # New for I2V - self.image_seq_len = image_seq_len # New for I2V + self.added_kv_proj_dim = added_kv_proj_dim # New for I2V + self.image_seq_len = image_seq_len # New for I2V self.attention_op = NNXAttentionOp( mesh=mesh, @@ -1006,23 +1033,35 @@ def __init__( self.norm_added_k = nnx.data(None) if self.added_kv_proj_dim is not None: self.add_k_proj = nnx.Linear( - self.added_kv_proj_dim, self.inner_dim, rngs=rngs, - dtype=dtype, param_dtype=weights_dtype, precision=precision, + self.added_kv_proj_dim, + self.inner_dim, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, ("embed",), ), ) self.add_v_proj = nnx.Linear( - self.added_kv_proj_dim, self.inner_dim, rngs=rngs, - dtype=dtype, param_dtype=weights_dtype, precision=precision, + self.added_kv_proj_dim, + self.inner_dim, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, ("embed",), ), ) self.norm_added_k = nnx.RMSNorm( - num_features=self.inner_dim, rngs=rngs, epsilon=eps, dtype=dtype, param_dtype=weights_dtype, + num_features=self.inner_dim, + rngs=rngs, + epsilon=eps, + dtype=dtype, + param_dtype=weights_dtype, scale_init=nnx.with_partitioning( nnx.initializers.ones, ("norm",), @@ -1120,10 +1159,10 @@ def __call__( # Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention # It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384 if encoder_attention_mask is not None: - encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len] + encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len] else: - # Fallback: no mask means treat all as valid (for dot product attention) - encoder_attention_mask_img = None + # Fallback: no mask means treat all as valid (for dot product attention) + encoder_attention_mask_img = None else: # If no image_seq_len is specified, treat all as text encoder_hidden_states_img = None @@ -1134,7 +1173,7 @@ def __call__( with self.conditional_named_scope("attn_q_norm"): query_proj_text = self.norm_q(query_proj_raw) else: - query_proj_text = query_proj_raw + query_proj_text = query_proj_raw # Text K/V with self.conditional_named_scope("proj_key"): @@ -1163,13 +1202,14 @@ def __call__( value_proj_img = checkpoint_name(value_proj_img, "value_proj_img") query_proj_img = checkpoint_name(query_proj_img, "query_proj_img") - # Attention - tensors are (B, S, D) with self.conditional_named_scope("cross_attn_text_apply"): attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) with self.conditional_named_scope("cross_attn_img_apply"): # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens - attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img) + attn_output_img = self.attention_op.apply_attention( + query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img + ) attn_output = attn_output_text + attn_output_img else: diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 21c67e10..41afa3b4 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -249,10 +249,32 @@ def get_1d_rotary_pos_embed( out = jnp.exp(1j * freqs) return out + class NNXWanImageEmbedding(nnx.Module): - def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, dtype: jnp.dtype, weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None, alignment: int = 128, flash_min_seq_length: int = 4096): + + def __init__( + self, + rngs: nnx.Rngs, + in_features: int, + out_features: int, + dtype: jnp.dtype, + weights_dtype: jnp.dtype, + precision: jax.lax.Precision, + pos_embed_seq_len=None, + alignment: int = 128, + flash_min_seq_length: int = 4096, + ): self.norm1 = FP32LayerNorm(rngs=rngs, dim=in_features, elementwise_affine=True, eps=1e-6) - self.ff = NNXSimpleFeedForward(rngs=rngs, dim=in_features, dim_out=out_features, mult=1, activation_fn="gelu", dtype=dtype, weights_dtype=weights_dtype, precision=precision) + self.ff = NNXSimpleFeedForward( + rngs=rngs, + dim=in_features, + dim_out=out_features, + mult=1, + activation_fn="gelu", + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) self.norm2 = FP32LayerNorm(rngs=rngs, dim=out_features, elementwise_affine=True, eps=1e-6) self.alignment = alignment self.flash_min_seq_length = flash_min_seq_length @@ -271,14 +293,14 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j # Apply pos_embed to the original sequence length hidden_states = hidden_states.at[:, :add_len, :].add(self.pos_embed.value[:, :add_len, :]) if current_seq_len > pe_len: - print(f"[WARN] Input seq_len {current_seq_len} > pos_embed len {pe_len}") + print(f"[WARN] Input seq_len {current_seq_len} > pos_embed len {pe_len}") hidden_states = self.norm1(hidden_states) hidden_states = self.ff(hidden_states) hidden_states = self.norm2(hidden_states) # hidden_states shape: (B, current_seq_len, out_features) B, current_seq_len, D_out = hidden_states.shape - use_flash_attn = current_seq_len>=self.flash_min_seq_length + use_flash_attn = current_seq_len >= self.flash_min_seq_length if use_flash_attn: # --- Dynamic Padding to nearest multiple of self.alignment --- @@ -291,13 +313,13 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j attention_mask = jnp.ones((B, current_seq_len), dtype=jnp.int32) if current_seq_len < target_seq_len: - padding_size = target_seq_len - current_seq_len - padding = jnp.zeros((B, padding_size, D_out), dtype=hidden_states.dtype) - hidden_states = jnp.concatenate([hidden_states, padding], axis=1) + padding_size = target_seq_len - current_seq_len + padding = jnp.zeros((B, padding_size, D_out), dtype=hidden_states.dtype) + hidden_states = jnp.concatenate([hidden_states, padding], axis=1) - # Extend mask with zeros for padded positions - padding_mask = jnp.zeros((B, padding_size), dtype=jnp.int32) - attention_mask = jnp.concatenate([attention_mask, padding_mask], axis=1) + # Extend mask with zeros for padded positions + padding_mask = jnp.zeros((B, padding_size), dtype=jnp.int32) + attention_mask = jnp.concatenate([attention_mask, padding_mask], axis=1) if not use_flash_attn: attention_mask = None return hidden_states, attention_mask diff --git a/src/maxdiffusion/models/flux/__init__.py b/src/maxdiffusion/models/flux/__init__.py index 84dd0f15..217c0ac8 100644 --- a/src/maxdiffusion/models/flux/__init__.py +++ b/src/maxdiffusion/models/flux/__init__.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from .transformers.transformer_flux_flax import FluxTransformer2DModel diff --git a/src/maxdiffusion/models/flux/transformers/__init__.py b/src/maxdiffusion/models/flux/transformers/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/flux/transformers/__init__.py +++ b/src/maxdiffusion/models/flux/transformers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 7f63da67..814e21ea 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Tuple import jax @@ -180,7 +180,6 @@ class FluxTransformerBlock(nn.Module): attention_kernel: str = "dot_product" def setup(self): - self.img_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) self.txt_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) @@ -203,29 +202,27 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.img_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) + self.img_mlp = nn.Sequential([ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ]) self.txt_norm2 = nn.LayerNorm( use_bias=False, @@ -234,29 +231,27 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.txt_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) + self.txt_mlp = nn.Sequential([ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ]) # let chunk size default to None self._chunk_size = None diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 8f7d0bf5..a4f665c6 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # copied from https://github.com/ml-gde/jflux/blob/main/jflux/util.py import os diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 9162fbcb..18e5c7e6 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from enum import Enum, auto diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index 82d32e80..88a2b92a 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -1,17 +1,17 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Union, Tuple, Optional diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/ltx_video/__init__.py +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/__init__.py +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 1078f084..e392e4f6 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -126,7 +126,6 @@ def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: class AlphaCombinedTimestepSizeEmbeddings(nn.Module): - embedding_dim: int size_emb_dim: int dtype: jnp.dtype = jnp.float32 diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 8b12b1d8..67902936 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -227,7 +227,6 @@ def __call__( encoder_hidden_states = self.caption_projection(encoder_hidden_states) if self.num_layers > 0: - hidden_states = self.transformer_blocks( hidden_states, freqs_cis, diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py index 6241804b..ceede943 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import inspect from importlib import import_module from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 685b0c0b..1239ddbc 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch - Flax general utilities.""" +"""PyTorch - Flax general utilities.""" import re import torch @@ -348,10 +348,14 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alpha text_encoder_2_params = flatten_dict(unfreeze(params["text_encoder_2"])) else: text_encoder_2_params = None - (unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, network_alphas) = ( - create_flax_params_from_pytorch_state( - pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, adapter_name, is_lora=True - ) + ( + unet_state_dict, + text_encoder_state_dict, + text_encoder_2_state_dict, + rank, + network_alphas, + ) = create_flax_params_from_pytorch_state( + pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, adapter_name, is_lora=True ) params["unet"] = unflatten_dict(unet_state_dict) params["text_encoder"] = unflatten_dict(text_encoder_state_dict) diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index 2ba658d4..24f423f1 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax import jax.numpy as jnp diff --git a/src/maxdiffusion/models/vae_flax.py b/src/maxdiffusion/models/vae_flax.py index dc9b0063..013087ab 100644 --- a/src/maxdiffusion/models/vae_flax.py +++ b/src/maxdiffusion/models/vae_flax.py @@ -74,7 +74,6 @@ class FlaxUpsample2D(nn.Module): weights_dtype: jnp.dtype = jnp.float32 def setup(self): - self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), diff --git a/src/maxdiffusion/models/wan/__init__.py b/src/maxdiffusion/models/wan/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/wan/__init__.py +++ b/src/maxdiffusion/models/wan/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 1da2d18f..f0ac43a0 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Tuple, List, Sequence, Union, Optional @@ -28,7 +28,8 @@ BlockSizes = common_types.BlockSizes CACHE_T = 2 -flax.config.update('flax_always_shard_variable', False) +flax.config.update("flax_always_shard_variable", False) + # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: @@ -483,7 +484,6 @@ def __init__( ) def __call__(self, x: jax.Array): - identity = x batch_size, time, height, width, channels = x.shape diff --git a/src/maxdiffusion/models/wan/transformers/__init__.py b/src/maxdiffusion/models/wan/transformers/__init__.py index 9ff757fc..4a62083b 100644 --- a/src/maxdiffusion/models/wan/transformers/__init__.py +++ b/src/maxdiffusion/models/wan/transformers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index a18b127c..cc237f71 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Tuple, Optional, Dict, Union, Any @@ -104,7 +104,7 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, - flash_min_seq_length: int = 4096 + flash_min_seq_length: int = 4096, ): self.timesteps_proj = NNXFlaxTimesteps(dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0) self.time_embedder = NNXTimestepEmbedding( @@ -149,7 +149,7 @@ def __init__( dtype=dtype, weights_dtype=weights_dtype, precision=precision, - flash_min_seq_length=flash_min_seq_length + flash_min_seq_length=flash_min_seq_length, ) def __call__( @@ -261,11 +261,11 @@ def conditional_named_scope(self, name: str): return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: - hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) - hidden_states = checkpoint_name(hidden_states, "ffn_activation") - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) - with jax.named_scope("proj_out"): - return self.proj_out(hidden_states) # output is (4, 75600, 5120) + hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) + hidden_states = checkpoint_name(hidden_states, "ffn_activation") + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + with jax.named_scope("proj_out"): + return self.proj_out(hidden_states) # output is (4, 75600, 5120) class WanTransformerBlock(nnx.Module): @@ -292,7 +292,6 @@ def __init__( mask_padding_tokens: bool = True, enable_jax_named_scopes: bool = False, ): - self.enable_jax_named_scopes = enable_jax_named_scopes # 1. Self-attention @@ -412,7 +411,7 @@ def __call__( encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs, - encoder_attention_mask = encoder_attention_mask + encoder_attention_mask=encoder_attention_mask, ) with self.conditional_named_scope("cross_attn_residual"): hidden_states = hidden_states + attn_output @@ -504,7 +503,7 @@ def __init__( text_embed_dim=text_dim, image_embed_dim=image_dim, pos_embed_seq_len=pos_embed_seq_len, - flash_min_seq_length=flash_min_seq_length + flash_min_seq_length=flash_min_seq_length, ) # 3. Transformer blocks @@ -583,7 +582,7 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() - @jax.named_scope('WanModel') + @jax.named_scope("WanModel") def __call__( self, hidden_states: jax.Array, @@ -609,24 +608,37 @@ def __call__( hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) with self.conditional_named_scope("condition_embedder"): - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: - encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) - if encoder_attention_mask is not None: - text_mask = jnp.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), dtype=jnp.int32) - encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) - encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + if encoder_attention_mask is not None: + text_mask = jnp.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), + dtype=jnp.int32, + ) + encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) + encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) if self.scan_layers: def scan_fn(carry, block): hidden_states_carry, rngs_carry = carry hidden_states = block( - hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry, encoder_attention_mask + hidden_states_carry, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs_carry, + encoder_attention_mask, ) new_carry = (hidden_states, rngs_carry) return new_carry, None @@ -647,7 +659,15 @@ def scan_fn(carry, block): for block in self.blocks: def layer_forward(hidden_states): - return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs, encoder_attention_mask=encoder_attention_mask) + return block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs, + encoder_attention_mask=encoder_attention_mask, + ) rematted_layer_forward = self.gradient_checkpoint.apply( layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index 1e1f7ae5..ce73ac5d 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -104,15 +104,11 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), ("embed", None) - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) # 2. Self-attention - self.norm1 = FP32LayerNorm( - rngs=rngs, dim=dim, eps=eps, elementwise_affine=False - ) + self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) self.attn1 = FlaxWanAttention( rngs=rngs, query_dim=dim, @@ -150,9 +146,7 @@ def __init__( residual_checkpoint_name="cross_attn", ) assert cross_attn_norm is True, "cross_attn_norm must be True" - self.norm2 = FP32LayerNorm( - rngs=rngs, dim=dim, eps=eps, elementwise_affine=True - ) + self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) # 4. Feed-forward self.ffn = WanFeedForward( @@ -166,9 +160,7 @@ def __init__( dropout=dropout, ) - self.norm3 = FP32LayerNorm( - rngs=rngs, dim=dim, eps=eps, elementwise_affine=False - ) + self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) # 5. Output projection self.proj_out = nnx.data([None]) @@ -180,9 +172,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), ("embed", None) - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) key = rngs.params() @@ -205,19 +195,15 @@ def __call__( control_hidden_states = self.proj_in(control_hidden_states) control_hidden_states = control_hidden_states + hidden_states - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - jnp.split( - (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 - ) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) control_hidden_states = jax.lax.with_sharding_constraint( control_hidden_states, PartitionSpec("data", "fsdp", "tensor"), ) - control_hidden_states = checkpoint_name( - control_hidden_states, "control_hidden_states" - ) + control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states") encoder_hidden_states = jax.lax.with_sharding_constraint( encoder_hidden_states, PartitionSpec("data", "fsdp", None), @@ -225,11 +211,9 @@ def __call__( # 1. Self-attention with jax.named_scope("attn1"): - norm_hidden_states = ( - self.norm1(control_hidden_states.astype(jnp.float32)) - * (1 + scale_msa) - + shift_msa - ).astype(control_hidden_states.dtype) + norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + control_hidden_states.dtype + ) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, @@ -237,15 +221,13 @@ def __call__( deterministic=deterministic, rngs=rngs, ) - control_hidden_states = ( - control_hidden_states.astype(jnp.float32) + attn_output * gate_msa - ).astype(control_hidden_states.dtype) + control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype( + control_hidden_states.dtype + ) # 2. Cross-attention with jax.named_scope("attn2"): - norm_hidden_states = self.norm2( - control_hidden_states.astype(jnp.float32) - ).astype(control_hidden_states.dtype) + norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype) attn_output = self.attn2( hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -256,17 +238,12 @@ def __call__( # 3. Feed-forward with jax.named_scope("ffn"): - norm_hidden_states = ( - self.norm3(control_hidden_states.astype(jnp.float32)) - * (1 + c_scale_msa) - + c_shift_msa - ).astype(control_hidden_states.dtype) - ff_output = self.ffn( - norm_hidden_states, deterministic=deterministic, rngs=rngs + norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + control_hidden_states.dtype ) + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) control_hidden_states = ( - control_hidden_states.astype(jnp.float32) - + ff_output.astype(jnp.float32) * c_gate_msa + control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa ).astype(control_hidden_states.dtype) conditioning_states = None if self.apply_output_projection: @@ -327,9 +304,7 @@ def __init__( self.scan_layers = scan_layers # 1. Patch & position embedding - self.rope = WanRotaryPosEmbed( - attention_head_dim, patch_size, rope_max_seq_len - ) + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nnx.Conv( in_channels, inner_dim, @@ -356,9 +331,7 @@ def __init__( pos_embed_seq_len=pos_embed_seq_len, ) - self.gradient_checkpoint = GradientCheckpointType.from_str( - remat_policy - ) + self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) self.names_which_can_be_offloaded = names_which_can_be_offloaded self.names_which_can_be_saved = names_which_can_be_saved @@ -432,9 +405,7 @@ def __init__( ), ) - self.norm_out = FP32LayerNorm( - rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False - ) + self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( rngs=rngs, in_features=inner_dim, @@ -442,16 +413,12 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), ("embed", None) - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) key = rngs.params() self.scale_shift_table = nnx.Param( jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), (None, None, "embed") - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), ) @jax.named_scope("WanVACEModel") @@ -468,9 +435,7 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: - hidden_states = nn.with_logical_constraint( - hidden_states, ("batch", None, None, None, None) - ) + hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -478,9 +443,7 @@ def __call__( post_patch_width = width // p_w if control_hidden_states_scale is None: - control_hidden_states_scale = jnp.ones_like( - control_hidden_states, shape=(len(self.config.vace_layers),) - ) + control_hidden_states_scale = jnp.ones_like(control_hidden_states, shape=(len(self.config.vace_layers),)) if control_hidden_states_scale.shape[0] != len(self.config.vace_layers): raise ValueError( "Length of `control_hidden_states_scale`" @@ -489,9 +452,7 @@ def __call__( ) hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) - control_hidden_states = jnp.transpose( - control_hidden_states, (0, 2, 3, 4, 1) - ) + control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) @@ -505,15 +466,16 @@ def __call__( hidden_states.shape[2] - control_hidden_states.shape[2], )) - control_hidden_states = jnp.concatenate( - [control_hidden_states, control_hidden_states_padding], axis=2 - ) + control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2) # Condition embedder is a FC layer. - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( - self.condition_embedder( # We will need to mask out the text embedding. - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + ) = self.condition_embedder( # We will need to mask out the text embedding. + timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) @@ -526,6 +488,7 @@ def __call__( # Prepare VACE hints control_hidden_states_list = nnx.List([]) for i, vace_block in enumerate(self.vace_blocks): + def layer_forward(hidden_states, control_hidden_states): return vace_block( hidden_states=hidden_states, @@ -543,12 +506,8 @@ def layer_forward(hidden_states, control_hidden_states): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - conditioning_states, control_hidden_states = rematted_layer_forward( - hidden_states, control_hidden_states - ) - control_hidden_states_list.append( - (conditioning_states, control_hidden_states_scale[i]) - ) + conditioning_states, control_hidden_states = rematted_layer_forward(hidden_states, control_hidden_states) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) control_hidden_states_list = control_hidden_states_list[::-1] @@ -576,13 +535,9 @@ def layer_forward_vace(hidden_states): hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify - shift, scale = jnp.split( - self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1 - ) + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) - hidden_states = ( - self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift - ).astype(hidden_states.dtype) + hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) with jax.named_scope("proj_out"): hidden_states = self.proj_out(hidden_states) # Linear layer. diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 7a4b8841..b12d1907 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import os @@ -186,7 +186,6 @@ def load_wan_transformer( scan_layers: bool = True, subfolder: str = "", ): - if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: @@ -260,23 +259,23 @@ def load_base_wan_transformer( renamed_pt_key = rename_key(pt_key) if "condition_embedder" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") - renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") - renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") if "image_embedder" in renamed_pt_key: - if "net.0.proj" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") - elif "net_0.proj" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") - if "net.2" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") - renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") - if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("weight", "scale") - renamed_pt_key = renamed_pt_key.replace("kernel", "scale") + if "net.0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") + elif "net_0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") + if "net.2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") + renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") + if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("weight", "scale") + renamed_pt_key = renamed_pt_key.replace("kernel", "scale") renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") diff --git a/src/maxdiffusion/multihost_dataloading.py b/src/maxdiffusion/multihost_dataloading.py index 4be0ba8d..273ded82 100644 --- a/src/maxdiffusion/multihost_dataloading.py +++ b/src/maxdiffusion/multihost_dataloading.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=unused-import """SPMD Multihost Dataloading Utilities. diff --git a/src/maxdiffusion/pedagogical_examples/attention_comparison.py b/src/maxdiffusion/pedagogical_examples/attention_comparison.py index 024ef92a..07831550 100644 --- a/src/maxdiffusion/pedagogical_examples/attention_comparison.py +++ b/src/maxdiffusion/pedagogical_examples/attention_comparison.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import time diff --git a/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py b/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py index 28230251..47a521fb 100644 --- a/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py +++ b/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """This script is used an example of how to restore params from a orbax train_state ckpt.""" diff --git a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py index 16c015a7..2300d0bd 100644 --- a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py +++ b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import argparse import tensorflow as tf diff --git a/src/maxdiffusion/pedagogical_examples/parameter_count.py b/src/maxdiffusion/pedagogical_examples/parameter_count.py index 8e591b4e..cf9f8b8d 100644 --- a/src/maxdiffusion/pedagogical_examples/parameter_count.py +++ b/src/maxdiffusion/pedagogical_examples/parameter_count.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence from absl import app import jax @@ -21,7 +21,6 @@ def run(config): - pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( config.pretrained_model_name_or_path, revision=config.revision, diff --git a/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py b/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py index 08a0f46a..350f6d0c 100644 --- a/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py +++ b/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Load and save a checkpoint. This is useful for uploading checkpoints to gcs and later loading them from gcs directly. diff --git a/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py b/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py index 64aa0f0b..860f4beb 100644 --- a/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py +++ b/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Load and save a checkpoint. This is useful for uploading checkpoints to gcs and later loading them from gcs directly. diff --git a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py index 6298adda..a0a38021 100644 --- a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py +++ b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Example file of how to prepare tfrecords with latents and hidden_states preprocessed. @@ -54,14 +54,12 @@ dl_manager = tfds.download.DownloadManager(download_dir="/tmp") tmp_dataset = "dataset" -TRANSFORMS = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(size=512), - transforms.Normalize([0.5], [0.5]), - ] -) +TRANSFORMS = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(size=512), + transforms.Normalize([0.5], [0.5]), +]) def delete_files(path): @@ -184,7 +182,6 @@ def img_to_latents(img, p_vae_apply, sample_rng): def run(config): - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( config.pretrained_model_name_or_path, revision=config.revision, diff --git a/src/maxdiffusion/pedagogical_examples/unet_shardings.py b/src/maxdiffusion/pedagogical_examples/unet_shardings.py index bc956b1f..38ed1af9 100644 --- a/src/maxdiffusion/pedagogical_examples/unet_shardings.py +++ b/src/maxdiffusion/pedagogical_examples/unet_shardings.py @@ -1,20 +1,20 @@ #!/usr/bin/python3 """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """This script is used an example of how to shard the UNET on TPU.""" diff --git a/src/maxdiffusion/pipelines/__init__.py b/src/maxdiffusion/pipelines/__init__.py index 227784ba..019c79a8 100644 --- a/src/maxdiffusion/pipelines/__init__.py +++ b/src/maxdiffusion/pipelines/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import TYPE_CHECKING @@ -51,16 +51,14 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) else: - _import_structure["stable_diffusion"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["stable_diffusion"].extend([ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ]) try: if not is_flax_available(): @@ -82,20 +80,15 @@ _import_structure["controlnet"].extend( ["FlaxStableDiffusionControlNetPipeline", "FlaxStableDiffusionXLControlNetPipeline"] ) - _import_structure["stable_diffusion"].extend( - [ - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - ] - ) - _import_structure["stable_diffusion_xl"].extend( - [ - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["stable_diffusion"].extend([ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ]) + _import_structure["stable_diffusion_xl"].extend([ + "FlaxStableDiffusionXLPipeline", + ]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/maxdiffusion/pipelines/controlnet/__init__.py b/src/maxdiffusion/pipelines/controlnet/__init__.py index e650f9d5..0cf92cd4 100644 --- a/src/maxdiffusion/pipelines/controlnet/__init__.py +++ b/src/maxdiffusion/pipelines/controlnet/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py b/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py index 885b0b37..b8b1cc18 100644 --- a/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py +++ b/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py @@ -112,7 +112,6 @@ def __call__( output_type: str = None, jit: bool = False, ): - if isinstance(guidance_scale, float) and jit: # Convert to a tensor so each device gets a copy. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py index 5457eef5..c39cc364 100644 --- a/src/maxdiffusion/pipelines/flux/__init__.py +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" _import_structure = {"pipeline_jflux": "JfluxPipeline"} from .flux_pipeline import ( diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py index 112338d5..15b2c4f5 100644 --- a/src/maxdiffusion/pipelines/flux/flux_pipeline.py +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -131,7 +131,6 @@ def prepare_latents( dtype: jnp.dtype, rng: Array, ): - # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) @@ -194,7 +193,6 @@ def get_t5_prompt_embeds( encode_in_batches=False, encode_batch_size=None, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -243,7 +241,6 @@ def encode_prompt( encode_in_batches: bool = False, encode_batch_size: int = None, ): - if encode_in_batches: assert encode_in_batches is not None @@ -271,7 +268,6 @@ def encode_prompt( def _generate( self, flux_params, vae_params, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): - def loop_body( step, args, diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 1b8f4deb..4aa3baf1 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -901,7 +901,6 @@ def transformer_forward_pass( skip_layer_mask, skip_layer_strategy, ): - noise_pred = transformer.apply( {"params": state.params}, hidden_states=latents, diff --git a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py index 9ac32eb7..cbef1d5f 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import TYPE_CHECKING from ...utils import ( @@ -84,13 +84,11 @@ StableDiffusionPix2PixZeroPipeline, ) - _dummy_objects.update( - { - "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, - "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, - "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, - } - ) + _dummy_objects.update({ + "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, + "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, + "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, + }) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py index 2eb01334..13e201ae 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py index 83a537f8..9a17b1e7 100644 --- a/src/maxdiffusion/pipelines/wan/__init__.py +++ b/src/maxdiffusion/pipelines/wan/__init__.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from .wan_pipeline import WanPipeline diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 0bc93f0c..415bcfea 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -94,9 +94,13 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder: str = "", ): - def create_model(rngs: nnx.Rngs, wan_config: dict): wan_transformer = WanModel(**wan_config, rngs=rngs) return wan_transformer @@ -111,7 +115,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # WAN 2.2 I2V uses VAE-encoded latent conditioning (image_dim and added_kv_proj_dim are None in the transformer config) if config.model_name == "wan2.1": if wan_config.get("image_seq_len") is None: - wan_config["image_seq_len"] = 257 + wan_config["image_seq_len"] = 257 wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype @@ -201,6 +205,7 @@ class WanPipeline: vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ + def __init__( self, tokenizer: AutoTokenizer, @@ -252,21 +257,18 @@ def load_tokenizer(cls, config: HyperParameters): @classmethod def load_image_encoder(cls, config: HyperParameters): - image_processor = CLIPImageProcessor.from_pretrained( - config.pretrained_model_name_or_path, subfolder="image_processor" - ) + image_processor = CLIPImageProcessor.from_pretrained(config.pretrained_model_name_or_path, subfolder="image_processor") try: - image_encoder = FlaxCLIPVisionModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="image_encoder", dtype=jnp.float32 - ) + image_encoder = FlaxCLIPVisionModel.from_pretrained( + config.pretrained_model_name_or_path, subfolder="image_encoder", dtype=jnp.float32 + ) except Exception as e: - max_logging.error(f"Failed to load FlaxCLIPVisionModel: {e}") - raise + max_logging.error(f"Failed to load FlaxCLIPVisionModel: {e}") + raise return image_processor, image_encoder @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - def create_model(rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( config.pretrained_model_name_or_path, @@ -384,10 +386,22 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline @classmethod def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): + cls, + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder="transformer", + ): with mesh: wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder + devices_array=devices_array, + mesh=mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder=subfolder, ) return wan_transformer @@ -401,17 +415,16 @@ def load_scheduler(cls, config): return scheduler, scheduler_state def encode_image(self, image: PipelineImageInput, num_videos_per_prompt: int = 1): - if not isinstance(image, list): - image = [image] - image_inputs = self.image_processor(images=image, return_tensors="np") - pixel_values = jnp.array(image_inputs.pixel_values) - - image_encoder_output = self.image_encoder(pixel_values, output_hidden_states=True) - image_embeds = image_encoder_output.hidden_states[-2] + if not isinstance(image, list): + image = [image] + image_inputs = self.image_processor(images=image, return_tensors="np") + pixel_values = jnp.array(image_inputs.pixel_values) - image_embeds = jnp.repeat(image_embeds, num_videos_per_prompt, axis=0) - return image_embeds + image_encoder_output = self.image_encoder(pixel_values, output_hidden_states=True) + image_embeds = image_encoder_output.hidden_states[-2] + image_embeds = jnp.repeat(image_embeds, num_videos_per_prompt, axis=0) + return image_embeds def _get_t5_prompt_embeds( self, @@ -508,82 +521,90 @@ def prepare_latents_i2v_base( dtype: jnp.dtype, last_image: Optional[jax.Array] = None, ) -> Tuple[jax.Array, jax.Array]: - """ - Encodes the initial image(s) into latents to be used as conditioning. - Returns: - latent_condition: The VAE encoded latents of the image(s). - video_condition: The input to the VAE. - """ - height, width = image.shape[-2:] - image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W] - - if last_image is None: - video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 - ) - else: - last_image = last_image[:, :, jnp.newaxis, :, :] - video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], axis=2 - ) - - vae_dtype = getattr(self.vae, "dtype", jnp.float32) - video_condition = video_condition.astype(vae_dtype) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() - - # Normalize latents - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) - latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) - latent_condition = encoded_output - latent_condition = latent_condition.astype(dtype) - latent_condition = (latent_condition - latents_mean) / latents_std - - return latent_condition, video_condition + """ + Encodes the initial image(s) into latents to be used as conditioning. + Returns: + latent_condition: The VAE encoded latents of the image(s). + video_condition: The input to the VAE. + """ + height, width = image.shape[-2:] + image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W] + + if last_image is None: + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 + ) + else: + last_image = last_image[:, :, jnp.newaxis, :, :] + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], + axis=2, + ) + + vae_dtype = getattr(self.vae, "dtype", jnp.float32) + video_condition = video_condition.astype(vae_dtype) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() + + # Normalize latents + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) + latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latent_condition = encoded_output + latent_condition = latent_condition.astype(dtype) + latent_condition = (latent_condition - latents_mean) / latents_std + + return latent_condition, video_condition def _denormalize_latents(self, latents: jax.Array) -> jax.Array: - """Denormalizes latents using VAE statistics.""" - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) - return latents + """Denormalizes latents using VAE statistics.""" + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(jnp.float32) + return latents def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: - """Decodes latents to video frames and postprocesses.""" - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] + """Decodes latents to video frames and postprocesses.""" + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - return self.video_processor.postprocess_video(video, output_type="np") + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + return self.video_processor.postprocess_video(video, output_type="np") @classmethod def _create_common_components(cls, config, vae_only=False, i2v=False): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - components = { - "vae": wan_vae, "vae_cache": vae_cache, - "devices_array": devices_array, "rngs": rngs, "mesh": mesh, - "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None, - "image_processor": None, "image_encoder": None - } - - if not vae_only: - components["tokenizer"] = cls.load_tokenizer(config=config) - components["text_encoder"] = cls.load_text_encoder(config=config) - components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) - if i2v and config.model_name == 'wan2.1': - components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) - return components + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + components = { + "vae": wan_vae, + "vae_cache": vae_cache, + "devices_array": devices_array, + "rngs": rngs, + "mesh": mesh, + "tokenizer": None, + "text_encoder": None, + "scheduler": None, + "scheduler_state": None, + "image_processor": None, + "image_encoder": None, + } + + if not vae_only: + components["tokenizer"] = cls.load_tokenizer(config=config) + components["text_encoder"] = cls.load_text_encoder(config=config) + components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) + if i2v and config.model_name == "wan2.1": + components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) + return components @abstractmethod def _get_num_channel_latents(self) -> int: @@ -603,7 +624,7 @@ def _prepare_model_inputs_i2v( last_image: Optional[PIL.Image.Image] = None, ): if prompt is not None and isinstance(prompt, str): - prompt = [prompt] + prompt = [prompt] batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt effective_batch_size = batch_size * num_videos_per_prompt @@ -617,30 +638,29 @@ def _prepare_model_inputs_i2v( negative_prompt_embeds=negative_prompt_embeds, ) - # 2. Encode Image (only for WAN 2.1 I2V which uses CLIP image embeddings) # WAN 2.2 I2V does not use CLIP image embeddings, it uses VAE latent conditioning instead transformer_dtype = self.config.activations_dtype if self.config.model_name == "wan2.1": - # WAN 2.1 I2V: Use CLIP image encoder - if image_embeds is None: - images_to_encode = [image] - if last_image is None: - images_to_encode = [image] - else: - images_to_encode = [image, last_image] - image_embeds = self.encode_image(images_to_encode, num_videos_per_prompt=num_videos_per_prompt) - self.image_seq_len = image_embeds.shape[1] - - if batch_size > 1: - image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1)) - - image_embeds = image_embeds.astype(transformer_dtype) + # WAN 2.1 I2V: Use CLIP image encoder + if image_embeds is None: + images_to_encode = [image] + if last_image is None: + images_to_encode = [image] + else: + images_to_encode = [image, last_image] + image_embeds = self.encode_image(images_to_encode, num_videos_per_prompt=num_videos_per_prompt) + self.image_seq_len = image_embeds.shape[1] + + if batch_size > 1: + image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1)) + + image_embeds = image_embeds.astype(transformer_dtype) else: - # WAN 2.2 I2V: No CLIP image embeddings, set to None or empty tensor - # The actual image conditioning happens via VAE latents in prepare_latents - image_embeds = None + # WAN 2.2 I2V: No CLIP image embeddings, set to None or empty tensor + # The actual image conditioning happens via VAE latents in prepare_latents + image_embeds = None prompt_embeds = prompt_embeds.astype(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype) @@ -648,7 +668,7 @@ def _prepare_model_inputs_i2v( # Use same sharding logic as T2V pipeline for consistent behavior data_sharding = NamedSharding(self.mesh, P()) if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) @@ -656,22 +676,21 @@ def _prepare_model_inputs_i2v( return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size - def _prepare_model_inputs( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: max_logging.log( @@ -724,8 +743,9 @@ def _prepare_model_inputs( @abstractmethod def __call__(self, **kwargs): - """Runs the inference pipeline.""" - pass + """Runs the inference pipeline.""" + pass + @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( @@ -740,7 +760,12 @@ def transformer_forward_pass( encoder_hidden_states_image=None, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image) + noise_pred = wan_transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=encoder_hidden_states_image, + ) if do_classifier_free_guidance: bsz = latents.shape[0] // 2 noise_cond = noise_pred[:bsz] # First half = conditional diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 5617e3b7..c247facb 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -23,8 +23,10 @@ import jax.numpy as jnp from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipeline2_1(WanPipeline): """Pipeline for WAN 2.1 with a single transformer.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): super().__init__(config=config, **kwargs) self.transformer = transformer @@ -41,27 +43,27 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer" + subfolder="transformer", ) pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - transformer=transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - config=config, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, ) return pipeline, transformer @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) return pipeline @@ -74,20 +76,20 @@ def _get_num_channel_latents(self) -> int: return self.transformer.config.in_channels def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: Optional[jax.Array] = None, - prompt_embeds: Optional[jax.Array] = None, - negative_prompt_embeds: Optional[jax.Array] = None, - vae_only: bool = False, + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + vae_only: bool = False, ): latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, @@ -126,6 +128,7 @@ def __call__( latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) + def run_inference_2_1( graphdef, sharded_state, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index c0400f60..e6514dae 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -23,9 +23,17 @@ import jax.numpy as jnp from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipeline2_2(WanPipeline): """Pipeline for WAN 2.2 with dual transformers.""" - def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + + def __init__( + self, + config: HyperParameters, + low_noise_transformer: Optional[WanModel], + high_noise_transformer: Optional[WanModel], + **kwargs + ): super().__init__(config=config, **kwargs) self.low_noise_transformer = low_noise_transformer self.high_noise_transformer = high_noise_transformer @@ -35,24 +43,24 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t common_components = cls._create_common_components(config, vae_only) low_noise_transformer, high_noise_transformer = None, None if not vae_only and load_transformer: - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer_2" - ) - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer" - ) - - pipeline = cls( + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2", + ) + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + + pipeline = cls( tokenizer=common_components["tokenizer"], text_encoder=common_components["text_encoder"], low_noise_transformer=low_noise_transformer, @@ -64,7 +72,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t devices_array=common_components["devices_array"], mesh=common_components["mesh"], config=config, - ) + ) return pipeline, low_noise_transformer, high_noise_transformer @classmethod @@ -76,29 +84,31 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform @classmethod def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init( + config, restored_checkpoint, vae_only, load_transformer + ) return pipeline def _get_num_channel_latents(self) -> int: return self.low_noise_transformer.config.in_channels def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, ): latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, @@ -143,6 +153,7 @@ def __call__( latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) + def run_inference_2_2( low_noise_graphdef, low_noise_state, @@ -167,17 +178,27 @@ def run_inference_2_2( def low_noise_branch(operands): latents, timestep, prompt_embeds = operands return transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_low + low_noise_graphdef, + low_noise_state, + low_noise_rest, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance, + guidance_scale_low, ) def high_noise_branch(operands): latents, timestep, prompt_embeds = operands return transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_high + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance, + guidance_scale_high, ) for step in range(num_inference_steps): @@ -192,10 +213,7 @@ def high_noise_branch(operands): # - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise). # - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise). noise_pred, latents = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latents, timestep, prompt_embeds) + use_high_noise, high_noise_branch, low_noise_branch, (latents, timestep, prompt_embeds) ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 0380a07c..0622ec79 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -26,8 +26,10 @@ from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipelineI2V_2_1(WanPipeline): """Pipeline for WAN 2.1 Image-to-Video.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): super().__init__(config=config, **kwargs) self.transformer = transformer @@ -44,28 +46,28 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer" + subfolder="transformer", ) pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - image_processor=common_components["image_processor"], - image_encoder=common_components["image_encoder"], - transformer=transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - config=config, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + image_processor=common_components["image_processor"], + image_encoder=common_components["image_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, ) return pipeline, transformer @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) return pipeline @@ -87,110 +89,109 @@ def prepare_latents( last_image: Optional[jax.Array] = None, num_videos_per_prompt: int = 1, ) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: - - if hasattr(image, "detach"): - image = image.detach().cpu().numpy() - image = jnp.array(image) - - if last_image is not None: - if hasattr(last_image, "detach"): - last_image = last_image.detach().cpu().numpy() - last_image = jnp.array(last_image) - - if num_videos_per_prompt > 1: - image = jnp.repeat(image, num_videos_per_prompt, axis=0) - if last_image is not None: - last_image = jnp.repeat(last_image, num_videos_per_prompt, axis=0) - - num_channels_latents = self.vae.z_dim - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - latent_height = height // self.vae_scale_factor_spatial - latent_width = width // self.vae_scale_factor_spatial - - shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) - - if latents is None: - latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) - else: - latents = latents.astype(dtype) - latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) - mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) - if last_image is None: - mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) - else: - mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) - mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) - mask_lat_size = mask_lat_size.reshape( - batch_size, - 1, - num_latent_frames, - self.vae_scale_factor_temporal, - latent_height, - latent_width - ) - mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) - condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) - return latents, condition, None - + if hasattr(image, "detach"): + image = image.detach().cpu().numpy() + image = jnp.array(image) + + if last_image is not None: + if hasattr(last_image, "detach"): + last_image = last_image.detach().cpu().numpy() + last_image = jnp.array(last_image) + + if num_videos_per_prompt > 1: + image = jnp.repeat(image, num_videos_per_prompt, axis=0) + if last_image is not None: + last_image = jnp.repeat(last_image, num_videos_per_prompt, axis=0) + + num_channels_latents = self.vae.z_dim + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + + if latents is None: + latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) + else: + latents = latents.astype(dtype) + latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) + mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) + if last_image is None: + mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) + else: + mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) + mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) + mask_lat_size = mask_lat_size.reshape( + batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) + condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) + return latents, condition, None def __call__( - self, - prompt: Union[str, List[str]], - image: PipelineImageInput, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, - latents: Optional[jax.Array] = None, - prompt_embeds: Optional[jax.Array] = None, - negative_prompt_embeds: Optional[jax.Array] = None, - image_embeds: Optional[jax.Array] = None, - last_image: Optional[PipelineImageInput] = None, - output_type: Optional[str] = "np", - rng: Optional[jax.Array] = None, + self, + prompt: Union[str, List[str]], + image: PipelineImageInput, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + image_embeds: Optional[jax.Array] = None, + last_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "np", + rng: Optional[jax.Array] = None, ): - height = height or self.config.height width = width or self.config.width num_frames = num_frames or self.config.num_frames # Validate and adjust num_frames to ensure proper reshaping in prepare_latents if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " - f"Rounding {num_frames} to the nearest valid number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - max_logging.log(f"Adjusted num_frames to: {num_frames}") + max_logging.log( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + f"Rounding {num_frames} to the nearest valid number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( - prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length, - prompt_embeds, negative_prompt_embeds, image_embeds, last_image + prompt, + image, + negative_prompt, + num_videos_per_prompt, + max_sequence_length, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + last_image, ) def _process_image_input(img_input, height, width, num_videos_per_prompt): - if img_input is None: - return None - tensor = self.video_processor.preprocess(img_input, height=height, width=width) - jax_array = jnp.array(tensor.cpu().numpy()) - if jax_array.ndim == 3: - jax_array = jax_array[None, ...] # Add batch dimension - if num_videos_per_prompt > 1: - jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) - return jax_array + if img_input is None: + return None + tensor = self.video_processor.preprocess(img_input, height=height, width=width) + jax_array = jnp.array(tensor.cpu().numpy()) + if jax_array.ndim == 3: + jax_array = jax_array[None, ...] # Add batch dimension + if num_videos_per_prompt > 1: + jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) + return jax_array image_tensor = _process_image_input(image, height, width, effective_batch_size) last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) if rng is None: - rng = jax.random.key(self.config.seed) + rng = jax.random.key(self.config.seed) latents_rng, inference_rng = jax.random.split(rng) latents, condition, first_frame_mask = self.prepare_latents( @@ -213,7 +214,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) data_sharding = NamedSharding(self.mesh, P()) if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) latents = jax.device_put(latents, data_sharding) condition = jax.device_put(condition, data_sharding) @@ -221,7 +222,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) image_embeds = jax.device_put(image_embeds, data_sharding) if first_frame_mask is not None: - first_frame_mask = jax.device_put(first_frame_mask, data_sharding) + first_frame_mask = jax.device_put(first_frame_mask, data_sharding) p_run_inference = partial( run_inference_2_1_i2v, @@ -233,7 +234,6 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): scheduler=self.scheduler, ) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( latents=latents, @@ -252,7 +252,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): def run_inference_2_1_i2v( - graphdef, sharded_state, rest_of_state, + graphdef, + sharded_state, + rest_of_state, latents: jnp.array, condition: jnp.array, prompt_embeds: jnp.array, @@ -273,14 +275,18 @@ def run_inference_2_1_i2v( t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] latents_input = latents if do_classifier_free_guidance: - latents_input = jnp.concatenate([latents, latents], axis=0) + latents_input = jnp.concatenate([latents, latents], axis=0) latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) noise_pred, _ = transformer_forward_pass( - graphdef, sharded_state, rest_of_state, - latent_model_input, timestep, prompt_embeds, + graphdef, + sharded_state, + rest_of_state, + latent_model_input, + timestep, + prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index ab24a651..1f65f452 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -26,9 +26,17 @@ from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipelineI2V_2_2(WanPipeline): """Pipeline for WAN 2.2 Image-to-Video.""" - def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + + def __init__( + self, + config: HyperParameters, + low_noise_transformer: Optional[WanModel], + high_noise_transformer: Optional[WanModel], + **kwargs, + ): super().__init__(config=config, **kwargs) self.low_noise_transformer = low_noise_transformer self.high_noise_transformer = high_noise_transformer @@ -39,26 +47,38 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t common_components = cls._create_common_components(config, vae_only, i2v=True) low_noise_transformer, high_noise_transformer = None, None if not vae_only: - if load_transformer: - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], mesh=common_components["mesh"], - rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer" - ) - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], mesh=common_components["mesh"], - rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer_2" - ) + if load_transformer: + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2", + ) pipeline = cls( - tokenizer=common_components["tokenizer"], text_encoder=common_components["text_encoder"], - image_processor=common_components["image_processor"], image_encoder=common_components["image_encoder"], - low_noise_transformer=low_noise_transformer, high_noise_transformer=high_noise_transformer, - vae=common_components["vae"], vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], mesh=common_components["mesh"], - config=config, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + image_processor=common_components["image_processor"], + image_encoder=common_components["image_encoder"], + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, ) return pipeline, low_noise_transformer, high_noise_transformer @@ -75,27 +95,26 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ return pipeline def prepare_latents( - self, - image: jax.Array, - batch_size: int, - height: int, - width: int, - num_frames: int, - dtype: jnp.dtype, - rng: jax.Array, - latents: Optional[jax.Array] = None, - last_image: Optional[jax.Array] = None, - num_videos_per_prompt: int = 1, -) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: - + self, + image: jax.Array, + batch_size: int, + height: int, + width: int, + num_frames: int, + dtype: jnp.dtype, + rng: jax.Array, + latents: Optional[jax.Array] = None, + last_image: Optional[jax.Array] = None, + num_videos_per_prompt: int = 1, + ) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: if hasattr(image, "detach"): - image = image.detach().cpu().numpy() + image = image.detach().cpu().numpy() image = jnp.array(image) if last_image is not None: - if hasattr(last_image, "detach"): - last_image = last_image.detach().cpu().numpy() - last_image = jnp.array(last_image) + if hasattr(last_image, "detach"): + last_image = last_image.detach().cpu().numpy() + last_image = jnp.array(last_image) num_channels_latents = self.vae.z_dim num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 @@ -105,16 +124,16 @@ def prepare_latents( shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) if latents is None: - latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) + latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) else: - latents = latents.astype(dtype) + latents = latents.astype(dtype) latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) if last_image is None: - mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) + mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) else: - mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) + mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) @@ -127,59 +146,67 @@ def prepare_latents( return latents, condition, None def __call__( - self, - prompt: Union[str, List[str]], - image: PipelineImageInput, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, - latents: Optional[jax.Array] = None, - prompt_embeds: Optional[jax.Array] = None, - negative_prompt_embeds: Optional[jax.Array] = None, - image_embeds: Optional[jax.Array] = None, - last_image: Optional[PipelineImageInput] = None, - output_type: Optional[str] = "np", - rng: Optional[jax.Array] = None, + self, + prompt: Union[str, List[str]], + image: PipelineImageInput, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + image_embeds: Optional[jax.Array] = None, + last_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "np", + rng: Optional[jax.Array] = None, ): height = height or self.config.height width = width or self.config.width num_frames = num_frames or self.config.num_frames if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " - f"Rounding {num_frames} to the nearest valid number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - max_logging.log(f"Adjusted num_frames to: {num_frames}") + max_logging.log( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + f"Rounding {num_frames} to the nearest valid number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( - prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length, - prompt_embeds, negative_prompt_embeds, image_embeds, last_image + prompt, + image, + negative_prompt, + num_videos_per_prompt, + max_sequence_length, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + last_image, ) + def _process_image_input(img_input, height, width, num_videos_per_prompt): - if img_input is None: - return None - tensor = self.video_processor.preprocess(img_input, height=height, width=width) - jax_array = jnp.array(tensor.cpu().numpy()) - if jax_array.ndim == 3: - jax_array = jax_array[None, ...] # Add batch dimension - if num_videos_per_prompt > 1: - jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) - return jax_array + if img_input is None: + return None + tensor = self.video_processor.preprocess(img_input, height=height, width=width) + jax_array = jnp.array(tensor.cpu().numpy()) + if jax_array.ndim == 3: + jax_array = jax_array[None, ...] # Add batch dimension + if num_videos_per_prompt > 1: + jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) + return jax_array image_tensor = _process_image_input(image, height, width, effective_batch_size) last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) if rng is None: - rng = jax.random.key(self.config.seed) + rng = jax.random.key(self.config.seed) latents_rng, inference_rng = jax.random.split(rng) # For WAN 2.2, image_embeds may be None (no CLIP image encoder) @@ -206,17 +233,16 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) data_sharding = NamedSharding(self.mesh, P()) if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) latents = jax.device_put(latents, data_sharding) condition = jax.device_put(condition, data_sharding) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) # WAN 2.2 I2V doesn't use image_embeds (it's None), but we still need to pass it to the function if image_embeds is not None: - image_embeds = jax.device_put(image_embeds, data_sharding) + image_embeds = jax.device_put(image_embeds, data_sharding) if first_frame_mask is not None: - first_frame_mask = jax.device_put(first_frame_mask, data_sharding) - + first_frame_mask = jax.device_put(first_frame_mask, data_sharding) boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps @@ -232,10 +258,16 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, low_noise_state=low_noise_state, low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, high_noise_state=high_noise_state, high_noise_rest=high_noise_rest, - latents=latents, condition=condition, - prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, + latents=latents, + condition=condition, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, scheduler_state=scheduler_state, ) latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) @@ -245,9 +277,14 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): return latents return self._decode_latents_to_video(latents) + def run_inference_2_2_i2v( - low_noise_graphdef, low_noise_state, low_noise_rest, - high_noise_graphdef, high_noise_state, high_noise_rest, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, latents: jnp.array, condition: jnp.array, prompt_embeds: jnp.array, @@ -260,51 +297,59 @@ def run_inference_2_2_i2v( scheduler: FlaxUniPCMultistepScheduler, scheduler_state, ): - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 - def high_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands - latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) - noise_pred, latents_out = transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents_input, ts_input, pe_input, - do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_high, - encoder_hidden_states_image=ie_input - ) - return noise_pred, latents_out - - def low_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands - latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) - noise_pred, latents_out = transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents_input, ts_input, pe_input, - do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_low, - encoder_hidden_states_image=ie_input - ) - return noise_pred, latents_out + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + + def high_noise_branch(operands): + latents_input, ts_input, pe_input, ie_input = operands + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + noise_pred, latents_out = transformer_forward_pass( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents_input, + ts_input, + pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale_high, + encoder_hidden_states_image=ie_input, + ) + return noise_pred, latents_out + + def low_noise_branch(operands): + latents_input, ts_input, pe_input, ie_input = operands + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + noise_pred, latents_out = transformer_forward_pass( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + latents_input, + ts_input, + pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale_low, + encoder_hidden_states_image=ie_input, + ) + return noise_pred, latents_out + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + # WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder + if image_embeds is not None: + image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) + condition = jnp.concatenate([condition] * 2) + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + latents_input = latents if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - # WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder - if image_embeds is not None: - image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) - condition = jnp.concatenate([condition] * 2) - - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - latents_input = latents - if do_classifier_free_guidance: - latents_input = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) - timestep = jnp.broadcast_to(t, latents_input.shape[0]) - - use_high_noise = jnp.greater_equal(t, boundary) - noise_pred, _ = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latent_model_input, timestep, prompt_embeds, image_embeds) - ) - noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents + latents_input = jnp.concatenate([latents, latents], axis=0) + latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + use_high_noise = jnp.greater_equal(t, boundary) + noise_pred, _ = jax.lax.cond( + use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds) + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index 2fc3293b..487cc85e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -37,9 +37,7 @@ import PIL -def retrieve_latents( - encoder_output: torch.Tensor, rngs=None, sample_mode: str = "sample" -): +def retrieve_latents(encoder_output: torch.Tensor, rngs=None, sample_mode: str = "sample"): """Extracts the latent codes from the encoder object. From https://github.com/huggingface/diffusers/blob/8d415a6f481ff1b26168c046267628419650f930/src/diffusers/pipelines/wan/pipeline_wan_vace.py#L128C1-L128C4 @@ -56,9 +54,13 @@ def retrieve_latents( # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder: str = "", ): - def create_model(rngs: nnx.Rngs, wan_config: dict): wan_vace_transformer = WanVACEModel(**wan_config, rngs=rngs) return wan_vace_transformer @@ -148,33 +150,31 @@ class VaceWanPipeline2_1(WanPipeline2_1): """ def preprocess_conditions( - self, - video: Optional[PipelineImageInput] = None, - mask: Optional[PipelineImageInput] = None, - reference_images: Optional[PipelineImageInput] = None, - batch_size: int = 1, - height: int = 480, - width: int = 832, - num_frames: int = 81, - dtype = None, -): + self, + video: Optional[PipelineImageInput] = None, + mask: Optional[PipelineImageInput] = None, + reference_images: Optional[PipelineImageInput] = None, + batch_size: int = 1, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype=None, + ): """Prepares the conditional data for inference. Based on https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/pipelines/wan/pipeline_wan_vace.py#L414 """ if video is not None: - base = self.vae_scale_factor_spatial * ( - self.transformer.config.patch_size[1] - ) + base = self.vae_scale_factor_spatial * (self.transformer.config.patch_size[1]) video_height, video_width = self.video_processor.get_default_height_width(video[0]) if video_height * video_width > height * width: - scale = min(width / video_width, height / video_width) - video_height, video_width = int(video_height * scale), int(video_width * scale) + scale = min(width / video_width, height / video_width) + video_height, video_width = int(video_height * scale), int(video_width * scale) if video_height % base != 0 or video_width % base != 0: - video_height = (video_height // base) * base - video_width = (video_width // base) * base + video_height = (video_height // base) * base + video_width = (video_width // base) * base assert video_height * video_width <= height * width @@ -182,9 +182,7 @@ def preprocess_conditions( video = jnp.array(np.asarray(video), dtype=dtype) image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling) else: - video = jnp.zeros( - (batch_size, 3, num_frames, height, width), dtype=dtype - ) + video = jnp.zeros((batch_size, 3, num_frames, height, width), dtype=dtype) image_size = (height, width) # Use the height/width provider by user if mask is not None: @@ -201,9 +199,7 @@ def preprocess_conditions( # per video if reference_images is None or isinstance(reference_images, PIL.Image.Image): reference_images = [[reference_images] for _ in range(video.shape[0])] - elif isinstance(reference_images, (list, tuple)) and isinstance( - next(iter(reference_images)), PIL.Image.Image - ): + elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image): reference_images = [reference_images] elif ( isinstance(reference_images, (list, tuple)) @@ -242,14 +238,16 @@ def preprocess_conditions( # TODO: should we use jax/TF-based resizing here? resized_image = torch.nn.functional.interpolate( image, size=(new_height, new_width), mode="bilinear", align_corners=False - ).squeeze(0) # [C, H, W] + ).squeeze( + 0 + ) # [C, H, W] top = (image_size[0] - new_height) // 2 left = (image_size[1] - new_width) // 2 canvas = torch.ones(3, *image_size, dtype=torch.float32) canvas[:, top : top + new_height, left : left + new_width] = resized_image - canvas = canvas.permute(1, 2, 0) # Bring back to Jax + canvas = canvas.permute(1, 2, 0) # Bring back to Jax canvas = torch2jax(canvas) preprocessed_images.append(canvas) @@ -275,9 +273,7 @@ def prepare_masks( ) if mask.shape[0] != 1: - raise ValueError( - "Generating with more than one video is not yet supported. This may be supported in the future." - ) + raise ValueError("Generating with more than one video is not yet supported. This may be supported in the future.") transformer_patch_size = self.transformer.config.patch_size[1] @@ -288,14 +284,12 @@ def prepare_masks( new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size mask_ = mask_[0, :, :, :] - mask_ = mask_.view( - num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial - ) + mask_ = mask_.view(num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial) # TODO: should we refactor to use Jax/TF? mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width] mask_ = torch.nn.functional.interpolate( - mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" - ).squeeze(0) + mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" + ).squeeze(0) num_ref_images = len(reference_images_batch) if num_ref_images > 0: mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :]) @@ -307,10 +301,22 @@ def prepare_masks( @classmethod def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): + cls, + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder="transformer", + ): with mesh: wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder + devices_array=devices_array, + mesh=mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder=subfolder, ) return wan_transformer @@ -328,7 +334,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") + transformer = cls.load_transformer( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer" + ) text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -370,9 +378,7 @@ def check_inputs( if self.transformer is not None: base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] else: - raise ValueError( - "`transformer` component must be set in order to run inference with this pipeline" - ) + raise ValueError("`transformer` component must be set in order to run inference with this pipeline") if height % base != 0 or width % base != 0: raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") @@ -382,52 +388,50 @@ def check_inputs( if prompt is not None and prompt_embeds is not None: raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" - " only forward one of the two." + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif negative_prompt is not None and ( - not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) - ): + elif negative_prompt is not None and (not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") if video is not None: if mask is not None: if len(video) != len(mask): raise ValueError( - f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that" - " they have the same length." + f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that" + " they have the same length." ) if reference_images is not None: is_pil_image = isinstance(reference_images, PIL.Image.Image) is_list_of_pil_images = isinstance(reference_images, list) and all( - isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images + isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images ) is_list_of_list_of_pil_images = isinstance(reference_images, list) and all( - isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img) - for ref_img in reference_images + isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img) + for ref_img in reference_images ) if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images): raise ValueError( - "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " - "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" ) if is_list_of_list_of_pil_images and len(reference_images) != 1: raise ValueError( - "The pipeline only supports generating one video at a time at the moment. When passing a list " - "of list of reference images, where the outer list corresponds to the batch size and the inner " - "list corresponds to list of conditioning images per video, please make sure to only pass " - "one inner list of reference images (i.e., `[[, , ...]]`" + "The pipeline only supports generating one video at a time at the moment. When passing a list " + "of list of reference images, where the outer list corresponds to the batch size and the inner " + "list corresponds to list of conditioning images per video, please make sure to only pass " + "one inner list of reference images (i.e., `[[, , ...]]`" ) elif mask is not None: raise ValueError("`mask` can only be passed if `video` is passed as well.") @@ -438,7 +442,6 @@ def __call__( mask: Optional[List[PipelineImageInput]] = None, reference_images: Optional[List[PipelineImageInput]] = None, conditioning_scale: Union[float, List[float], torch.Tensor] = 1.0, - prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 480, @@ -484,7 +487,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, video=video, mask=mask, - reference_images=reference_images + reference_images=reference_images, ) if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: @@ -500,9 +503,7 @@ def __call__( batch_size = len(prompt) if num_videos_per_prompt != 1: - raise ValueError( - "Generating multiple videos per prompt is not yet supported. This may be supported in the future." - ) + raise ValueError("Generating multiple videos per prompt is not yet supported. This may be supported in the future.") prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, @@ -520,8 +521,8 @@ def __call__( if isinstance(conditioning_scale, list): if len(conditioning_scale) != len(vace_layers): raise ValueError( - f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}." - ) + f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}." + ) conditioning_scale = jnp.array(conditioning_scale) if isinstance(conditioning_scale, jax.Array): if conditioning_scale.shape[0] != len(vace_layers): @@ -545,7 +546,9 @@ def __call__( if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding)) - conditioning_latents = self.prepare_video_latents(data_sharding=data_sharding, video=video, mask=mask, reference_images=reference_images, rngs=None) + conditioning_latents = self.prepare_video_latents( + data_sharding=data_sharding, video=video, mask=mask, reference_images=reference_images, rngs=None + ) mask = self.prepare_masks(mask, reference_images) conditioning_latents = conditioning_latents.transpose(0, 4, 1, 2, 3) @@ -628,7 +631,6 @@ def prepare_video_latents( reference_images: Optional[List[List[torch.Tensor]]] = None, rngs=None, ) -> jax.Array: - if reference_images is None: # For each batch of video, we set no re # ference image (as one or more can be passed by user) @@ -640,9 +642,7 @@ def prepare_video_latents( ) if video.shape[0] != 1: - raise ValueError( - "Generating with more than one video is not yet supported. This may be supported in the future." - ) + raise ValueError("Generating with more than one video is not yet supported. This may be supported in the future.") vae_dtype = self.vae.decoder.conv_in.conv.bias.dtype video = video.astype(dtype=vae_dtype) @@ -671,7 +671,9 @@ def prepare_video_latents( reference_image = jax.device_put(reference_image, data_sharding) reference_image = reference_image[None, None, :, :, :] # [1, 1, H, W, C] - reference_latent = retrieve_latents(self.vae.encode(reference_image, feat_cache=self.vae_cache), rngs=None, sample_mode="argmax") + reference_latent = retrieve_latents( + self.vae.encode(reference_image, feat_cache=self.vae_cache), rngs=None, sample_mode="argmax" + ) reference_latent = ((reference_latent.astype(jnp.float32) - latents_mean) * latents_std).astype(vae_dtype) @@ -727,7 +729,6 @@ def run_inference( num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, scheduler_state, - control_hidden_states, control_hidden_states_scale, ): @@ -751,7 +752,6 @@ def run_inference( prompt_embeds, control_hidden_states=control_hidden_states, control_hidden_states_scale=control_hidden_states_scale, - do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, ) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 27c9f645..bcc409b1 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring import os @@ -32,6 +32,7 @@ _ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2} _ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1} + def _validate_model_name(model_name: str | None): """Raise if model_name is not in the allowed list.""" if model_name is None: @@ -39,12 +40,16 @@ def _validate_model_name(model_name: str | None): if model_name not in _ALLOWED_MODEL_NAMES: raise ValueError(f"Invalid config.model_name '{model_name}'. Allowed values: {sorted(_ALLOWED_MODEL_NAMES)}") + def _validate_training_model_name(model_name: str | None): """Raise if model_name is not in the allowed training list.""" if model_name is None: return if model_name not in _ALLOWED_TRAINING_MODEL_NAMES: - raise ValueError(f"Invalid config.model_name '{model_name}' for training. Allowed values: {sorted(_ALLOWED_TRAINING_MODEL_NAMES)}") + raise ValueError( + f"Invalid config.model_name '{model_name}' for training. Allowed values: {sorted(_ALLOWED_TRAINING_MODEL_NAMES)}" + ) + def string_to_bool(s: str) -> bool: if s.lower() == "true": @@ -196,7 +201,9 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]: - max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.") + max_logging.log( + f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set." + ) logical_axis_rules = list(raw_keys["logical_axis_rules"]) max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") new_rules = [] @@ -211,7 +218,7 @@ def user_init(raw_keys): if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") new_rules.append(ring_attention_axis_rule) - else: # attention =flash but sequence parallel sharding requested for both self and cross attention + else: # attention =flash but sequence parallel sharding requested for both self and cross attention for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES: if seq_parallel_axis_rule not in logical_axis_rules: max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}") @@ -244,9 +251,10 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = ( - _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) - ) + ( + raw_keys["global_batch_size_to_load"], + raw_keys["global_batch_size_to_train_on"], + ) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py index 218117eb..c55a49c4 100644 --- a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -528,13 +528,11 @@ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: ) def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - timestep_list = jnp.array( - [ - state.timesteps[step_index - 2], - state.timesteps[step_index - 1], - state.timesteps[step_index], - ] - ) + timestep_list = jnp.array([ + state.timesteps[step_index - 2], + state.timesteps[step_index - 1], + state.timesteps[step_index], + ]) return self.multistep_dpm_solver_third_order_update( state, state.model_outputs, diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index 863fa26c..f6ace5fc 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -256,7 +256,6 @@ def add_noise( timesteps: jnp.ndarray, flux: bool = False, ) -> jnp.ndarray: - if flux: t = state.timesteps[timesteps] t = t[:, None, None] diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index 03a47fd4..b2c7d96a 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -136,13 +136,11 @@ def __init__( if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if ( - sum( - [ - self.config.use_beta_sigmas, - self.config.use_exponential_sigmas, - self.config.use_karras_sigmas, - ] - ) + sum([ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ]) > 1 ): raise ValueError( diff --git a/src/maxdiffusion/schedulers/scheduling_utils_flax.py b/src/maxdiffusion/schedulers/scheduling_utils_flax.py index e1690ba8..d38f1446 100644 --- a/src/maxdiffusion/schedulers/scheduling_utils_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_utils_flax.py @@ -262,7 +262,8 @@ def create(cls, scheduler): elif config.beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = ( - jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) ** 2 + jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) + ** 2 ) elif config.beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule diff --git a/src/maxdiffusion/tests/__init__.py b/src/maxdiffusion/tests/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/tests/__init__.py +++ b/src/maxdiffusion/tests/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index c2180240..65561eca 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest diff --git a/src/maxdiffusion/tests/configuration_utils_test.py b/src/maxdiffusion/tests/configuration_utils_test.py index 29f3f8a7..ee761a38 100644 --- a/src/maxdiffusion/tests/configuration_utils_test.py +++ b/src/maxdiffusion/tests/configuration_utils_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import json import os diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index 4cb290ef..ca0a5020 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest from unittest.mock import Mock diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py index 68968bfd..12bfe77b 100644 --- a/src/maxdiffusion/tests/generate_flux_smoke_test.py +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index a5bb289f..ff823010 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index d0c02044..2c5b783a 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest import pytest diff --git a/src/maxdiffusion/tests/gradient_checkpoint_test.py b/src/maxdiffusion/tests/gradient_checkpoint_test.py index ca237d52..a4d6f6cd 100644 --- a/src/maxdiffusion/tests/gradient_checkpoint_test.py +++ b/src/maxdiffusion/tests/gradient_checkpoint_test.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import unittest diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 1141ec8c..0b55c8f8 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -1,17 +1,17 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import os @@ -70,7 +70,6 @@ def setUp(self): InputPipelineInterface.dummy_data = {} def test_make_dreambooth_train_iterator(self): - instance_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/instance_class" class_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/class_class" local_dir = "/tmp/" @@ -135,7 +134,9 @@ def test_make_dreambooth_train_iterator(self): cleanup(instance_class_local_dir) cleanup(class_class_local_dir) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_pokemon_hf_iterator(self): pyconfig.initialize( [ @@ -239,7 +240,9 @@ def test_make_pokemon_hf_iterator_sdxl(self): assert data["input_ids"].shape == (device_count, 2, 77) assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_pokemon_tf_iterator_cache(self): pyconfig.initialize( [ @@ -302,7 +305,9 @@ def test_make_pokemon_tf_iterator_cache(self): config.resolution // vae_scale_factor, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_pokemon_iterator_no_cache(self): pyconfig.initialize( [ @@ -435,7 +440,9 @@ def test_make_pokemon_iterator_sdxl_cache(self): config.resolution // vae_scale_factor, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_laion_grain_iterator(self): try: subprocess.check_output( @@ -492,7 +499,9 @@ def test_make_laion_grain_iterator(self): 8, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_laion_tfrecord_iterator(self): pyconfig.initialize( [ @@ -553,7 +562,9 @@ def _parse_tfrecord_fn(example): 8, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_tfrecord(self): """Validate latents match a deterministic output image""" diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 9398c915..083ed265 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import torch diff --git a/src/maxdiffusion/tests/maxdiffusion_utils_test.py b/src/maxdiffusion/tests/maxdiffusion_utils_test.py index 23dd3ee3..4b29d365 100644 --- a/src/maxdiffusion/tests/maxdiffusion_utils_test.py +++ b/src/maxdiffusion/tests/maxdiffusion_utils_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest diff --git a/src/maxdiffusion/tests/text_encoders_test.py b/src/maxdiffusion/tests/text_encoders_test.py index e7d3d6dd..c91bca9a 100644 --- a/src/maxdiffusion/tests/text_encoders_test.py +++ b/src/maxdiffusion/tests/text_encoders_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest @@ -36,7 +36,6 @@ def setUp(self): @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_t5_text_encoder(self): - text_encoder = FlaxT5EncoderModel.from_pretrained("ariG23498/t5-v1-1-xxl-flax") tokenizer_2 = T5TokenizerFast.from_pretrained("ariG23498/t5-v1-1-xxl-flax") @@ -47,7 +46,6 @@ def test_flux_t5_text_encoder(self): @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_clip_text_encoder(self): - text_encoder = FlaxCLIPTextModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="text_encoder", from_pt=True, dtype="bfloat16" ) diff --git a/src/maxdiffusion/tests/train_smoke_test.py b/src/maxdiffusion/tests/train_smoke_test.py index a7d0f4b8..f5f6df00 100644 --- a/src/maxdiffusion/tests/train_smoke_test.py +++ b/src/maxdiffusion/tests/train_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test """ import os diff --git a/src/maxdiffusion/tests/unet_test.py b/src/maxdiffusion/tests/unet_test.py index 562fb5a3..0bbf706f 100644 --- a/src/maxdiffusion/tests/unet_test.py +++ b/src/maxdiffusion/tests/unet_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test """ import os diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index e3a46b10..17e9b211 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest @@ -38,7 +38,6 @@ def setUp(self): @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_vae(self): - img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) img_min = np.min(base_image) diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index 81a38670..a0a529f1 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2025 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import unittest from unittest.mock import patch, MagicMock @@ -19,6 +19,7 @@ from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 + class WanCheckpointer2_1Test(unittest.TestCase): """Tests for WAN 2.1 checkpointer.""" @@ -237,6 +238,7 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m self.assertEqual(opt_state["learning_rate"], 0.002) self.assertEqual(step, 1) + class WanCheckpointerI2V_2_1Test(unittest.TestCase): """Tests for WAN 2.1 I2V checkpointer.""" @@ -324,6 +326,7 @@ def test_load_checkpoint_with_optimizer(self, mock_from_checkpoint, mock_create_ self.assertEqual(opt_state["learning_rate"], 0.001) self.assertEqual(step, 1) + class WanCheckpointerI2V_2_2Test(unittest.TestCase): """Tests for WAN 2.2 I2V checkpointer.""" @@ -447,6 +450,7 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline_i2 self.assertEqual(opt_state["learning_rate"], 0.002) self.assertEqual(step, 1) + class WanCheckpointerEdgeCasesTest(unittest.TestCase): """Tests for edge cases and error handling.""" diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index c1044cc3..eb8b4973 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import jax @@ -65,7 +65,6 @@ def setUp(self): devices_array = create_device_mesh(config) self.mesh = Mesh(devices_array, config.mesh_axes) - def test_rotary_pos_embed(self): batch_size = 1 channels = 16 @@ -198,12 +197,7 @@ def test_wan_block(self): def test_wan_attention(self): for attention_kernel in ["flash", "tokamax_flash"]: pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - f"attention={attention_kernel}" - ], - unittest=True + [None, os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), f"attention={attention_kernel}"], unittest=True ) config = pyconfig.config batch_size = 1 @@ -286,7 +280,9 @@ def test_wan_model(self): batch_size = 1 num_layers = 1 with nn_partitioning.axis_rules(config.logical_axis_rules): - wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) + wan_model = WanModel( + rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers + ) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) diff --git a/src/maxdiffusion/tests/wan_vace_transformer_test.py b/src/maxdiffusion/tests/wan_vace_transformer_test.py index 9864e64c..05b04f76 100644 --- a/src/maxdiffusion/tests/wan_vace_transformer_test.py +++ b/src/maxdiffusion/tests/wan_vace_transformer_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import jax @@ -43,6 +43,7 @@ class WanVaceTransformerTest(unittest.TestCase): + def test_wan_vace_block_returns_the_correct_shape(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -117,5 +118,6 @@ def test_wan_vace_block_returns_the_correct_shape(self): assert conditioning_states.shape == dummy_hidden_states.shape assert control_hidden_states.shape == dummy_hidden_states.shape + if __name__ == "__main__": absltest.main() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index b2ffbc3b..0f9158cb 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import functools @@ -211,7 +211,6 @@ def test_wanrms_norm(self): assert np.allclose(output_np, torch_output_np) is True def test_zero_padded_conv(self): - key = jax.random.key(0) rngs = nnx.Rngs(key) diff --git a/src/maxdiffusion/tpu_utils.py b/src/maxdiffusion/tpu_utils.py index 9ea03e7c..5697f60c 100644 --- a/src/maxdiffusion/tpu_utils.py +++ b/src/maxdiffusion/tpu_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax diff --git a/src/maxdiffusion/train.py b/src/maxdiffusion/train.py index 60657e0b..1bfcc942 100644 --- a/src/maxdiffusion/train.py +++ b/src/maxdiffusion/train.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index 05cdae44..e341ae1f 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/train_sdxl.py b/src/maxdiffusion/train_sdxl.py index 60170a85..ad11c1e4 100644 --- a/src/maxdiffusion/train_sdxl.py +++ b/src/maxdiffusion/train_sdxl.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 79e65e99..8db92a40 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import numpy as np import jax diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index fea15720..9217fb4a 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/trainers/__init__.py b/src/maxdiffusion/trainers/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/trainers/__init__.py +++ b/src/maxdiffusion/trainers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index a9f17adc..7bb2e26b 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from abc import abstractmethod import time diff --git a/src/maxdiffusion/trainers/dreambooth_trainer.py b/src/maxdiffusion/trainers/dreambooth_trainer.py index 40a40190..a2bd8991 100644 --- a/src/maxdiffusion/trainers/dreambooth_trainer.py +++ b/src/maxdiffusion/trainers/dreambooth_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from pathlib import Path import time @@ -116,7 +116,6 @@ def get_data_shardings(self): return data_sharding def load_dataset(self, pipeline, params, train_states): - return make_dreambooth_train_iterator( self.config, self.mesh, @@ -183,7 +182,6 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da return p_train_step def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, learning_rate_scheduler): - writer = max_utils.initialize_summary_writer(self.config) unet_state = train_states["unet_state"] text_encoder_state = train_states["text_encoder_state"] @@ -265,7 +263,6 @@ def _train_step(unet_state, text_encoder_state, batch, train_rng, config, pipeli state_params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params} def compute_loss(state_params): - encoder_hidden_states = encode(input_ids, pipeline.text_encoder, state_params["text_encoder"]) # Sample noise that we'll add to the latents diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 74b4f259..c9cbe871 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os from functools import partial @@ -76,7 +76,6 @@ def calculate_tflops(self, pipeline): return per_device_tflops def start_training(self): - # Hook # self.pre_training_steps() # Load checkpoint - will load or create states @@ -314,7 +313,6 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da return p_train_step def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): - writer = max_utils.initialize_summary_writer(self.config) flux_state = train_states[FLUX_STATE_KEY] num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params) diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index a68cc617..fc442117 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os from functools import partial @@ -176,7 +176,6 @@ def prepare_sample(features): return data_iterator def compile_train_step(self, pipeline, params, train_states, state_shardings, data_shardings): - self.rng, train_rngs = jax.random.split(self.rng) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): p_train_step = jax.jit( @@ -208,7 +207,6 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da return p_train_step def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): - writer = max_utils.initialize_summary_writer(self.config) writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) writer_thread.start() diff --git a/src/maxdiffusion/trainers/stable_diffusion_trainer.py b/src/maxdiffusion/trainers/stable_diffusion_trainer.py index 5844df3d..a89c22ac 100644 --- a/src/maxdiffusion/trainers/stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/stable_diffusion_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import sys diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index f23836a5..24852fde 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import datetime @@ -105,7 +105,6 @@ def create_scheduler(self): @staticmethod def calculate_tflops(pipeline): - maxdiffusion_config = pipeline.config # Model configuration height = pipeline.config.height @@ -210,7 +209,6 @@ def prepare_sample_eval(features): return data_iterator def start_training(self): - pipeline, opt_state, step = self.checkpointer.load_checkpoint() restore_args = {} if opt_state and step: diff --git a/src/maxdiffusion/utils/deprecation_utils.py b/src/maxdiffusion/utils/deprecation_utils.py index bd2f6e35..265a60b5 100644 --- a/src/maxdiffusion/utils/deprecation_utils.py +++ b/src/maxdiffusion/utils/deprecation_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import inspect import warnings from typing import Any, Dict, Optional, Union diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index c540f5a9..51b05d30 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import io import random import struct diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index d83596e8..05ef72ec 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -512,29 +512,27 @@ def is_peft_available(): """ -BACKENDS_MAPPING = OrderedDict( - [ - ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), - ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), - ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), - ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), - ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), - ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), - ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), - ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), - ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), - ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), - ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), - ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), - ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), - ] -) +BACKENDS_MAPPING = OrderedDict([ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), + ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), + ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), + ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), + ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), +]) def requires_backends(obj, backends): diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 85bddb87..f2b72cbd 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os from typing import Callable, List, Optional, Union diff --git a/src/maxdiffusion/utils/logging.py b/src/maxdiffusion/utils/logging.py index b9013a95..2fe7d87d 100644 --- a/src/maxdiffusion/utils/logging.py +++ b/src/maxdiffusion/utils/logging.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Logging utilities.""" +"""Logging utilities.""" import logging import os diff --git a/src/maxdiffusion/utils/pil_utils.py b/src/maxdiffusion/utils/pil_utils.py index cb44e025..a05aa47a 100644 --- a/src/maxdiffusion/utils/pil_utils.py +++ b/src/maxdiffusion/utils/pil_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import List import PIL.Image diff --git a/src/maxdiffusion/utils/testing_utils.py b/src/maxdiffusion/utils/testing_utils.py index a5e8aeae..6194a03a 100644 --- a/src/maxdiffusion/utils/testing_utils.py +++ b/src/maxdiffusion/utils/testing_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import functools import importlib import inspect