@@ -157,6 +157,7 @@ def _set_weight(
157157 group = tp_group ,
158158 )
159159 del splited_weights
160+ torch .cuda .empty_cache ()
160161 else :
161162 tensor = hf_weight
162163 if offset :
@@ -243,6 +244,7 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
243244 )
244245 tensor = torch .cat (output , dim = tp_dim )
245246 del output
247+ torch .cuda .empty_cache ()
246248 # pp/ep
247249 if pp_size > 1 :
248250 src_rank = torch .tensor ([0 if tensor is None else pp_rank ], dtype = torch .int64 , device = 'cuda' )
@@ -273,6 +275,22 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
273275 tensor = None
274276 return tensor
275277
278+ def _cpu_clone (self , tensor : Optional [torch .Tensor ]):
279+ if tensor is None :
280+ return None
281+ if not isinstance (tensor , torch .Tensor ):
282+ return tensor
283+ # Detach to avoid any autograd references
284+ t = tensor .detach ()
285+ if t .device .type != 'cpu' :
286+ # Move to CPU if not already (this will make a copy for sure)
287+ # `non_blocking=True` attempts an asynchronous copy for GPU->CPU when destination is
288+ # pinned memory; this is best-effort and will fall back to blocking if not possible.
289+ # https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html
290+ return t .to ('cpu' , non_blocking = True )
291+ else :
292+ return t .clone ()
293+
276294 def _set_state_dict (self ,
277295 mg_module ,
278296 mg_key : str ,
@@ -412,25 +430,28 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
412430 if lora_A is not None :
413431 self ._peft_target_modules .update ({'q_proj' , 'k_proj' , 'v_proj' })
414432 for key in ['q_proj' , 'k_proj' , 'v_proj' ]:
415- hf_state_dict [f'{ key } .lora_A.weight' ] = lora_A . clone ( )
433+ hf_state_dict [f'{ key } .lora_A.weight' ] = self . _cpu_clone ( lora_A )
416434 lora_B = lora_B .reshape ((num_query_groups , - 1 , lora_B .shape [- 1 ]))
417- hf_state_dict ['q_proj.lora_B.weight' ] = lora_B [:, :q_dim , :].reshape (- 1 , lora_B .shape [- 1 ]).clone ()
418- hf_state_dict ['k_proj.lora_B.weight' ] = lora_B [:,
419- q_dim :- kv_dim , :].reshape (- 1 ,
420- lora_B .shape [- 1 ]).clone ()
421- hf_state_dict ['v_proj.lora_B.weight' ] = lora_B [:, - kv_dim :, :].reshape (- 1 , lora_B .shape [- 1 ]).clone ()
435+ hf_state_dict ['q_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [:, :q_dim , :].reshape (
436+ - 1 , lora_B .shape [- 1 ]))
437+ hf_state_dict ['k_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [:, q_dim :- kv_dim , :].reshape (
438+ - 1 , lora_B .shape [- 1 ]))
439+ hf_state_dict ['v_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [:, - kv_dim :, :].reshape (
440+ - 1 , lora_B .shape [- 1 ]))
441+ torch .cuda .empty_cache ()
422442 elif not self ._is_peft_format :
423443 mg_attn_weight = self ._get_weight (None if mg_attn is None else mg_attn .linear_qkv .weight .data ,
424444 'linear_qkv.weight' )
425445 if mg_attn_weight is not None :
426446 mg_attn_weight = mg_attn_weight .reshape ((num_query_groups , - 1 , args .hidden_size ))
427- hf_state_dict ['q_proj.weight' ] = mg_attn_weight [:, :q_dim , :].reshape (- 1 , args . hidden_size ). clone ()
428- hf_state_dict [ 'k_proj.weight' ] = mg_attn_weight [:,
429- q_dim :- kv_dim , :].reshape (- 1 ,
430- args .hidden_size ). clone ( )
431- hf_state_dict ['v_proj.weight' ] = mg_attn_weight [:, - kv_dim :, :].reshape (- 1 ,
432- args .hidden_size ). clone ( )
447+ hf_state_dict ['q_proj.weight' ] = self . _cpu_clone ( mg_attn_weight [:, :q_dim , :].reshape (
448+ - 1 , args . hidden_size ))
449+ hf_state_dict [ 'k_proj.weight' ] = self . _cpu_clone ( mg_attn_weight [:, q_dim :- kv_dim , :].reshape (
450+ - 1 , args .hidden_size ))
451+ hf_state_dict ['v_proj.weight' ] = self . _cpu_clone ( mg_attn_weight [:, - kv_dim :, :].reshape (
452+ - 1 , args .hidden_size ))
433453 del mg_attn_weight
454+ torch .cuda .empty_cache ()
434455 self ._set_state_dict (mg_attn , 'linear_proj.weight' , hf_state_dict , 'o_proj.weight' , to_mcore )
435456
436457 # Copy bias
@@ -448,9 +469,9 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
448469 'linear_qkv.bias' )
449470 if mg_attn_bias is not None :
450471 mg_attn_bias = mg_attn_bias .reshape ((num_query_groups , - 1 ))
451- hf_state_dict ['q_proj.bias' ] = mg_attn_bias [:, :q_dim ].reshape (- 1 ). clone ( )
452- hf_state_dict ['k_proj.bias' ] = mg_attn_bias [:, q_dim :- kv_dim ].reshape (- 1 ). clone ( )
453- hf_state_dict ['v_proj.bias' ] = mg_attn_bias [:, - kv_dim :].reshape (- 1 ). clone ( )
472+ hf_state_dict ['q_proj.bias' ] = self . _cpu_clone ( mg_attn_bias [:, :q_dim ].reshape (- 1 ))
473+ hf_state_dict ['k_proj.bias' ] = self . _cpu_clone ( mg_attn_bias [:, q_dim :- kv_dim ].reshape (- 1 ))
474+ hf_state_dict ['v_proj.bias' ] = self . _cpu_clone ( mg_attn_bias [:, - kv_dim :].reshape (- 1 ))
454475 if args .qk_layernorm :
455476 hf_q_norm_key = 'q_norm.weight' if hasattr (hf_attn , 'q_norm' ) else 'query_layernorm.weight'
456477 hf_k_norm_key = 'k_norm.weight' if hasattr (hf_attn , 'k_norm' ) else 'key_layernorm.weight'
@@ -626,6 +647,7 @@ def _set_mlp_state(self,
626647 weight_list .append (torch .stack ([gate_proj_weight , up_proj_weight ], dim = 0 ))
627648 gate_up_proj_weight = torch .concat (weight_list , dim = 0 )
628649 del weight_list
650+ torch .cuda .empty_cache ()
629651 else :
630652 gate_proj_weight = hf_state_dict ['gate_proj.weight' ].load ()
631653 up_proj_weight = hf_state_dict ['up_proj.weight' ].load ()
@@ -637,6 +659,7 @@ def _set_mlp_state(self,
637659 getattr (mg_mlp .linear_fc1 ,
638660 f'weight{ i } ' ).data .copy_ (fc1_weight [i ].view (- 1 , fc1_weight .shape [- 1 ]))
639661 del fc1_weight
662+ torch .cuda .empty_cache ()
640663 else :
641664 mg_mlp .linear_fc1 .weight .data .copy_ (fc1_weight .view (- 1 , fc1_weight .shape [- 1 ]))
642665 else :
@@ -678,28 +701,33 @@ def _set_mlp_state(self,
678701 lora_B = lora_B .view (num_local_experts , - 1 , lora_B .shape [- 1 ])
679702 for i in range (num_local_experts ):
680703 hf_i = i + ep_rank * num_local_experts
681- hf_state_dict [f'{ hf_i } .gate_up_proj.lora_A.weight' ] = lora_A [i ].clone ()
682- hf_state_dict [f'{ hf_i } .gate_up_proj.lora_B.weight' ] = lora_B [i ].clone ()
704+ hf_state_dict [f'{ hf_i } .gate_up_proj.lora_A.weight' ] = self ._cpu_clone (lora_A [i ])
705+ hf_state_dict [f'{ hf_i } .gate_up_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [i ])
706+ torch .cuda .empty_cache ()
683707
684708 else :
685- hf_state_dict ['gate_up_proj.lora_A.weight' ] = lora_A .clone ()
686- hf_state_dict ['gate_up_proj.lora_B.weight' ] = lora_B .view (- 1 , lora_B .shape [- 1 ]).clone ()
709+ hf_state_dict ['gate_up_proj.lora_A.weight' ] = self ._cpu_clone (lora_A )
710+ hf_state_dict ['gate_up_proj.lora_B.weight' ] = self ._cpu_clone (
711+ lora_B .view (- 1 , lora_B .shape [- 1 ]))
712+ torch .cuda .empty_cache ()
687713 else :
688714 self ._peft_target_modules .update ({'gate_proj' , 'up_proj' })
689715 if is_expert :
690716 lora_A = lora_A .view (num_local_experts , - 1 , lora_A .shape [- 1 ])
691717 lora_B = lora_B .view (num_local_experts , 2 , - 1 , lora_B .shape [- 1 ])
692718 for i in range (num_local_experts ):
693719 hf_i = i + ep_rank * num_local_experts
694- hf_state_dict [f'{ hf_i } .gate_proj.lora_A.weight' ] = lora_A [i ].clone ()
695- hf_state_dict [f'{ hf_i } .up_proj.lora_A.weight' ] = lora_A [i ].clone ()
696- hf_state_dict [f'{ hf_i } .gate_proj.lora_B.weight' ] = lora_B [i ][0 ].clone ()
697- hf_state_dict [f'{ hf_i } .up_proj.lora_B.weight' ] = lora_B [i ][1 ].clone ()
720+ hf_state_dict [f'{ hf_i } .gate_proj.lora_A.weight' ] = self ._cpu_clone (lora_A [i ])
721+ hf_state_dict [f'{ hf_i } .up_proj.lora_A.weight' ] = self ._cpu_clone (lora_A [i ])
722+ hf_state_dict [f'{ hf_i } .gate_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [i ][0 ])
723+ hf_state_dict [f'{ hf_i } .up_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [i ][1 ])
724+ torch .cuda .empty_cache ()
698725 else :
699- hf_state_dict ['gate_proj.lora_A.weight' ] = lora_A .clone ()
700- hf_state_dict ['up_proj.lora_A.weight' ] = lora_A .clone ()
701- hf_state_dict ['gate_proj.lora_B.weight' ] = lora_B [0 ].clone ()
702- hf_state_dict ['up_proj.lora_B.weight' ] = lora_B [1 ].clone ()
726+ hf_state_dict ['gate_proj.lora_A.weight' ] = self ._cpu_clone (lora_A )
727+ hf_state_dict ['up_proj.lora_A.weight' ] = self ._cpu_clone (lora_A )
728+ hf_state_dict ['gate_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [0 ])
729+ hf_state_dict ['up_proj.lora_B.weight' ] = self ._cpu_clone (lora_B [1 ])
730+ torch .cuda .empty_cache ()
703731 elif not self ._is_peft_format :
704732 if mg_mlp is None :
705733 fc1_weight = None
@@ -725,27 +753,29 @@ def _set_mlp_state(self,
725753 if 'gate_up_proj' in hf_state_dict :
726754 gate_up_proj_weight = torch .concat (
727755 [hf_state_dict ['gate_up_proj' ], gate_up_proj_weight ], dim = 0 )
728- hf_state_dict ['gate_up_proj' ] = gate_up_proj_weight . clone ( )
756+ hf_state_dict ['gate_up_proj' ] = self . _cpu_clone ( gate_up_proj_weight )
729757 else :
730758 for i in range (num_local_experts ):
731759 hf_i = i + ep_rank * num_local_experts
732- hf_state_dict [f'{ hf_i } .gate_up_proj.weight' ] = gate_up_proj_weight [i ].clone ()
733- del gate_up_proj_weight
760+ hf_state_dict [f'{ hf_i } .gate_up_proj.weight' ] = self ._cpu_clone (
761+ gate_up_proj_weight [i ])
762+ del gate_up_proj_weight
763+ torch .cuda .empty_cache ()
734764 else :
735- hf_state_dict ['gate_up_proj.weight' ] = gate_up_proj_weight . view (
736- - 1 , gate_up_proj_weight .shape [- 1 ]). clone ( )
765+ hf_state_dict ['gate_up_proj.weight' ] = self . _cpu_clone (
766+ gate_up_proj_weight . view ( - 1 , gate_up_proj_weight .shape [- 1 ]))
737767 else :
738768 if is_expert :
739769 gate_up_proj_weight = gate_up_proj_weight .view (num_local_experts , 2 , - 1 ,
740770 gate_up_proj_weight .shape [- 1 ])
741771 for i in range (num_local_experts ):
742772 hf_i = i + ep_rank * num_local_experts
743- hf_state_dict [f'{ hf_i } .gate_proj.weight' ] = gate_up_proj_weight [i ][0 ]. clone ( )
744- hf_state_dict [f'{ hf_i } .up_proj.weight' ] = gate_up_proj_weight [i ][1 ]. clone ( )
773+ hf_state_dict [f'{ hf_i } .gate_proj.weight' ] = self . _cpu_clone ( gate_up_proj_weight [i ][0 ])
774+ hf_state_dict [f'{ hf_i } .up_proj.weight' ] = self . _cpu_clone ( gate_up_proj_weight [i ][1 ])
745775 del gate_up_proj_weight
746776 else :
747- hf_state_dict ['gate_proj.weight' ] = gate_up_proj_weight [0 ]. clone ( )
748- hf_state_dict ['up_proj.weight' ] = gate_up_proj_weight [1 ]. clone ( )
777+ hf_state_dict ['gate_proj.weight' ] = self . _cpu_clone ( gate_up_proj_weight [0 ])
778+ hf_state_dict ['up_proj.weight' ] = self . _cpu_clone ( gate_up_proj_weight [1 ])
749779 # linear_fc2
750780 if is_expert :
751781 if to_mcore :
@@ -825,8 +855,8 @@ def _set_mlp_state(self,
825855 lora_B = lora_B .view (num_local_experts , - 1 , lora_B .shape [- 1 ])
826856 for i in range (num_local_experts ):
827857 hf_i = i + ep_rank * num_local_experts
828- hf_state_dict [f'{ hf_i } .down_proj.lora_A.weight' ] = lora_A [i ]. clone ( )
829- hf_state_dict [f'{ hf_i } .down_proj.lora_B.weight' ] = lora_B [i ]. clone ( )
858+ hf_state_dict [f'{ hf_i } .down_proj.lora_A.weight' ] = self . _cpu_clone ( lora_A [i ])
859+ hf_state_dict [f'{ hf_i } .down_proj.lora_B.weight' ] = self . _cpu_clone ( lora_B [i ])
830860 elif not self ._is_peft_format :
831861 if mg_mlp is None :
832862 fc2_weight = None
@@ -838,17 +868,20 @@ def _set_mlp_state(self,
838868 dim = 0 )
839869 down_proj_weight = self ._get_weight (fc2_weight , 'linear_fc2.weight' , is_expert = is_expert )
840870 del fc2_weight
871+ torch .cuda .empty_cache ()
841872 if down_proj_weight is not None :
842873 down_proj_weight = down_proj_weight .view (num_local_experts , - 1 , down_proj_weight .shape [- 1 ])
843874 if hf_grouped :
844875 down_proj_weight = down_proj_weight .transpose (1 , 2 )
845876 if 'down_proj' in hf_state_dict :
846877 down_proj_weight = torch .concat ([hf_state_dict ['down_proj' ], down_proj_weight ], dim = 0 )
847- hf_state_dict ['down_proj' ] = down_proj_weight . clone ( )
878+ hf_state_dict ['down_proj' ] = self . _cpu_clone ( down_proj_weight )
848879 else :
849880 for i in range (num_local_experts ):
850881 hf_i = i + ep_rank * num_local_experts
851- hf_state_dict [f'{ hf_i } .down_proj.weight' ] = down_proj_weight [i ].clone ()
882+ hf_state_dict [f'{ hf_i } .down_proj.weight' ] = self ._cpu_clone (down_proj_weight [i ])
883+ del down_proj_weight
884+ torch .cuda .empty_cache ()
852885 else :
853886 self ._set_state_dict (
854887 mg_mlp , 'linear_fc2.weight' , hf_state_dict , 'down_proj.weight' , to_mcore , is_expert = is_expert )
0 commit comments