Skip to content

Commit 6f4a17f

Browse files
committed
[megatron] fix: make bridge exported cloned weights store on CPU
There are so many `clone` in the GPTBridge code, if we make a copy of the weight on GPU, OOM can easily happen. This PR tries to address this issue by using cpu clone as a solution, so that those weight will be cloned to CPU. This PR also add `torch.cuda.empty_cache` to mitigate the possibility of OOM during LoRA merge. Signed-off-by: Hollow Man <[email protected]>
1 parent 6cfaeaa commit 6f4a17f

File tree

2 files changed

+67
-40
lines changed

2 files changed

+67
-40
lines changed

swift/megatron/model/gpt_bridge.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _set_weight(
157157
group=tp_group,
158158
)
159159
del splited_weights
160+
torch.cuda.empty_cache()
160161
else:
161162
tensor = hf_weight
162163
if offset:
@@ -243,6 +244,7 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
243244
)
244245
tensor = torch.cat(output, dim=tp_dim)
245246
del output
247+
torch.cuda.empty_cache()
246248
# pp/ep
247249
if pp_size > 1:
248250
src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda')
@@ -273,6 +275,16 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
273275
tensor = None
274276
return tensor
275277

278+
def _cpu_clone(self, tensor: Optional[torch.Tensor]):
279+
if tensor is None:
280+
return None
281+
if not isinstance(tensor, torch.Tensor):
282+
return tensor
283+
# Detach to avoid any autograd references and create a copy on the CPU.
284+
# `non_blocking=True` is a best-effort for async copy from GPU.
285+
# `copy=True` ensures a new tensor is created even if it's already on CPU.
286+
return tensor.detach().to('cpu', copy=True, non_blocking=True)
287+
276288
def _set_state_dict(self,
277289
mg_module,
278290
mg_key: str,
@@ -412,25 +424,27 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
412424
if lora_A is not None:
413425
self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'})
414426
for key in ['q_proj', 'k_proj', 'v_proj']:
415-
hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
427+
hf_state_dict[f'{key}.lora_A.weight'] = self._cpu_clone(lora_A)
416428
lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1]))
417-
hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone()
418-
hf_state_dict['k_proj.lora_B.weight'] = lora_B[:,
419-
q_dim:-kv_dim, :].reshape(-1,
420-
lora_B.shape[-1]).clone()
421-
hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone()
429+
hf_state_dict['q_proj.lora_B.weight'] = self._cpu_clone(lora_B[:, :q_dim, :].reshape(
430+
-1, lora_B.shape[-1]))
431+
hf_state_dict['k_proj.lora_B.weight'] = self._cpu_clone(lora_B[:, q_dim:-kv_dim, :].reshape(
432+
-1, lora_B.shape[-1]))
433+
hf_state_dict['v_proj.lora_B.weight'] = self._cpu_clone(lora_B[:, -kv_dim:, :].reshape(
434+
-1, lora_B.shape[-1]))
422435
elif not self._is_peft_format:
423436
mg_attn_weight = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.weight.data,
424437
'linear_qkv.weight')
425438
if mg_attn_weight is not None:
426439
mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size))
427-
hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone()
428-
hf_state_dict['k_proj.weight'] = mg_attn_weight[:,
429-
q_dim:-kv_dim, :].reshape(-1,
430-
args.hidden_size).clone()
431-
hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1,
432-
args.hidden_size).clone()
440+
hf_state_dict['q_proj.weight'] = self._cpu_clone(mg_attn_weight[:, :q_dim, :].reshape(
441+
-1, args.hidden_size))
442+
hf_state_dict['k_proj.weight'] = self._cpu_clone(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(
443+
-1, args.hidden_size))
444+
hf_state_dict['v_proj.weight'] = self._cpu_clone(mg_attn_weight[:, -kv_dim:, :].reshape(
445+
-1, args.hidden_size))
433446
del mg_attn_weight
447+
torch.cuda.empty_cache()
434448
self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore)
435449

436450
# Copy bias
@@ -448,9 +462,9 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
448462
'linear_qkv.bias')
449463
if mg_attn_bias is not None:
450464
mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1))
451-
hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone()
452-
hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone()
453-
hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone()
465+
hf_state_dict['q_proj.bias'] = self._cpu_clone(mg_attn_bias[:, :q_dim].reshape(-1))
466+
hf_state_dict['k_proj.bias'] = self._cpu_clone(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
467+
hf_state_dict['v_proj.bias'] = self._cpu_clone(mg_attn_bias[:, -kv_dim:].reshape(-1))
454468
if args.qk_layernorm:
455469
hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight'
456470
hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight'
@@ -630,6 +644,7 @@ def _set_mlp_state(self,
630644
gate_proj_weight = hf_state_dict['gate_proj.weight'].load()
631645
up_proj_weight = hf_state_dict['up_proj.weight'].load()
632646
gate_up_proj_weight = torch.stack([gate_proj_weight, up_proj_weight], dim=0)
647+
torch.cuda.empty_cache()
633648
self._set_weight(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight', is_expert=is_expert)
634649
if is_expert:
635650
fc1_weight = fc1_weight.view(num_local_experts, -1, fc1_weight.shape[-1])
@@ -668,6 +683,7 @@ def _set_mlp_state(self,
668683
lora_A = mg_mlp.linear_fc1.lora_A[self._adapter_name].weight
669684
lora_B = mg_mlp.linear_fc1.lora_B[self._adapter_name].weight
670685
lora_B = lora_B.view(num_local_experts * 2, -1, lora_B.shape[1])
686+
torch.cuda.empty_cache()
671687
lora_A = self._get_weight(lora_A, f'linear_fc1.lora_A.{self._adapter_name}.weight', is_expert=is_expert)
672688
lora_B = self._get_weight(lora_B, f'linear_fc1.lora_B.{self._adapter_name}.weight', is_expert=is_expert)
673689
if lora_A is not None:
@@ -678,28 +694,29 @@ def _set_mlp_state(self,
678694
lora_B = lora_B.view(num_local_experts, -1, lora_B.shape[-1])
679695
for i in range(num_local_experts):
680696
hf_i = i + ep_rank * num_local_experts
681-
hf_state_dict[f'{hf_i}.gate_up_proj.lora_A.weight'] = lora_A[i].clone()
682-
hf_state_dict[f'{hf_i}.gate_up_proj.lora_B.weight'] = lora_B[i].clone()
697+
hf_state_dict[f'{hf_i}.gate_up_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
698+
hf_state_dict[f'{hf_i}.gate_up_proj.lora_B.weight'] = self._cpu_clone(lora_B[i])
683699

684700
else:
685-
hf_state_dict['gate_up_proj.lora_A.weight'] = lora_A.clone()
686-
hf_state_dict['gate_up_proj.lora_B.weight'] = lora_B.view(-1, lora_B.shape[-1]).clone()
701+
hf_state_dict['gate_up_proj.lora_A.weight'] = self._cpu_clone(lora_A)
702+
hf_state_dict['gate_up_proj.lora_B.weight'] = self._cpu_clone(
703+
lora_B.view(-1, lora_B.shape[-1]))
687704
else:
688705
self._peft_target_modules.update({'gate_proj', 'up_proj'})
689706
if is_expert:
690707
lora_A = lora_A.view(num_local_experts, -1, lora_A.shape[-1])
691708
lora_B = lora_B.view(num_local_experts, 2, -1, lora_B.shape[-1])
692709
for i in range(num_local_experts):
693710
hf_i = i + ep_rank * num_local_experts
694-
hf_state_dict[f'{hf_i}.gate_proj.lora_A.weight'] = lora_A[i].clone()
695-
hf_state_dict[f'{hf_i}.up_proj.lora_A.weight'] = lora_A[i].clone()
696-
hf_state_dict[f'{hf_i}.gate_proj.lora_B.weight'] = lora_B[i][0].clone()
697-
hf_state_dict[f'{hf_i}.up_proj.lora_B.weight'] = lora_B[i][1].clone()
711+
hf_state_dict[f'{hf_i}.gate_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
712+
hf_state_dict[f'{hf_i}.up_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
713+
hf_state_dict[f'{hf_i}.gate_proj.lora_B.weight'] = self._cpu_clone(lora_B[i][0])
714+
hf_state_dict[f'{hf_i}.up_proj.lora_B.weight'] = self._cpu_clone(lora_B[i][1])
698715
else:
699-
hf_state_dict['gate_proj.lora_A.weight'] = lora_A.clone()
700-
hf_state_dict['up_proj.lora_A.weight'] = lora_A.clone()
701-
hf_state_dict['gate_proj.lora_B.weight'] = lora_B[0].clone()
702-
hf_state_dict['up_proj.lora_B.weight'] = lora_B[1].clone()
716+
hf_state_dict['gate_proj.lora_A.weight'] = self._cpu_clone(lora_A)
717+
hf_state_dict['up_proj.lora_A.weight'] = self._cpu_clone(lora_A)
718+
hf_state_dict['gate_proj.lora_B.weight'] = self._cpu_clone(lora_B[0])
719+
hf_state_dict['up_proj.lora_B.weight'] = self._cpu_clone(lora_B[1])
703720
elif not self._is_peft_format:
704721
if mg_mlp is None:
705722
fc1_weight = None
@@ -725,27 +742,29 @@ def _set_mlp_state(self,
725742
if 'gate_up_proj' in hf_state_dict:
726743
gate_up_proj_weight = torch.concat(
727744
[hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0)
728-
hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone()
745+
hf_state_dict['gate_up_proj'] = self._cpu_clone(gate_up_proj_weight)
729746
else:
730747
for i in range(num_local_experts):
731748
hf_i = i + ep_rank * num_local_experts
732-
hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].clone()
749+
hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = self._cpu_clone(
750+
gate_up_proj_weight[i])
733751
del gate_up_proj_weight
734752
else:
735-
hf_state_dict['gate_up_proj.weight'] = gate_up_proj_weight.view(
736-
-1, gate_up_proj_weight.shape[-1]).clone()
753+
hf_state_dict['gate_up_proj.weight'] = self._cpu_clone(
754+
gate_up_proj_weight.view(-1, gate_up_proj_weight.shape[-1]))
737755
else:
738756
if is_expert:
739757
gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1,
740758
gate_up_proj_weight.shape[-1])
741759
for i in range(num_local_experts):
742760
hf_i = i + ep_rank * num_local_experts
743-
hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone()
744-
hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone()
745-
del gate_up_proj_weight
761+
hf_state_dict[f'{hf_i}.gate_proj.weight'] = self._cpu_clone(gate_up_proj_weight[i][0])
762+
hf_state_dict[f'{hf_i}.up_proj.weight'] = self._cpu_clone(gate_up_proj_weight[i][1])
746763
else:
747-
hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone()
748-
hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone()
764+
hf_state_dict['gate_proj.weight'] = self._cpu_clone(gate_up_proj_weight[0])
765+
hf_state_dict['up_proj.weight'] = self._cpu_clone(gate_up_proj_weight[1])
766+
del gate_up_proj_weight
767+
torch.cuda.empty_cache()
749768
# linear_fc2
750769
if is_expert:
751770
if to_mcore:
@@ -825,8 +844,8 @@ def _set_mlp_state(self,
825844
lora_B = lora_B.view(num_local_experts, -1, lora_B.shape[-1])
826845
for i in range(num_local_experts):
827846
hf_i = i + ep_rank * num_local_experts
828-
hf_state_dict[f'{hf_i}.down_proj.lora_A.weight'] = lora_A[i].clone()
829-
hf_state_dict[f'{hf_i}.down_proj.lora_B.weight'] = lora_B[i].clone()
847+
hf_state_dict[f'{hf_i}.down_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
848+
hf_state_dict[f'{hf_i}.down_proj.lora_B.weight'] = self._cpu_clone(lora_B[i])
830849
elif not self._is_peft_format:
831850
if mg_mlp is None:
832851
fc2_weight = None
@@ -838,20 +857,23 @@ def _set_mlp_state(self,
838857
dim=0)
839858
down_proj_weight = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=is_expert)
840859
del fc2_weight
860+
torch.cuda.empty_cache()
841861
if down_proj_weight is not None:
842862
down_proj_weight = down_proj_weight.view(num_local_experts, -1, down_proj_weight.shape[-1])
843863
if hf_grouped:
844864
down_proj_weight = down_proj_weight.transpose(1, 2)
845865
if 'down_proj' in hf_state_dict:
846866
down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0)
847-
hf_state_dict['down_proj'] = down_proj_weight.clone()
867+
hf_state_dict['down_proj'] = self._cpu_clone(down_proj_weight)
848868
else:
849869
for i in range(num_local_experts):
850870
hf_i = i + ep_rank * num_local_experts
851-
hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].clone()
871+
hf_state_dict[f'{hf_i}.down_proj.weight'] = self._cpu_clone(down_proj_weight[i])
872+
del down_proj_weight
852873
else:
853874
self._set_state_dict(
854875
mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', to_mcore, is_expert=is_expert)
876+
torch.cuda.empty_cache()
855877
if to_mcore:
856878
hf_state_dict = {}
857879
else:

swift/megatron/tuners/lora.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,13 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
417417
weight.data = orig_weights[i]
418418
else:
419419
base_layer.weight.data = orig_weights[0]
420+
del orig_weights
420421
else:
421422
delta_weights = self.get_delta_weights(active_adapter)
422423
for orig_weight, delta_weight in zip(orig_weights, delta_weights):
423424
orig_weight.data += delta_weight
425+
del delta_weights
426+
torch.cuda.empty_cache()
424427
self.merged_adapters.append(active_adapter)
425428
if origin_device.type == 'cpu':
426429
self.to(device=origin_device)
@@ -452,6 +455,8 @@ def unmerge(self) -> None:
452455
for orig_weight, delta_weight in zip(orig_weights, delta_weights):
453456
# Subtract the delta weight to unmerge
454457
orig_weight.data -= delta_weight
458+
del delta_weights
459+
torch.cuda.empty_cache()
455460

456461
# Clear the merged adapters list
457462
self.merged_adapters = []

0 commit comments

Comments
 (0)