-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
295 lines (250 loc) · 12.9 KB
/
train.py
File metadata and controls
295 lines (250 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
# 文件名: train.py
"""
train.py - 最终版、完全集成的融合模型训练脚本
====================================================
• 整合了所有已知修复,包括LoRA目标模块名、梯度检查点、参数逻辑等。
• 实现了灵活的、可通过命令行选择的多种特征融合策略。
• 包含了数值稳定的评估指标计算和健壮的数据整理器与训练器。
• 这是你项目所需的最终训练脚本。
"""
import argparse
import os
import math
from typing import Dict, List, Tuple, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.special import softmax # 用于数值稳定的 softmax 计算
from sklearn.metrics import f1_score, roc_auc_score
from transformers import (
AutoTokenizer, AutoModel, Trainer, TrainingArguments, set_seed,
EarlyStoppingCallback, PreTrainedTokenizer
)
from peft import LoraConfig, get_peft_model, TaskType
# 确保这些文件在同一个目录下
from dataset import ProteinLIPDataset
from collator import FusionDataCollator
from losses import FocalLoss
# --- 1. 融合模型定义 ---
class ModelFusion(nn.Module):
def __init__(self, protbert_name, esm_name, num_labels, lora_r, lora_alpha, fusion_strategy="attention"):
super().__init__()
self.protbert = AutoModel.from_pretrained(protbert_name)
self.esm = AutoModel.from_pretrained(esm_name)
protbert_targets = ["query", "key", "value", "dense"]
esm_targets = ["query", "key", "value", "dense"]
protbert_lora_config = LoraConfig(r=lora_r, lora_alpha=lora_alpha, target_modules=protbert_targets, lora_dropout=0.1, task_type=TaskType.FEATURE_EXTRACTION)
esm_lora_config = LoraConfig(r=lora_r, lora_alpha=lora_alpha, target_modules=esm_targets, lora_dropout=0.1, task_type=TaskType.FEATURE_EXTRACTION)
self.protbert = get_peft_model(self.protbert, protbert_lora_config)
self.esm = get_peft_model(self.esm, esm_lora_config)
protbert_dim = self.protbert.config.hidden_size
esm_dim = self.esm.config.hidden_size
self.fusion_strategy = fusion_strategy
self.esm_proj = nn.Identity()
if protbert_dim != esm_dim:
self.esm_proj = nn.Linear(esm_dim, protbert_dim)
self.fusion_dim = protbert_dim
if fusion_strategy == "concat":
self.fusion_dim = protbert_dim * 2
self.fusion_layer = nn.Linear(self.fusion_dim, protbert_dim)
elif fusion_strategy == "attention":
self.fusion_layer = nn.MultiheadAttention(embed_dim=protbert_dim, num_heads=8, dropout=0.1, batch_first=True)
elif fusion_strategy == "weighted_sum":
self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
self.classifier = nn.Linear(self.fusion_dim, num_labels)
self.dropout = nn.Dropout(0.1)
self.config = self.protbert.config
self.config.num_labels = num_labels
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
if hasattr(self.protbert, 'gradient_checkpointing_enable'):
self.protbert.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
if hasattr(self.esm, 'gradient_checkpointing_enable'):
self.esm.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
def forward(self, protbert_input_ids, protbert_attention_mask, esm_input_ids, esm_attention_mask, **kwargs):
protbert_features = self.protbert(input_ids=protbert_input_ids, attention_mask=protbert_attention_mask).last_hidden_state
esm_features = self.esm(input_ids=esm_input_ids, attention_mask=esm_attention_mask).last_hidden_state
esm_features_proj = self.esm_proj(esm_features)
if self.fusion_strategy == "attention":
fused_features, _ = self.fusion_layer(protbert_features, esm_features_proj, esm_features_proj)
elif self.fusion_strategy == "concat":
fused_features = torch.cat([protbert_features, esm_features_proj], dim=-1)
fused_features = self.fusion_layer(fused_features)
elif self.fusion_strategy == "weighted_sum":
weights = torch.softmax(self.fusion_weights, dim=0)
fused_features = weights[0] * protbert_features + weights[1] * esm_features_proj
else: # add
fused_features = protbert_features + esm_features_proj
logits = self.classifier(self.dropout(fused_features))
# 返回一个包含logits的字典,损失计算将在Trainer中完成
return {"logits": logits}
# --- 2. 评估指标 ---
def compute_metrics(eval_pred):
# 兼容 Hugging Face Trainer 的 EvalPrediction 对象
predictions = getattr(eval_pred, "predictions", eval_pred[0])
labels = getattr(eval_pred, "label_ids", eval_pred[1])
# 使用 softmax 概率(对于二分类,取第 2 类的概率)
probs = softmax(predictions, axis=-1)[..., 1]
# 兼容 2D (N,num_labels) 与 3D (B,S,num_labels)
if predictions.ndim == 3:
preds = np.argmax(predictions, axis=2)
else:
preds = np.argmax(predictions, axis=-1)
mask = labels != -100
true_labels, true_preds = labels[mask], preds[mask]
if len(true_labels) == 0:
return {"f1": 0.0, "auc": 0.0}
if len(np.unique(true_labels)) < 2:
return {"f1": f1_score(true_labels, true_preds, average="binary", zero_division=0)}
true_probs = probs[mask]
return {
"f1": f1_score(true_labels, true_preds, average="binary", zero_division=0),
"auc": roc_auc_score(true_labels, true_probs)
}
# --- 3. 自定义训练器 ---
class CustomTrainer(Trainer):
def __init__(self, *args, use_focal_loss=False, focal_gamma=2.0, class_weights=None, **kwargs):
super().__init__(*args, **kwargs)
self.use_focal_loss = use_focal_loss
if self.use_focal_loss:
self.focal_loss = FocalLoss(gamma=focal_gamma, reduction="mean")
self.class_weights = class_weights
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
if self.use_focal_loss and "labels" in inputs:
# 修复:安全地分离标签,不修改原始inputs字典
labels = inputs.get("labels")
model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
outputs = model(**model_inputs)
logits = outputs["logits"]
mask = labels.view(-1) != -100
if not mask.any():
loss = torch.tensor(0.0, device=logits.device, requires_grad=True)
else:
# 修复:直接从logits的形状推断类别数量,兼容DDP
num_labels = logits.shape[-1]
active_logits = logits.view(-1, num_labels)[mask]
active_labels = labels.view(-1)[mask]
loss = self.focal_loss(active_logits, active_labels, alpha=self.class_weights)
outputs["loss"] = loss
return (loss, outputs) if return_outputs else loss
# 如果不使用Focal Loss,则使用标准的交叉熵损失
# 同样需要安全地处理inputs
labels = inputs.pop("labels", None)
outputs = model(**inputs)
logits = outputs["logits"]
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
outputs["loss"] = loss
return (loss, outputs) if return_outputs else loss
# --- 4. 主函数 ---
def main():
parser = argparse.ArgumentParser(description="训练蛋白质-脂质相互作用预测模型")
parser.add_argument("--protbert_model", default="Rostlab/prot_bert_bfd")
parser.add_argument("--esm_model", default="facebook/esm2_t33_650M_UR50D")
parser.add_argument("--train_pkl", required=True)
parser.add_argument("--eval_pkl", required=True)
parser.add_argument("--output_dir", default="./models/fusion_model_final")
parser.add_argument("--max_len", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=15)
parser.add_argument("--per_device_batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--focal_gamma", type=float, default=2.0)
parser.add_argument("--fusion_strategy", choices=["attention", "concat", "weighted_sum", "add"], default="attention")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--warmup_ratio", type=float, default=0.1)
# 修正后的布尔参数
parser.add_argument("--bf16", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
parser.add_argument("--use_focal_loss", action="store_true", default=False)
args = parser.parse_args()
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
protbert_tokenizer = AutoTokenizer.from_pretrained(args.protbert_model)
esm_tokenizer = AutoTokenizer.from_pretrained(args.esm_model)
# 确保 pad token 存在,避免 collator 在 padding 时出错(不新增词表)
if getattr(protbert_tokenizer, "pad_token", None) is None:
fallback = (
getattr(protbert_tokenizer, "eos_token", None)
or getattr(protbert_tokenizer, "sep_token", None)
or getattr(protbert_tokenizer, "cls_token", None)
or getattr(protbert_tokenizer, "unk_token", "<unk>")
)
protbert_tokenizer.pad_token = fallback
if getattr(esm_tokenizer, "pad_token", None) is None:
fallback = (
getattr(esm_tokenizer, "eos_token", None)
or getattr(esm_tokenizer, "sep_token", None)
or getattr(esm_tokenizer, "cls_token", None)
or getattr(esm_tokenizer, "unk_token", "<unk>")
)
esm_tokenizer.pad_token = fallback
train_ds = ProteinLIPDataset(protbert_tokenizer, esm_tokenizer, args.max_len, pkl_path=args.train_pkl)
eval_ds = ProteinLIPDataset(protbert_tokenizer, esm_tokenizer, args.max_len, pkl_path=args.eval_pkl)
data_collator = FusionDataCollator(protbert_tokenizer, esm_tokenizer)
class_weights = None
if args.use_focal_loss:
valid_labels = [l for chunk in train_ds.processed_data for l in chunk["labels"] if l in (0, 1)]
pos_count = sum(1 for l in valid_labels if l == 1)
neg_count = sum(1 for l in valid_labels if l == 0)
pos_weight = (neg_count / max(pos_count, 1)) if (neg_count + pos_count) > 0 else 1.0
class_weights = torch.tensor([1.0, pos_weight], dtype=torch.float32)
print(f"Calculated positive class weight for FocalLoss: {pos_weight:.2f}")
model = ModelFusion(
protbert_name=args.protbert_model,
esm_name=args.esm_model,
num_labels=2,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
fusion_strategy=args.fusion_strategy
)
if args.gradient_checkpointing:
print("Manually enabling gradient checkpointing for fusion model.")
model.gradient_checkpointing_enable()
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.per_device_batch_size,
per_device_eval_batch_size=args.per_device_batch_size * 2,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
bf16=args.bf16,
gradient_checkpointing=False,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_strategy="epoch",
report_to="tensorboard",
remove_unused_columns=False,
lr_scheduler_type="cosine",
warmup_ratio=args.warmup_ratio,
)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
use_focal_loss=args.use_focal_loss,
class_weights=class_weights,
focal_gamma=args.focal_gamma,
)
print("🚀 开始训练...")
trainer.train()
print("✅ 训练完成!")
final_model_path = os.path.join(args.output_dir, "final_model")
trainer.save_model(final_model_path)
print(f"模型已保存到: {final_model_path}")
if __name__ == "__main__":
main()