-
Notifications
You must be signed in to change notification settings - Fork 3.5k
diffusion model: support stable-diffusion-3-medium-diffusers #13422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
diffusion model: support stable-diffusion-3-medium-diffusers #13422
Conversation
…ediffusion3medium
fix spelling error Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…edium' into support_stablediffusion3medium
[MultiModal]Support stable-diffusion-3-medium-diffusers
…tablediffusion3medium
…edium' into support_stablediffusion3medium
Format code with pre-commit
Summary of ChangesHello @IPostYellow, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Stable Diffusion 3 Medium into SGLang, expanding its multimodal generation capabilities to include state-of-the-art text-to-image synthesis. The changes involve adding new model configurations, implementing a dedicated pipeline, and adapting core runtime components to support SD3's complex multi-text-encoder architecture and VAE processing, ensuring seamless operation and high-quality image generation. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Stable Diffusion 3 Medium for text-to-image generation. The changes are comprehensive, touching configuration, model implementation, and pipeline stages. The implementation correctly handles the three text encoders required by SD3. I've identified a few issues, including a bug in the SD3 transformer's forward pass return value and some brittle file loading logic. Additionally, I've provided suggestions to improve code clarity and maintainability by removing dead code and simplifying some expressions. Overall, this is a great contribution.
| if isinstance(server_args.pipeline_config, StableDiffusion3PipelineConfig): | ||
| precision = server_args.pipeline_config.vae_precision | ||
| base_name = "diffusion_pytorch_model" | ||
|
|
||
| # Priority: fp16 > full precision > any matching file | ||
| if precision == "fp16": | ||
| fp16_path = os.path.join( | ||
| str(model_path), f"{base_name}.fp16.safetensors" | ||
| ) | ||
| target_files = [fp16_path] if os.path.exists(fp16_path) else [] | ||
| else: | ||
| full_path = os.path.join(str(model_path), f"{base_name}.safetensors") | ||
| target_files = [full_path] if os.path.exists(full_path) else [] | ||
| safetensors_list = target_files |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for finding the VAE's safetensors file is brittle. If the specific precision file (.fp16.safetensors or .safetensors) is not found, it results in an empty list, which will cause the assertion on line 491 to fail with a generic message. The comment on line 480 suggests a priority-based fallback, which is not fully implemented. I suggest a more robust implementation that correctly applies the priority and provides a better fallback.
if isinstance(server_args.pipeline_config, StableDiffusion3PipelineConfig):
precision = server_args.pipeline_config.vae_precision
base_name = "diffusion_pytorch_model"
# Priority: fp16 > full precision > any matching file
fp16_path = os.path.join(str(model_path), f"{base_name}.fp16.safetensors")
full_path = os.path.join(str(model_path), f"{base_name}.safetensors")
if precision == "fp16" and os.path.exists(fp16_path):
safetensors_list = [fp16_path]
elif os.path.exists(full_path):
safetensors_list = [full_path]
elif os.path.exists(fp16_path):
safetensors_list = [fp16_path]
else:
# Fallback to any safetensors file if specific ones are not found
safetensors_list = glob.glob(os.path.join(str(model_path), f"{base_name}*.safetensors"))| if not return_dict: | ||
| return (output,) | ||
|
|
||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When return_dict is True, the function should return a Transformer2DModelOutput object, but it currently returns a raw tensor. This can lead to AttributeError if the caller expects an object with a .sample attribute. Please wrap the output tensor in Transformer2DModelOutput.
| if not return_dict: | |
| return (output,) | |
| return output | |
| if not return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) |
| _IMAGE_ENCODER_MODELS: dict[str, tuple] = { | ||
| # "HunyuanVideoTransformer3DModel": ("image_encoder", "hunyuanvideo", "HunyuanVideoImageEncoder"), | ||
| "CLIPVisionModelWithProjection": ("encoders", "clip", "CLIPVisionModel"), | ||
| "CLIPTextModelWithProjection": ("encoders", "clip", "CLIPTextModel"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # if batch.do_classifier_free_guidance: | ||
| # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | ||
| # pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | ||
| # batch.prompt_embeds = [prompt_embeds] | ||
| # batch.pooled_embeds = [pooled_prompt_embeds] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| vae_scale_factor = ( | ||
| server_args.pipeline_config.vae_config.get_vae_scale_factor() | ||
| if server_args.pipeline_config.vae_config.get_vae_scale_factor() | ||
| else 8 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
awesome job, thanks! we'll get back to this PR once necessary CI tests and refactors are added |
Remove unnecessary comments
Format code with pre-commit
Compatible with the latest code
…on3medium_fn2 # Conflicts: # python/sglang/multimodal_gen/configs/pipeline_configs/__init__.py # python/sglang/multimodal_gen/configs/pipeline_configs/stablediffusion3.py # python/sglang/multimodal_gen/registry.py # python/sglang/multimodal_gen/runtime/pipelines_core/stages/conditioning.py # python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py
|
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class TestStableDiffusionT2Image(TestGenerateBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, cli test is deprecated. Could we add it to test_server_a.py? Thanks
ca55334 to
8a52f14
Compare
Motivation
This PR introduces support for Stable Diffusion 3 Medium (stabilityai/stable-diffusion-3-medium-diffusers) text-to-image (t2i) generation in SGLang.
run with cli:
Output:

or starts a model inference server and generates an image via API call.
Start the server:
Send generation request:
Output:

Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist