@@ -157,6 +157,7 @@ def _set_weight(
157157 group = tp_group ,
158158 )
159159 del splited_weights
160+ torch .cuda .empty_cache ()
160161 else :
161162 tensor = hf_weight
162163 if offset :
@@ -243,6 +244,7 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
243244 )
244245 tensor = torch .cat (output , dim = tp_dim )
245246 del output
247+ torch .cuda .empty_cache ()
246248 # pp/ep
247249 if pp_size > 1 :
248250 src_rank = torch .tensor ([0 if tensor is None else pp_rank ], dtype = torch .int64 , device = 'cuda' )
@@ -273,6 +275,16 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
273275 tensor = None
274276 return tensor
275277
278+ def _cpu_clone (self , tensor : Optional [torch .Tensor ]):
279+ if tensor is None :
280+ return None
281+ if not isinstance (tensor , torch .Tensor ):
282+ return tensor
283+ # Detach to avoid any autograd references and create a copy on the CPU.
284+ # `non_blocking=True` is a best-effort for async copy from GPU.
285+ # `copy=True` ensures a new tensor is created even if it's already on CPU.
286+ return tensor .detach ().to ('cpu' , copy = True , non_blocking = True )
287+
276288 def _set_state_dict (self ,
277289 mg_module ,
278290 mg_key : str ,
@@ -412,25 +424,27 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
412424 if lora_A is not None :
413425 self ._peft_target_modules .update ({'q_proj' , 'k_proj' , 'v_proj' })
414426 for key in ['q_proj' , 'k_proj' , 'v_proj' ]:
415- hf_state_dict [f'{ key } .lora_A.weight' ] = lora_A . clone ( )
427+ hf_state_dict [f'{ key } .lora_A.weight' ] = self . _cpu_clone ( lora_A )
416428 lora_B = lora_B .reshape ((num_query_groups , - 1 , lora_B .shape [- 1 ]))
417- hf_state_dict ['q_proj.lora_B.weight' ] = lora_B [:, :q_dim , :].reshape (- 1 , lora_B .shape [- 1 ]).clone ()
418- hf_state_dict ['k_proj.lora_B.weight' ] = lora_B [:,
419- q_dim :- kv_dim , :].reshape (- 1 ,
420- lora_B .shape [- 1 ]).clone ()
421- hf_state_dict ['v_proj.lora_B.weight' ] = lora_B [:, - kv_dim :, :].reshape (- 1 , lora_B .shape [- 1 ]).clone ()
429+ hf_state_dict ['q_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [:, :q_dim , :].reshape (
430+ - 1 , lora_B .shape [- 1 ]))
431+ hf_state_dict ['k_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [:, q_dim :- kv_dim , :].reshape (
432+ - 1 , lora_B .shape [- 1 ]))
433+ hf_state_dict ['v_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [:, - kv_dim :, :].reshape (
434+ - 1 , lora_B .shape [- 1 ]))
422435 elif not self ._is_peft_format :
423436 mg_attn_weight = self ._get_weight (None if mg_attn is None else mg_attn .linear_qkv .weight .data ,
424437 'linear_qkv.weight' )
425438 if mg_attn_weight is not None :
426439 mg_attn_weight = mg_attn_weight .reshape ((num_query_groups , - 1 , args .hidden_size ))
427- hf_state_dict ['q_proj.weight' ] = mg_attn_weight [:, :q_dim , :].reshape (- 1 , args . hidden_size ). clone ()
428- hf_state_dict [ 'k_proj.weight' ] = mg_attn_weight [:,
429- q_dim :- kv_dim , :].reshape (- 1 ,
430- args .hidden_size ). clone ( )
431- hf_state_dict ['v_proj.weight' ] = mg_attn_weight [:, - kv_dim :, :].reshape (- 1 ,
432- args .hidden_size ). clone ( )
440+ hf_state_dict ['q_proj.weight' ] = self . _cpu_clone ( mg_attn_weight [:, :q_dim , :].reshape (
441+ - 1 , args . hidden_size ))
442+ hf_state_dict [ 'k_proj.weight' ] = self . _cpu_clone ( mg_attn_weight [:, q_dim :- kv_dim , :].reshape (
443+ - 1 , args .hidden_size ))
444+ hf_state_dict ['v_proj.weight' ] = self . _cpu_clone ( mg_attn_weight [:, - kv_dim :, :].reshape (
445+ - 1 , args .hidden_size ))
433446 del mg_attn_weight
447+ torch .cuda .empty_cache ()
434448 self ._set_state_dict (mg_attn , 'linear_proj.weight' , hf_state_dict , 'o_proj.weight' , to_mcore )
435449
436450 # Copy bias
@@ -448,9 +462,9 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
448462 'linear_qkv.bias' )
449463 if mg_attn_bias is not None :
450464 mg_attn_bias = mg_attn_bias .reshape ((num_query_groups , - 1 ))
451- hf_state_dict ['q_proj.bias' ] = mg_attn_bias [:, :q_dim ].reshape (- 1 ). clone ( )
452- hf_state_dict ['k_proj.bias' ] = mg_attn_bias [:, q_dim :- kv_dim ].reshape (- 1 ). clone ( )
453- hf_state_dict ['v_proj.bias' ] = mg_attn_bias [:, - kv_dim :].reshape (- 1 ). clone ( )
465+ hf_state_dict ['q_proj.bias' ] = self . _cpu_clone ( mg_attn_bias [:, :q_dim ].reshape (- 1 ))
466+ hf_state_dict ['k_proj.bias' ] = self . _cpu_clone ( mg_attn_bias [:, q_dim :- kv_dim ].reshape (- 1 ))
467+ hf_state_dict ['v_proj.bias' ] = self . _cpu_clone ( mg_attn_bias [:, - kv_dim :].reshape (- 1 ))
454468 if args .qk_layernorm :
455469 hf_q_norm_key = 'q_norm.weight' if hasattr (hf_attn , 'q_norm' ) else 'query_layernorm.weight'
456470 hf_k_norm_key = 'k_norm.weight' if hasattr (hf_attn , 'k_norm' ) else 'key_layernorm.weight'
@@ -630,6 +644,7 @@ def _set_mlp_state(self,
630644 gate_proj_weight = hf_state_dict ['gate_proj.weight' ].load ()
631645 up_proj_weight = hf_state_dict ['up_proj.weight' ].load ()
632646 gate_up_proj_weight = torch .stack ([gate_proj_weight , up_proj_weight ], dim = 0 )
647+ torch .cuda .empty_cache ()
633648 self ._set_weight (fc1_weight , gate_up_proj_weight , 'linear_fc1.weight' , is_expert = is_expert )
634649 if is_expert :
635650 fc1_weight = fc1_weight .view (num_local_experts , - 1 , fc1_weight .shape [- 1 ])
@@ -668,6 +683,7 @@ def _set_mlp_state(self,
668683 lora_A = mg_mlp .linear_fc1 .lora_A [self ._adapter_name ].weight
669684 lora_B = mg_mlp .linear_fc1 .lora_B [self ._adapter_name ].weight
670685 lora_B = lora_B .view (num_local_experts * 2 , - 1 , lora_B .shape [1 ])
686+ torch .cuda .empty_cache ()
671687 lora_A = self ._get_weight (lora_A , f'linear_fc1.lora_A.{ self ._adapter_name } .weight' , is_expert = is_expert )
672688 lora_B = self ._get_weight (lora_B , f'linear_fc1.lora_B.{ self ._adapter_name } .weight' , is_expert = is_expert )
673689 if lora_A is not None :
@@ -678,28 +694,29 @@ def _set_mlp_state(self,
678694 lora_B = lora_B .view (num_local_experts , - 1 , lora_B .shape [- 1 ])
679695 for i in range (num_local_experts ):
680696 hf_i = i + ep_rank * num_local_experts
681- hf_state_dict [f'{ hf_i } .gate_up_proj.lora_A.weight' ] = lora_A [i ]. clone ( )
682- hf_state_dict [f'{ hf_i } .gate_up_proj.lora_B.weight' ] = lora_B [i ]. clone ( )
697+ hf_state_dict [f'{ hf_i } .gate_up_proj.lora_A.weight' ] = self . _cpu_clone ( lora_A [i ])
698+ hf_state_dict [f'{ hf_i } .gate_up_proj.lora_B.weight' ] = self . _cpu_clone ( lora_B [i ])
683699
684700 else :
685- hf_state_dict ['gate_up_proj.lora_A.weight' ] = lora_A .clone ()
686- hf_state_dict ['gate_up_proj.lora_B.weight' ] = lora_B .view (- 1 , lora_B .shape [- 1 ]).clone ()
701+ hf_state_dict ['gate_up_proj.lora_A.weight' ] = self ._cpu_clone (lora_A )
702+ hf_state_dict ['gate_up_proj.lora_B.weight' ] = self ._cpu_clone (
703+ lora_B .view (- 1 , lora_B .shape [- 1 ]))
687704 else :
688705 self ._peft_target_modules .update ({'gate_proj' , 'up_proj' })
689706 if is_expert :
690707 lora_A = lora_A .view (num_local_experts , - 1 , lora_A .shape [- 1 ])
691708 lora_B = lora_B .view (num_local_experts , 2 , - 1 , lora_B .shape [- 1 ])
692709 for i in range (num_local_experts ):
693710 hf_i = i + ep_rank * num_local_experts
694- hf_state_dict [f'{ hf_i } .gate_proj.lora_A.weight' ] = lora_A [i ]. clone ( )
695- hf_state_dict [f'{ hf_i } .up_proj.lora_A.weight' ] = lora_A [i ]. clone ( )
696- hf_state_dict [f'{ hf_i } .gate_proj.lora_B.weight' ] = lora_B [i ][0 ]. clone ( )
697- hf_state_dict [f'{ hf_i } .up_proj.lora_B.weight' ] = lora_B [i ][1 ]. clone ( )
711+ hf_state_dict [f'{ hf_i } .gate_proj.lora_A.weight' ] = self . _cpu_clone ( lora_A [i ])
712+ hf_state_dict [f'{ hf_i } .up_proj.lora_A.weight' ] = self . _cpu_clone ( lora_A [i ])
713+ hf_state_dict [f'{ hf_i } .gate_proj.lora_B.weight' ] = self . _cpu_clone ( lora_B [i ][0 ])
714+ hf_state_dict [f'{ hf_i } .up_proj.lora_B.weight' ] = self . _cpu_clone ( lora_B [i ][1 ])
698715 else :
699- hf_state_dict ['gate_proj.lora_A.weight' ] = lora_A . clone ( )
700- hf_state_dict ['up_proj.lora_A.weight' ] = lora_A . clone ( )
701- hf_state_dict ['gate_proj.lora_B.weight' ] = lora_B [0 ]. clone ( )
702- hf_state_dict ['up_proj.lora_B.weight' ] = lora_B [1 ]. clone ( )
716+ hf_state_dict ['gate_proj.lora_A.weight' ] = self . _cpu_clone ( lora_A )
717+ hf_state_dict ['up_proj.lora_A.weight' ] = self . _cpu_clone ( lora_A )
718+ hf_state_dict ['gate_proj.lora_B.weight' ] = self . _cpu_clone ( lora_B [0 ])
719+ hf_state_dict ['up_proj.lora_B.weight' ] = self . _cpu_clone ( lora_B [1 ])
703720 elif not self ._is_peft_format :
704721 if mg_mlp is None :
705722 fc1_weight = None
@@ -725,27 +742,29 @@ def _set_mlp_state(self,
725742 if 'gate_up_proj' in hf_state_dict :
726743 gate_up_proj_weight = torch .concat (
727744 [hf_state_dict ['gate_up_proj' ], gate_up_proj_weight ], dim = 0 )
728- hf_state_dict ['gate_up_proj' ] = gate_up_proj_weight . clone ( )
745+ hf_state_dict ['gate_up_proj' ] = self . _cpu_clone ( gate_up_proj_weight )
729746 else :
730747 for i in range (num_local_experts ):
731748 hf_i = i + ep_rank * num_local_experts
732- hf_state_dict [f'{ hf_i } .gate_up_proj.weight' ] = gate_up_proj_weight [i ].clone ()
749+ hf_state_dict [f'{ hf_i } .gate_up_proj.weight' ] = self ._cpu_clone (
750+ gate_up_proj_weight [i ])
733751 del gate_up_proj_weight
734752 else :
735- hf_state_dict ['gate_up_proj.weight' ] = gate_up_proj_weight . view (
736- - 1 , gate_up_proj_weight .shape [- 1 ]). clone ( )
753+ hf_state_dict ['gate_up_proj.weight' ] = self . _cpu_clone (
754+ gate_up_proj_weight . view ( - 1 , gate_up_proj_weight .shape [- 1 ]))
737755 else :
738756 if is_expert :
739757 gate_up_proj_weight = gate_up_proj_weight .view (num_local_experts , 2 , - 1 ,
740758 gate_up_proj_weight .shape [- 1 ])
741759 for i in range (num_local_experts ):
742760 hf_i = i + ep_rank * num_local_experts
743- hf_state_dict [f'{ hf_i } .gate_proj.weight' ] = gate_up_proj_weight [i ][0 ].clone ()
744- hf_state_dict [f'{ hf_i } .up_proj.weight' ] = gate_up_proj_weight [i ][1 ].clone ()
745- del gate_up_proj_weight
761+ hf_state_dict [f'{ hf_i } .gate_proj.weight' ] = self ._cpu_clone (gate_up_proj_weight [i ][0 ])
762+ hf_state_dict [f'{ hf_i } .up_proj.weight' ] = self ._cpu_clone (gate_up_proj_weight [i ][1 ])
746763 else :
747- hf_state_dict ['gate_proj.weight' ] = gate_up_proj_weight [0 ].clone ()
748- hf_state_dict ['up_proj.weight' ] = gate_up_proj_weight [1 ].clone ()
764+ hf_state_dict ['gate_proj.weight' ] = self ._cpu_clone (gate_up_proj_weight [0 ])
765+ hf_state_dict ['up_proj.weight' ] = self ._cpu_clone (gate_up_proj_weight [1 ])
766+ del gate_up_proj_weight
767+ torch .cuda .empty_cache ()
749768 # linear_fc2
750769 if is_expert :
751770 if to_mcore :
@@ -825,8 +844,8 @@ def _set_mlp_state(self,
825844 lora_B = lora_B .view (num_local_experts , - 1 , lora_B .shape [- 1 ])
826845 for i in range (num_local_experts ):
827846 hf_i = i + ep_rank * num_local_experts
828- hf_state_dict [f'{ hf_i } .down_proj.lora_A.weight' ] = lora_A [i ]. clone ( )
829- hf_state_dict [f'{ hf_i } .down_proj.lora_B.weight' ] = lora_B [i ]. clone ( )
847+ hf_state_dict [f'{ hf_i } .down_proj.lora_A.weight' ] = self . _cpu_clone ( lora_A [i ])
848+ hf_state_dict [f'{ hf_i } .down_proj.lora_B.weight' ] = self . _cpu_clone ( lora_B [i ])
830849 elif not self ._is_peft_format :
831850 if mg_mlp is None :
832851 fc2_weight = None
@@ -838,20 +857,23 @@ def _set_mlp_state(self,
838857 dim = 0 )
839858 down_proj_weight = self ._get_weight (fc2_weight , 'linear_fc2.weight' , is_expert = is_expert )
840859 del fc2_weight
860+ torch .cuda .empty_cache ()
841861 if down_proj_weight is not None :
842862 down_proj_weight = down_proj_weight .view (num_local_experts , - 1 , down_proj_weight .shape [- 1 ])
843863 if hf_grouped :
844864 down_proj_weight = down_proj_weight .transpose (1 , 2 )
845865 if 'down_proj' in hf_state_dict :
846866 down_proj_weight = torch .concat ([hf_state_dict ['down_proj' ], down_proj_weight ], dim = 0 )
847- hf_state_dict ['down_proj' ] = down_proj_weight . clone ( )
867+ hf_state_dict ['down_proj' ] = self . _cpu_clone ( down_proj_weight )
848868 else :
849869 for i in range (num_local_experts ):
850870 hf_i = i + ep_rank * num_local_experts
851- hf_state_dict [f'{ hf_i } .down_proj.weight' ] = down_proj_weight [i ].clone ()
871+ hf_state_dict [f'{ hf_i } .down_proj.weight' ] = self ._cpu_clone (down_proj_weight [i ])
872+ del down_proj_weight
852873 else :
853874 self ._set_state_dict (
854875 mg_mlp , 'linear_fc2.weight' , hf_state_dict , 'down_proj.weight' , to_mcore , is_expert = is_expert )
876+ torch .cuda .empty_cache ()
855877 if to_mcore :
856878 hf_state_dict = {}
857879 else :
0 commit comments