Skip to content

Commit 063e48c

Browse files
committed
[megatron] fix: make bridge exported cloned weights store on CPU
There are so many `clone` in the GPTBridge code, if we make a copy of the weight on GPU, OOM can easily happen. This PR tries to address this issue by using cpu clone as a solution, so that those weight will be cloned to CPU. This PR also add `torch.cuda.empty_cache` to mitigate the possibility of OOM during LoRA merge. Signed-off-by: Hollow Man <[email protected]>
1 parent 6cfaeaa commit 063e48c

File tree

2 files changed

+80
-40
lines changed

2 files changed

+80
-40
lines changed

swift/megatron/model/gpt_bridge.py

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _set_weight(
157157
group=tp_group,
158158
)
159159
del splited_weights
160+
torch.cuda.empty_cache()
160161
else:
161162
tensor = hf_weight
162163
if offset:
@@ -243,6 +244,7 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
243244
)
244245
tensor = torch.cat(output, dim=tp_dim)
245246
del output
247+
torch.cuda.empty_cache()
246248
# pp/ep
247249
if pp_size > 1:
248250
src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda')
@@ -273,6 +275,22 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
273275
tensor = None
274276
return tensor
275277

278+
def _cpu_clone(self, tensor: Optional[torch.Tensor]):
279+
if tensor is None:
280+
return None
281+
if not isinstance(tensor, torch.Tensor):
282+
return tensor
283+
# Detach to avoid any autograd references
284+
t = tensor.detach()
285+
if t.device.type != 'cpu':
286+
# Move to CPU if not already (this will make a copy for sure)
287+
# `non_blocking=True` attempts an asynchronous copy for GPU->CPU when destination is
288+
# pinned memory; this is best-effort and will fall back to blocking if not possible.
289+
# https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html
290+
return t.to('cpu', non_blocking=True)
291+
else:
292+
return t.clone()
293+
276294
def _set_state_dict(self,
277295
mg_module,
278296
mg_key: str,
@@ -412,25 +430,28 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
412430
if lora_A is not None:
413431
self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'})
414432
for key in ['q_proj', 'k_proj', 'v_proj']:
415-
hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
433+
hf_state_dict[f'{key}.lora_A.weight'] = self._cpu_clone(lora_A)
416434
lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1]))
417-
hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone()
418-
hf_state_dict['k_proj.lora_B.weight'] = lora_B[:,
419-
q_dim:-kv_dim, :].reshape(-1,
420-
lora_B.shape[-1]).clone()
421-
hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone()
435+
hf_state_dict['q_proj.lora_B.weight'] = self._cpu_clone(lora_B[:, :q_dim, :].reshape(
436+
-1, lora_B.shape[-1]))
437+
hf_state_dict['k_proj.lora_B.weight'] = self._cpu_clone(lora_B[:, q_dim:-kv_dim, :].reshape(
438+
-1, lora_B.shape[-1]))
439+
hf_state_dict['v_proj.lora_B.weight'] = self._cpu_clone(lora_B[:, -kv_dim:, :].reshape(
440+
-1, lora_B.shape[-1]))
441+
torch.cuda.empty_cache()
422442
elif not self._is_peft_format:
423443
mg_attn_weight = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.weight.data,
424444
'linear_qkv.weight')
425445
if mg_attn_weight is not None:
426446
mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size))
427-
hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone()
428-
hf_state_dict['k_proj.weight'] = mg_attn_weight[:,
429-
q_dim:-kv_dim, :].reshape(-1,
430-
args.hidden_size).clone()
431-
hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1,
432-
args.hidden_size).clone()
447+
hf_state_dict['q_proj.weight'] = self._cpu_clone(mg_attn_weight[:, :q_dim, :].reshape(
448+
-1, args.hidden_size))
449+
hf_state_dict['k_proj.weight'] = self._cpu_clone(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(
450+
-1, args.hidden_size))
451+
hf_state_dict['v_proj.weight'] = self._cpu_clone(mg_attn_weight[:, -kv_dim:, :].reshape(
452+
-1, args.hidden_size))
433453
del mg_attn_weight
454+
torch.cuda.empty_cache()
434455
self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore)
435456

436457
# Copy bias
@@ -448,9 +469,9 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
448469
'linear_qkv.bias')
449470
if mg_attn_bias is not None:
450471
mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1))
451-
hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone()
452-
hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone()
453-
hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone()
472+
hf_state_dict['q_proj.bias'] = self._cpu_clone(mg_attn_bias[:, :q_dim].reshape(-1))
473+
hf_state_dict['k_proj.bias'] = self._cpu_clone(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
474+
hf_state_dict['v_proj.bias'] = self._cpu_clone(mg_attn_bias[:, -kv_dim:].reshape(-1))
454475
if args.qk_layernorm:
455476
hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight'
456477
hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight'
@@ -626,6 +647,7 @@ def _set_mlp_state(self,
626647
weight_list.append(torch.stack([gate_proj_weight, up_proj_weight], dim=0))
627648
gate_up_proj_weight = torch.concat(weight_list, dim=0)
628649
del weight_list
650+
torch.cuda.empty_cache()
629651
else:
630652
gate_proj_weight = hf_state_dict['gate_proj.weight'].load()
631653
up_proj_weight = hf_state_dict['up_proj.weight'].load()
@@ -637,6 +659,7 @@ def _set_mlp_state(self,
637659
getattr(mg_mlp.linear_fc1,
638660
f'weight{i}').data.copy_(fc1_weight[i].view(-1, fc1_weight.shape[-1]))
639661
del fc1_weight
662+
torch.cuda.empty_cache()
640663
else:
641664
mg_mlp.linear_fc1.weight.data.copy_(fc1_weight.view(-1, fc1_weight.shape[-1]))
642665
else:
@@ -678,28 +701,33 @@ def _set_mlp_state(self,
678701
lora_B = lora_B.view(num_local_experts, -1, lora_B.shape[-1])
679702
for i in range(num_local_experts):
680703
hf_i = i + ep_rank * num_local_experts
681-
hf_state_dict[f'{hf_i}.gate_up_proj.lora_A.weight'] = lora_A[i].clone()
682-
hf_state_dict[f'{hf_i}.gate_up_proj.lora_B.weight'] = lora_B[i].clone()
704+
hf_state_dict[f'{hf_i}.gate_up_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
705+
hf_state_dict[f'{hf_i}.gate_up_proj.lora_B.weight'] = self._cpu_clone(lora_B[i])
706+
torch.cuda.empty_cache()
683707

684708
else:
685-
hf_state_dict['gate_up_proj.lora_A.weight'] = lora_A.clone()
686-
hf_state_dict['gate_up_proj.lora_B.weight'] = lora_B.view(-1, lora_B.shape[-1]).clone()
709+
hf_state_dict['gate_up_proj.lora_A.weight'] = self._cpu_clone(lora_A)
710+
hf_state_dict['gate_up_proj.lora_B.weight'] = self._cpu_clone(
711+
lora_B.view(-1, lora_B.shape[-1]))
712+
torch.cuda.empty_cache()
687713
else:
688714
self._peft_target_modules.update({'gate_proj', 'up_proj'})
689715
if is_expert:
690716
lora_A = lora_A.view(num_local_experts, -1, lora_A.shape[-1])
691717
lora_B = lora_B.view(num_local_experts, 2, -1, lora_B.shape[-1])
692718
for i in range(num_local_experts):
693719
hf_i = i + ep_rank * num_local_experts
694-
hf_state_dict[f'{hf_i}.gate_proj.lora_A.weight'] = lora_A[i].clone()
695-
hf_state_dict[f'{hf_i}.up_proj.lora_A.weight'] = lora_A[i].clone()
696-
hf_state_dict[f'{hf_i}.gate_proj.lora_B.weight'] = lora_B[i][0].clone()
697-
hf_state_dict[f'{hf_i}.up_proj.lora_B.weight'] = lora_B[i][1].clone()
720+
hf_state_dict[f'{hf_i}.gate_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
721+
hf_state_dict[f'{hf_i}.up_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
722+
hf_state_dict[f'{hf_i}.gate_proj.lora_B.weight'] = self._cpu_clone(lora_B[i][0])
723+
hf_state_dict[f'{hf_i}.up_proj.lora_B.weight'] = self._cpu_clone(lora_B[i][1])
724+
torch.cuda.empty_cache()
698725
else:
699-
hf_state_dict['gate_proj.lora_A.weight'] = lora_A.clone()
700-
hf_state_dict['up_proj.lora_A.weight'] = lora_A.clone()
701-
hf_state_dict['gate_proj.lora_B.weight'] = lora_B[0].clone()
702-
hf_state_dict['up_proj.lora_B.weight'] = lora_B[1].clone()
726+
hf_state_dict['gate_proj.lora_A.weight'] = self._cpu_clone(lora_A)
727+
hf_state_dict['up_proj.lora_A.weight'] = self._cpu_clone(lora_A)
728+
hf_state_dict['gate_proj.lora_B.weight'] = self._cpu_clone(lora_B[0])
729+
hf_state_dict['up_proj.lora_B.weight'] = self._cpu_clone(lora_B[1])
730+
torch.cuda.empty_cache()
703731
elif not self._is_peft_format:
704732
if mg_mlp is None:
705733
fc1_weight = None
@@ -725,27 +753,29 @@ def _set_mlp_state(self,
725753
if 'gate_up_proj' in hf_state_dict:
726754
gate_up_proj_weight = torch.concat(
727755
[hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0)
728-
hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone()
756+
hf_state_dict['gate_up_proj'] = self._cpu_clone(gate_up_proj_weight)
729757
else:
730758
for i in range(num_local_experts):
731759
hf_i = i + ep_rank * num_local_experts
732-
hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].clone()
733-
del gate_up_proj_weight
760+
hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = self._cpu_clone(
761+
gate_up_proj_weight[i])
762+
del gate_up_proj_weight
763+
torch.cuda.empty_cache()
734764
else:
735-
hf_state_dict['gate_up_proj.weight'] = gate_up_proj_weight.view(
736-
-1, gate_up_proj_weight.shape[-1]).clone()
765+
hf_state_dict['gate_up_proj.weight'] = self._cpu_clone(
766+
gate_up_proj_weight.view(-1, gate_up_proj_weight.shape[-1]))
737767
else:
738768
if is_expert:
739769
gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1,
740770
gate_up_proj_weight.shape[-1])
741771
for i in range(num_local_experts):
742772
hf_i = i + ep_rank * num_local_experts
743-
hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone()
744-
hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone()
773+
hf_state_dict[f'{hf_i}.gate_proj.weight'] = self._cpu_clone(gate_up_proj_weight[i][0])
774+
hf_state_dict[f'{hf_i}.up_proj.weight'] = self._cpu_clone(gate_up_proj_weight[i][1])
745775
del gate_up_proj_weight
746776
else:
747-
hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone()
748-
hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone()
777+
hf_state_dict['gate_proj.weight'] = self._cpu_clone(gate_up_proj_weight[0])
778+
hf_state_dict['up_proj.weight'] = self._cpu_clone(gate_up_proj_weight[1])
749779
# linear_fc2
750780
if is_expert:
751781
if to_mcore:
@@ -825,8 +855,8 @@ def _set_mlp_state(self,
825855
lora_B = lora_B.view(num_local_experts, -1, lora_B.shape[-1])
826856
for i in range(num_local_experts):
827857
hf_i = i + ep_rank * num_local_experts
828-
hf_state_dict[f'{hf_i}.down_proj.lora_A.weight'] = lora_A[i].clone()
829-
hf_state_dict[f'{hf_i}.down_proj.lora_B.weight'] = lora_B[i].clone()
858+
hf_state_dict[f'{hf_i}.down_proj.lora_A.weight'] = self._cpu_clone(lora_A[i])
859+
hf_state_dict[f'{hf_i}.down_proj.lora_B.weight'] = self._cpu_clone(lora_B[i])
830860
elif not self._is_peft_format:
831861
if mg_mlp is None:
832862
fc2_weight = None
@@ -838,17 +868,20 @@ def _set_mlp_state(self,
838868
dim=0)
839869
down_proj_weight = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=is_expert)
840870
del fc2_weight
871+
torch.cuda.empty_cache()
841872
if down_proj_weight is not None:
842873
down_proj_weight = down_proj_weight.view(num_local_experts, -1, down_proj_weight.shape[-1])
843874
if hf_grouped:
844875
down_proj_weight = down_proj_weight.transpose(1, 2)
845876
if 'down_proj' in hf_state_dict:
846877
down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0)
847-
hf_state_dict['down_proj'] = down_proj_weight.clone()
878+
hf_state_dict['down_proj'] = self._cpu_clone(down_proj_weight)
848879
else:
849880
for i in range(num_local_experts):
850881
hf_i = i + ep_rank * num_local_experts
851-
hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].clone()
882+
hf_state_dict[f'{hf_i}.down_proj.weight'] = self._cpu_clone(down_proj_weight[i])
883+
del down_proj_weight
884+
torch.cuda.empty_cache()
852885
else:
853886
self._set_state_dict(
854887
mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', to_mcore, is_expert=is_expert)

swift/megatron/tuners/lora.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def get_delta_weights(self, adapter) -> List[torch.Tensor]:
370370
assert len(weight_A) == len(weight_B)
371371
for i in range(len(weight_B)):
372372
output_tensor.append(transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * self.scaling[adapter])
373+
torch.cuda.empty_cache()
373374

374375
return output_tensor
375376

@@ -417,10 +418,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
417418
weight.data = orig_weights[i]
418419
else:
419420
base_layer.weight.data = orig_weights[0]
421+
del orig_weights
422+
torch.cuda.empty_cache()
420423
else:
421424
delta_weights = self.get_delta_weights(active_adapter)
422425
for orig_weight, delta_weight in zip(orig_weights, delta_weights):
423426
orig_weight.data += delta_weight
427+
del delta_weights
428+
torch.cuda.empty_cache()
424429
self.merged_adapters.append(active_adapter)
425430
if origin_device.type == 'cpu':
426431
self.to(device=origin_device)
@@ -452,6 +457,8 @@ def unmerge(self) -> None:
452457
for orig_weight, delta_weight in zip(orig_weights, delta_weights):
453458
# Subtract the delta weight to unmerge
454459
orig_weight.data -= delta_weight
460+
del delta_weights
461+
torch.cuda.empty_cache()
455462

456463
# Clear the merged adapters list
457464
self.merged_adapters = []

0 commit comments

Comments
 (0)