Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,6 @@ struct SDContextParams {
"--vae-tile-overlap",
"tile overlap for vae tiling, in fraction of tile size (default: 0.5)",
&vae_tiling_params.target_overlap},
{"",
"--flow-shift",
"shift value for Flow models like SD3.x or WAN (default: auto)",
&flow_shift},
};

options.bool_options = {
Expand Down Expand Up @@ -903,7 +899,6 @@ struct SDContextParams {
<< " photo_maker_path: \"" << photo_maker_path << "\",\n"
<< " rng_type: " << sd_rng_type_name(rng_type) << ",\n"
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
Expand Down Expand Up @@ -986,7 +981,6 @@ struct SDContextParams {
chroma_use_t5_mask,
chroma_t5_mask_pad,
qwen_image_zero_cond_t,
flow_shift,
};
return sd_ctx_params;
}
Expand Down Expand Up @@ -1206,6 +1200,10 @@ struct SDGenerationParams {
"--eta",
"eta in DDIM, only for DDIM and TCD (default: 0)",
&sample_params.eta},
{"",
"--flow-shift",
"shift value for Flow models like SD3.x or WAN (default: auto)",
&sample_params.flow_shift},
{"",
"--high-noise-cfg-scale",
"(high noise) unconditional guidance scale: (default: 7.0)",
Expand Down Expand Up @@ -1606,6 +1604,7 @@ struct SDGenerationParams {
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
load_if_exists("guidance", sample_params.guidance.distilled_guidance);
load_if_exists("flow_shift", sample_params.flow_shift);

auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) {
if (j.contains(key) && j[key].is_string()) {
Expand Down
2 changes: 1 addition & 1 deletion include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ typedef struct {
bool chroma_use_t5_mask;
int chroma_t5_mask_pad;
bool qwen_image_zero_cond_t;
float flow_shift;
} sd_ctx_params_t;

typedef struct {
Expand Down Expand Up @@ -235,6 +234,7 @@ typedef struct {
int shifted_timestep;
float* custom_sigmas;
int custom_sigmas_count;
float flow_shift;
} sd_sample_params_t;

typedef struct {
Expand Down
97 changes: 37 additions & 60 deletions src/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,38 +651,19 @@ float time_snr_shift(float alpha, float t) {
return alpha * t / (1 + (alpha - 1) * t);
}

struct DiscreteFlowDenoiser : public Denoiser {
struct FlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 3.0f;

float shift = INFINITY;
float sigma_data = 1.0f;

DiscreteFlowDenoiser(float shift = 3.0f)
: shift(shift) {
set_parameters();
}

void set_parameters() {
for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(static_cast<float>(i));
}
}
virtual void set_parameters(float shift) = 0;

float sigma_min() override {
return sigmas[0];
}

float sigma_max() override {
return sigmas[TIMESTEPS - 1];
}

float sigma_to_t(float sigma) override {
return sigma * 1000.f;
}

float t_to_sigma(float t) override {
t = t + 1;
return time_snr_shift(shift, t / 1000.f);
return sigmas[TIMESTEPS-1];
}

std::vector<float> get_scalings(float sigma) override {
Expand All @@ -706,39 +687,54 @@ struct DiscreteFlowDenoiser : public Denoiser {
}
};

struct DiscreteFlowDenoiser : public FlowDenoiser {

DiscreteFlowDenoiser() = default;

void set_parameters(float shift) override {
if (shift != this->shift) {
this->shift = shift;
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(static_cast<float>(i));
}
}
}

float sigma_to_t(float sigma) override {
return sigma * 1000.f;
}

float t_to_sigma(float t) override {
t = t + 1;
return time_snr_shift(shift, t / 1000.f);
}

};

float flux_time_shift(float mu, float sigma, float t) {
return ::expf(mu) / (::expf(mu) + ::powf((1.0f / t - 1.0f), sigma));
}

struct FluxFlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 1.15f;
struct FluxFlowDenoiser : public FlowDenoiser {

float sigma_data = 1.0f;
float shift_sigmas = INFINITY;

FluxFlowDenoiser(float shift = 1.15f) {
set_parameters(shift);
}
FluxFlowDenoiser() = default;

void set_shift(float shift) {
this->shift = shift;
}

void set_parameters(float shift) {
void set_parameters(float shift) override {
set_shift(shift);
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(static_cast<float>(i));
if (shift != shift_sigmas) {
shift_sigmas = shift;
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(static_cast<float>(i));
}
}
}

float sigma_min() override {
return sigmas[0];
}

float sigma_max() override {
return sigmas[TIMESTEPS - 1];
}

float sigma_to_t(float sigma) override {
return sigma;
}
Expand All @@ -748,25 +744,6 @@ struct FluxFlowDenoiser : public Denoiser {
return flux_time_shift(shift, 1.0f, t / TIMESTEPS);
}

std::vector<float> get_scalings(float sigma) override {
float c_skip = 1.0f;
float c_out = -sigma;
float c_in = 1.0f;
return {c_skip, c_out, c_in};
}

// this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(noise, sigma);
ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma);
ggml_ext_tensor_add_inplace(latent, noise);
return latent;
}

ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma));
return latent;
}
};

struct Flux2FlowDenoiser : public FluxFlowDenoiser {
Expand Down
50 changes: 32 additions & 18 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class StableDiffusionGGML {
int n_threads = -1;
float scale_factor = 0.18215f;
float shift_factor = 0.f;
float default_flow_shift = INFINITY;

std::shared_ptr<Conditioner> cond_stage_model;
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
Expand Down Expand Up @@ -881,7 +882,6 @@ class StableDiffusionGGML {
// init denoiser
{
prediction_t pred_type = sd_ctx_params->prediction;
float flow_shift = sd_ctx_params->flow_shift;

if (pred_type == PREDICTION_COUNT) {
if (sd_version_is_sd2(version)) {
Expand All @@ -906,22 +906,19 @@ class StableDiffusionGGML {
sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) {
pred_type = FLOW_PRED;
if (flow_shift == INFINITY) {
if (sd_version_is_wan(version)) {
flow_shift = 5.f;
} else {
flow_shift = 3.f;
}
if (sd_version_is_wan(version)) {
default_flow_shift = 5.f;
} else {
default_flow_shift = 3.f;
}
} else if (sd_version_is_flux(version)) {
pred_type = FLUX_FLOW_PRED;

if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
flow_shift = 1.15f;
}
default_flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
default_flow_shift = 1.15f;
break;
}
}
} else if (sd_version_is_flux2(version)) {
Expand All @@ -945,12 +942,12 @@ class StableDiffusionGGML {
break;
case FLOW_PRED: {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
denoiser = std::make_shared<DiscreteFlowDenoiser>();
break;
}
case FLUX_FLOW_PRED: {
LOG_INFO("running in Flux FLOW mode");
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
denoiser = std::make_shared<FluxFlowDenoiser>();
break;
}
case FLUX2_FLOW_PRED: {
Expand Down Expand Up @@ -2711,6 +2708,17 @@ class StableDiffusionGGML {
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
return result;
}

void set_flow_shift(float flow_shift = INFINITY) {
auto flow_denoiser = std::dynamic_pointer_cast<FlowDenoiser>(denoiser);
if (flow_denoiser) {
if (flow_shift == INFINITY) {
flow_shift = default_flow_shift;
}
flow_denoiser->set_parameters(flow_shift);
}
}

};

/*================================================= SD API ==================================================*/
Expand Down Expand Up @@ -2931,7 +2939,6 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->chroma_use_dit_mask = true;
sd_ctx_params->chroma_use_t5_mask = false;
sd_ctx_params->chroma_t5_mask_pad = 1;
sd_ctx_params->flow_shift = INFINITY;
}

char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
Expand Down Expand Up @@ -3023,6 +3030,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->sample_steps = 20;
sample_params->custom_sigmas = nullptr;
sample_params->custom_sigmas_count = 0;
sample_params->flow_shift = INFINITY;
}

char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
Expand All @@ -3043,7 +3051,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
"sample_method: %s, "
"sample_steps: %d, "
"eta: %.2f, "
"shifted_timestep: %d)",
"shifted_timestep: %d, ",
"flow_shift: %.2f)",
sample_params->guidance.txt_cfg,
std::isfinite(sample_params->guidance.img_cfg)
? sample_params->guidance.img_cfg
Expand All @@ -3057,7 +3066,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
sd_sample_method_name(sample_params->sample_method),
sample_params->sample_steps,
sample_params->eta,
sample_params->shifted_timestep);
sample_params->shifted_timestep,
sample_params->flow_shift);

return buf;
}
Expand Down Expand Up @@ -3528,6 +3538,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g

size_t t0 = ggml_time_ms();

sd_ctx->sd->set_flow_shift(sd_img_gen_params->sample_params.flow_shift);

// Apply lora
sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count);

Expand Down Expand Up @@ -3803,6 +3815,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
LOG_INFO("generate_video %dx%dx%d", width, height, frames);

sd_ctx->sd->set_flow_shift(sd_vid_gen_params->sample_params.flow_shift);

enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
Expand Down
Loading