@@ -378,40 +378,7 @@ def get_extra_modules(self, block):
378378 def get_moe_gate (self , block ):
379379 return None
380380
381- def replace_vision_module_all (self , module , params_dict , keep_device = False ):
382- vision_model_linears = self .get_block_linears (self .vision_model )
383- for name , m in vision_model_linears .items ():
384- M = module .new (m , ** params_dict )
385-
386- name_tmp = name .rsplit ('.' , 1 )
387- if len (name_tmp ) == 2 :
388- parent_name = name_tmp [0 ]
389- parent = self .vision_model .get_submodule (parent_name )
390- child_name = name_tmp [1 ]
391- elif len (name_tmp ) == 1 :
392- parent = self .vision_model
393- child_name = name_tmp [0 ]
394-
395- setattr (parent , child_name , M )
396-
397- gc .collect ()
398- torch .cuda .empty_cache ()
399- logger .info (f'The Replaced vision_model: { self .vision_model } ' )
400-
401- def replace_language_module_all (self , module , params_dict , keep_device = False ):
402- for block_idx in range (len (self .blocks )):
403- logger .info (f'Replace block index: { block_idx } /{ len (self .blocks )} ' )
404- if keep_device :
405- self .replace_module_block (module , self .blocks [block_idx ], block_idx , params_dict )
406- else :
407- self .blocks [block_idx ].cuda ()
408- self .replace_module_block (module , self .blocks [block_idx ], block_idx , params_dict )
409- self .blocks [block_idx ].cpu ()
410- gc .collect ()
411- torch .cuda .empty_cache ()
412- logger .info (f'The Replaced model: { self .model } ' )
413-
414- def replace_video_gen_module_all (self , module , params_dict , keep_device = False ):
381+ def replace_module_all (self , module , params_dict , keep_device = False ):
415382 for block_idx in range (len (self .blocks )):
416383 logger .info (f'Replace block index: { block_idx } /{ len (self .blocks )} ' )
417384 if keep_device :
@@ -422,7 +389,6 @@ def replace_video_gen_module_all(self, module, params_dict, keep_device=False):
422389 self .blocks [block_idx ].cpu ()
423390 gc .collect ()
424391 torch .cuda .empty_cache ()
425- logger .info (f'The Replaced model: { self .model } ' )
426392
427393 def replace_module_block (self , module , block , block_idx , params_dict ):
428394 if module in _LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_ :
0 commit comments