Skip to content

Commit ea84e12

Browse files
author
dmoi
committed
fixing fp16 errors
1 parent 2b8979e commit ea84e12

5 files changed

Lines changed: 35 additions & 35 deletions

File tree

config_notebook_1k_epochs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ overwrite: true
1212
# Training hyperparameters (from notebook)
1313
epochs: 1000
1414
batch_size: 10
15-
gradient_accumulation_steps: 1
15+
gradient_accumulation_steps: 2
1616
seed: 0
1717

1818
# Model architecture (from notebook)

foldtree2/learn_monodecoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,8 @@ def analyze_gradient_norms(model, top_k=3):
733733
if out.get('ss_pred') is not None:
734734
if args.mask_plddt:
735735
mask = (data['plddt'].x >= args.plddt_threshold).squeeze()
736-
ss_loss = F.cross_entropy(out['ss_pred'][mask], data['ss'].x[mask])
736+
if mask.sum() > 0:
737+
ss_loss = F.cross_entropy(out['ss_pred'][mask], data['ss'].x[mask])
737738
else:
738739
ss_loss = F.cross_entropy(out['ss_pred'], data['ss'].x)
739740

@@ -771,7 +772,8 @@ def analyze_gradient_norms(model, top_k=3):
771772
if out.get('ss_pred') is not None:
772773
if args.mask_plddt:
773774
mask = (data['plddt'].x >= args.plddt_threshold).squeeze()
774-
ss_loss = F.cross_entropy(out['ss_pred'][mask], data['ss'].x[mask])
775+
if mask.sum() > 0:
776+
ss_loss = F.cross_entropy(out['ss_pred'][mask], data['ss'].x[mask])
775777
else:
776778
ss_loss = F.cross_entropy(out['ss_pred'], data['ss'].x)
777779

foldtree2/src/encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ def __init__(self, in_channels, hidden_channels, out_channels,
8282
self.input = nn.ModuleDict()
8383

8484
self.input['dropout'] = nn.Dropout(p=dropout_p)
85-
self.input['ln'] = nn.LayerNorm(self.in_channels)
85+
self.input['ln'] = nn.LayerNorm(self.in_channels, eps=1e-6)
8686

8787
self.input['inmlp'] = nn.Sequential(
8888
nn.Dropout(dropout_p),
89-
nn.LayerNorm(self.in_with_positions),
89+
nn.LayerNorm(self.in_with_positions, eps=1e-6),
9090
nn.Linear(self.in_with_positions, hidden_channels[0] * 2),
9191
nn.GELU(),
9292
nn.Linear(hidden_channels[0] * 2, hidden_channels[0]),
@@ -96,7 +96,7 @@ def __init__(self, in_channels, hidden_channels, out_channels,
9696
if self.fftin:
9797
self.input['ffin'] = nn.Sequential(
9898
nn.Dropout(dropout_p),
99-
nn.LayerNorm(2 * 80),
99+
nn.LayerNorm(2 * 80, eps=1e-6),
100100
nn.Linear(2 * 80, hidden_channels[0] * 2),
101101
nn.GELU(),
102102
nn.Linear(hidden_channels[0] * 2, hidden_channels[0]),

foldtree2/src/losses/losses.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,10 @@ def ss_reconstruction_loss(ss, recon_ss, mask_plddt=False, plddt_threshold=0.3 ,
234234
"""
235235
if mask_plddt:
236236
mask = (plddt_mask > plddt_threshold).squeeze()
237-
ss_loss = F.cross_entropy(recon_ss[mask], ss[mask])
237+
if mask.sum() > 0:
238+
ss_loss = F.cross_entropy(recon_ss[mask], ss[mask])
239+
else:
240+
ss_loss = torch.tensor(0.0, device=recon_ss.device)
238241
else:
239242
ss_loss = F.cross_entropy(recon_ss, ss)
240243
return ss_loss
@@ -251,7 +254,10 @@ def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None , plddt_
251254
if plddt_mask is not None:
252255
mask = plddt_mask > plddt_thresh
253256
mask = mask.squeeze(1) # Ensure mask is 1D
254-
delta = delta[mask]
257+
if mask.sum() > 0:
258+
delta = delta[mask]
259+
else:
260+
return torch.tensor(0.0, device=pred.device)
255261
loss = F.smooth_l1_loss(delta, torch.zeros_like(delta), beta=beta)
256262

257263
return loss.mean()

foldtree2/src/mono_decoders.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from Bio.PDB import PDBParser
4444
from foldtree2.src.chebconv import StableChebConv
4545
from scipy.spatial.distance import cdist
46-
EPS = 1e-15
46+
EPS = 1e-6
4747
datadir = '../../datasets/foldtree2/'
4848

4949

@@ -197,13 +197,12 @@ def __init__(self, in_channels = {'res':10 , 'godnode4decoder':5 , 'foldx':23 },
197197
if output_ss == True:
198198
self.output_ss = True
199199
self.ss_mlp = torch.nn.Sequential(
200-
torch.nn.LayerNorm(lastlin),
200+
torch.nn.LayerNorm(lastlin, eps=1e-6),
201201
torch.nn.Linear(lastlin, 128),
202202
torch.nn.GELU(),
203203
torch.nn.Linear(128,64),
204204
torch.nn.GELU(),
205-
torch.nn.Linear(64,3),
206-
torch.nn.LogSoftmax(dim=1)
205+
torch.nn.Linear(64,3)
207206
)
208207
else:
209208
self.output_ss = False
@@ -231,7 +230,7 @@ def __init__(self, in_channels = {'res':10 , 'godnode4decoder':5 , 'foldx':23 },
231230
self.output_edge_logits = True
232231
self.edge_logits_mlp = torch.nn.Sequential(
233232
#layernorm
234-
torch.nn.LayerNorm(2*lastlin),
233+
torch.nn.LayerNorm(2*lastlin, eps=1e-6),
235234
torch.nn.Linear(2*lastlin, anglesdecoder_hidden[0]),
236235
torch.nn.GELU(),
237236
torch.nn.Linear(anglesdecoder_hidden[0],anglesdecoder_hidden[1]),
@@ -276,7 +275,7 @@ def forward(self, data , contact_pred_index, **kwargs):
276275
if self.residual == True:
277276
z = z + inz
278277
if self.normalize == True:
279-
z = z / ( torch.norm(z, dim=1, keepdim=True) + 1e-10)
278+
z = z / ( torch.norm(z, dim=1, keepdim=True) + 1e-6)
280279
#decoder_in = torch.cat( [inz, z] , axis = 1)
281280
#amino acid prediction removed
282281

@@ -420,13 +419,10 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
420419
padding=kernel_size//2
421420
)
422421
)
423-
self.body['norms'].append(nn.LayerNorm(channels))
422+
self.body['norms'].append(nn.LayerNorm(channels, eps=1e-6))
424423

425-
finalout = conv_channels[-1]
426-
427-
# Intermediate projection
428424
self.body['lin'] = nn.Sequential(
429-
nn.Linear(finalout, Xdecoder_hidden[0]),
425+
nn.Linear(conv_channels[-1], Xdecoder_hidden[0]),
430426
nn.GELU(),
431427
nn.Linear(Xdecoder_hidden[0], Xdecoder_hidden[1]),
432428
nn.GELU(),
@@ -474,13 +470,12 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
474470
# Secondary structure prediction
475471
if output_ss:
476472
self.head['ss_mlp'] = nn.Sequential(
477-
nn.LayerNorm(lastlin),
473+
nn.LayerNorm(lastlin, eps=1e-6),
478474
nn.Linear(lastlin, anglesdecoder_hidden[0]),
479475
nn.GELU(),
480476
nn.Linear(anglesdecoder_hidden[0], anglesdecoder_hidden[1]),
481477
nn.GELU(),
482-
nn.Linear(anglesdecoder_hidden[1], 3),
483-
nn.LogSoftmax(dim=1)
478+
nn.Linear(anglesdecoder_hidden[1], 3)
484479
)
485480

486481
# Bond angles prediction
@@ -499,7 +494,7 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
499494
# Edge logits prediction
500495
if output_edge_logits:
501496
self.head['edge_logits_mlp'] = nn.Sequential(
502-
nn.LayerNorm(2*lastlin),
497+
nn.LayerNorm(2*lastlin, eps=1e-6),
503498
nn.Linear(2*lastlin, anglesdecoder_hidden[0]),
504499
nn.GELU(),
505500
nn.Linear(anglesdecoder_hidden[0], anglesdecoder_hidden[1]),
@@ -584,7 +579,7 @@ def forward(self, data, contact_pred_index, **kwargs):
584579
if self.residual:
585580
z = z + inz
586581
if self.normalize:
587-
z = z / (torch.norm(z, dim=1, keepdim=True) + 1e-10)
582+
z = z / (torch.norm(z, dim=1, keepdim=True) + 1e-6)
588583

589584
# ===================== HEAD PROCESSING =====================
590585
# Godnode/FFT decoder
@@ -730,7 +725,7 @@ def __init__(
730725

731726
# Optional CNN decoder
732727
if use_cnn_decoder := kwargs.get('use_cnn_decoder', False):
733-
self.head['prenorm'] = nn.LayerNorm(d_model)
728+
self.head['prenorm'] = nn.LayerNorm(d_model, eps=1e-6)
734729
self.head['cnn_decoder'] = nn.Sequential(
735730
# Conv1d expects (batch, channels, seq_len)
736731
nn.Conv1d(d_model, AAdecoder_hidden[0], kernel_size=3, padding=1),
@@ -750,8 +745,7 @@ def __init__(
750745
nn.GELU(),
751746
nn.Linear(AAdecoder_hidden[1], AAdecoder_hidden[2]),
752747
nn.GELU(),
753-
nn.Linear(AAdecoder_hidden[2], 20),
754-
nn.LogSoftmax(dim=1)
748+
nn.Linear(AAdecoder_hidden[2], 20)
755749
)
756750

757751
# Optional secondary structure prediction head
@@ -763,8 +757,7 @@ def __init__(
763757
nn.GELU(),
764758
nn.Linear(AAdecoder_hidden[1], AAdecoder_hidden[2]),
765759
nn.GELU(),
766-
nn.Linear(AAdecoder_hidden[2], 3),
767-
nn.LogSoftmax(dim=1)
760+
nn.Linear(AAdecoder_hidden[2], 3)
768761
)
769762

770763
def forward(self, data, **kwargs):
@@ -814,7 +807,7 @@ def forward(self, data, **kwargs):
814807

815808
# Apply normalization
816809
if self.normalize:
817-
x = x / (torch.norm(x, dim=-1, keepdim=True) + 1e-10)
810+
x = x / (torch.norm(x, dim=-1, keepdim=True) + 1e-6)
818811

819812
# ===================== HEAD PROCESSING =====================
820813
if batch is not None:
@@ -975,15 +968,14 @@ def __init__(
975968
if not isinstance(ssdecoder_hidden, list):
976969
ssdecoder_hidden = [ssdecoder_hidden, ssdecoder_hidden]
977970
self.head['ss_head'] = nn.Sequential(
978-
nn.LayerNorm(d_model),
971+
nn.LayerNorm(d_model, eps=1e-6),
979972
nn.Linear(d_model, ssdecoder_hidden[0]),
980973
nn.GELU(),
981974
nn.Linear(ssdecoder_hidden[0], ssdecoder_hidden[1]),
982975
nn.GELU(),
983976
nn.Linear(ssdecoder_hidden[1], ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1]),
984977
nn.GELU(),
985-
nn.Linear(ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1], 3),
986-
nn.LogSoftmax(dim=1)
978+
nn.Linear(ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1], 3)
987979
)
988980

989981
# Bond angles prediction head (phi, psi, omega)
@@ -1044,7 +1036,7 @@ def forward(self, data, contact_pred_index=None, **kwargs):
10441036

10451037
# Apply normalization
10461038
if self.normalize:
1047-
x = x / (torch.norm(x, dim=-1, keepdim=True) + 1e-10)
1039+
x = x / (torch.norm(x, dim=-1, keepdim=True) + 1e-6)
10481040

10491041
# ===================== HEAD PROCESSING =====================
10501042
rt_pred = None
@@ -1171,7 +1163,7 @@ def forward(self, data, **kwargs):
11711163
# No residual connection here, as pooled is a single vector
11721164
pass
11731165
if self.normalize:
1174-
pooled = pooled / (torch.norm(pooled, dim=-1, keepdim=True) + 1e-10)
1166+
pooled = pooled / (torch.norm(pooled, dim=-1, keepdim=True) + 1e-6)
11751167
foldx_out = self.lin(pooled)
11761168
return { 'foldx_out' : foldx_out }
11771169

0 commit comments

Comments
 (0)