@@ -161,31 +161,79 @@ class CoefficientCalculator:
161161 "t2v" : {
162162 "1.3b" : {
163163 "default" : [
164- [- 5.21862437e04 , 9.23041404e03 , - 5.28275948e02 , 1.36987616e01 , - 4.99875664e-02 ],
165- [2.39676752e03 , - 1.31110545e03 , 2.01331979e02 , - 8.29855975e00 , 1.37887774e-01 ],
164+ [
165+ - 5.21862437e04 ,
166+ 9.23041404e03 ,
167+ - 5.28275948e02 ,
168+ 1.36987616e01 ,
169+ - 4.99875664e-02 ,
170+ ],
171+ [
172+ 2.39676752e03 ,
173+ - 1.31110545e03 ,
174+ 2.01331979e02 ,
175+ - 8.29855975e00 ,
176+ 1.37887774e-01 ,
177+ ],
166178 ]
167179 },
168180 "14b" : {
169181 "default" : [
170- [- 3.03318725e05 , 4.90537029e04 , - 2.65530556e03 , 5.87365115e01 , - 3.15583525e-01 ],
171- [- 5784.54975374 , 5449.50911966 , - 1811.16591783 , 256.27178429 , - 13.02252404 ],
182+ [
183+ - 3.03318725e05 ,
184+ 4.90537029e04 ,
185+ - 2.65530556e03 ,
186+ 5.87365115e01 ,
187+ - 3.15583525e-01 ,
188+ ],
189+ [
190+ - 5784.54975374 ,
191+ 5449.50911966 ,
192+ - 1811.16591783 ,
193+ 256.27178429 ,
194+ - 13.02252404 ,
195+ ],
172196 ]
173197 },
174198 },
175199 "i2v" : {
176200 "720p" : [
177- [8.10705460e03 , 2.13393892e03 , - 3.72934672e02 , 1.66203073e01 , - 4.17769401e-02 ],
201+ [
202+ 8.10705460e03 ,
203+ 2.13393892e03 ,
204+ - 3.72934672e02 ,
205+ 1.66203073e01 ,
206+ - 4.17769401e-02 ,
207+ ],
178208 [- 114.36346466 , 65.26524496 , - 18.82220707 , 4.91518089 , - 0.23412683 ],
179209 ],
180210 "480p" : [
181- [2.57151496e05 , - 3.54229917e04 , 1.40286849e03 , - 1.35890334e01 , 1.32517977e-01 ],
182- [- 3.02331670e02 , 2.23948934e02 , - 5.25463970e01 , 5.87348440e00 , - 2.01973289e-01 ],
211+ [
212+ 2.57151496e05 ,
213+ - 3.54229917e04 ,
214+ 1.40286849e03 ,
215+ - 1.35890334e01 ,
216+ 1.32517977e-01 ,
217+ ],
218+ [
219+ - 3.02331670e02 ,
220+ 2.23948934e02 ,
221+ - 5.25463970e01 ,
222+ 5.87348440e00 ,
223+ - 2.01973289e-01 ,
224+ ],
183225 ],
184226 },
185227 }
186228
187229 @classmethod
188- def get_coefficients (cls , task : str , model_size : str , resolution : Tuple [int , int ], use_ret_steps : bool ) -> List [List [float ]]:
230+ def get_coefficients (
231+ cls ,
232+ task : str ,
233+ model_size : str ,
234+ resolution : Tuple [int , int ],
235+ use_ret_steps : bool ,
236+ ) -> List [List [float ]]:
189237 """Get appropriate coefficients for TeaCache."""
190238 if task == "t2v" :
191239 coeffs = cls .COEFFICIENTS ["t2v" ].get (model_size , {}).get ("default" , None )
@@ -269,13 +317,15 @@ def apply_inference_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
269317 updates ["target_video_length" ] = config ["video_length" ]
270318 if "fps" in config :
271319 updates ["fps" ] = config ["fps" ]
272-
320+
273321 if "denoising_step_list" in config :
274322 updates ["denoising_step_list" ] = config ["denoising_step_list" ]
275323
276324 return updates
277325
278- def apply_teacache_config (self , config : Dict [str , Any ], model_info : Dict [str , Any ]) -> Dict [str , Any ]:
326+ def apply_teacache_config (
327+ self , config : Dict [str , Any ], model_info : Dict [str , Any ]
328+ ) -> Dict [str , Any ]:
279329 """Apply TeaCache configuration."""
280330 updates = {}
281331
@@ -286,16 +336,23 @@ def apply_teacache_config(self, config: Dict[str, Any], model_info: Dict[str, An
286336
287337 task = model_info .get ("task" , "t2v" )
288338 model_size = "14b" if "14b" in model_info .get ("model_cls" , "" ) else "1.3b"
289- resolution = (model_info .get ("target_width" , 832 ), model_info .get ("target_height" , 480 ))
339+ resolution = (
340+ model_info .get ("target_width" , 832 ),
341+ model_info .get ("target_height" , 480 ),
342+ )
290343
291- coeffs = CoefficientCalculator .get_coefficients (task , model_size , resolution , updates ["use_ret_steps" ])
344+ coeffs = CoefficientCalculator .get_coefficients (
345+ task , model_size , resolution , updates ["use_ret_steps" ]
346+ )
292347 updates ["coefficients" ] = coeffs
293348 else :
294349 updates ["feature_caching" ] = "NoCaching"
295350
296351 return updates
297352
298- def apply_quantization_config (self , config : Dict [str , Any ], model_path : str ) -> Dict [str , Any ]:
353+ def apply_quantization_config (
354+ self , config : Dict [str , Any ], model_path : str
355+ ) -> Dict [str , Any ]:
299356 """Apply quantization configuration."""
300357 updates = {}
301358
@@ -309,14 +366,18 @@ def apply_quantization_config(self, config: Dict[str, Any], model_path: str) ->
309366 updates ["t5_quantized" ] = t5_scheme != "bf16"
310367 if t5_scheme != "bf16" :
311368 t5_path = os .path .join (model_path , t5_scheme )
312- updates ["t5_quantized_ckpt" ] = os .path .join (t5_path , f"models_t5_umt5-xxl-enc-{ t5_scheme } .pth" )
369+ updates ["t5_quantized_ckpt" ] = os .path .join (
370+ t5_path , f"models_t5_umt5-xxl-enc-{ t5_scheme } .pth"
371+ )
313372
314373 clip_scheme = config .get ("clip_precision" , "fp16" )
315374 updates ["clip_quant_scheme" ] = clip_scheme
316375 updates ["clip_quantized" ] = clip_scheme != "fp16"
317376 if clip_scheme != "fp16" :
318377 clip_path = os .path .join (model_path , clip_scheme )
319- updates ["clip_quantized_ckpt" ] = os .path .join (clip_path , f"clip-{ clip_scheme } .pth" )
378+ updates ["clip_quantized_ckpt" ] = os .path .join (
379+ clip_path , f"clip-{ clip_scheme } .pth"
380+ )
320381
321382 quant_backend = config .get ("quant_backend" , "vllm" )
322383 updates ["quant_op" ] = quant_backend
@@ -330,7 +391,9 @@ def apply_quantization_config(self, config: Dict[str, Any], model_path: str) ->
330391 else :
331392 mm_type = f"W-{ dit_scheme } -channel-sym-A-{ dit_scheme } -channel-sym-dynamic-Sgl"
332393 elif quant_backend == "q8f" :
333- mm_type = f"W-{ dit_scheme } -channel-sym-A-{ dit_scheme } -channel-sym-dynamic-Q8F"
394+ mm_type = (
395+ f"W-{ dit_scheme } -channel-sym-A-{ dit_scheme } -channel-sym-dynamic-Q8F"
396+ )
334397 else :
335398 mm_type = "Default"
336399
@@ -357,15 +420,21 @@ def apply_memory_optimization(self, config: Dict[str, Any]) -> Dict[str, Any]:
357420 updates ["clean_cuda_cache" ] = True
358421
359422 # CPU offloading
360- if config .get ("enable_cpu_offload" , False ) or level in ["medium" , "high" , "extreme" ]:
423+ if config .get ("enable_cpu_offload" , False ) or level in [
424+ "medium" ,
425+ "high" ,
426+ "extreme" ,
427+ ]:
361428 updates ["cpu_offload" ] = True
362429 updates ["offload_granularity" ] = config .get ("offload_granularity" , "phase" )
363430 updates ["offload_ratio" ] = config .get ("offload_ratio" , 1.0 )
364431
365432 # T5 offloading
366433 if level in ["high" , "extreme" ]:
367434 updates ["t5_cpu_offload" ] = True
368- updates ["t5_offload_granularity" ] = "block" if level == "extreme" else "model"
435+ updates ["t5_offload_granularity" ] = (
436+ "block" if level == "extreme" else "model"
437+ )
369438
370439 # Module management
371440 if config .get ("lazy_load" , False ) or level == "extreme" :
@@ -383,7 +452,9 @@ def apply_memory_optimization(self, config: Dict[str, Any]) -> Dict[str, Any]:
383452
384453 return updates
385454
386- def apply_vae_config (self , config : Dict [str , Any ], model_path : str ) -> Dict [str , Any ]:
455+ def apply_vae_config (
456+ self , config : Dict [str , Any ], model_path : str
457+ ) -> Dict [str , Any ]:
387458 """Apply VAE configuration."""
388459 updates = {}
389460
@@ -413,7 +484,9 @@ def build_final_config(self, configs: Dict[str, Dict[str, Any]]) -> EasyDict:
413484
414485 if "quantization" in configs :
415486 model_path = final_config .get ("model_path" , "" )
416- quant_updates = self .apply_quantization_config (configs ["quantization" ], model_path )
487+ quant_updates = self .apply_quantization_config (
488+ configs ["quantization" ], model_path
489+ )
417490 final_config .update (quant_updates )
418491
419492 if "memory" in configs :
0 commit comments