4343from Bio .PDB import PDBParser
4444from foldtree2 .src .chebconv import StableChebConv
4545from scipy .spatial .distance import cdist
46- EPS = 1e-15
46+ EPS = 1e-6
4747datadir = '../../datasets/foldtree2/'
4848
4949
@@ -197,13 +197,12 @@ def __init__(self, in_channels = {'res':10 , 'godnode4decoder':5 , 'foldx':23 },
197197 if output_ss == True :
198198 self .output_ss = True
199199 self .ss_mlp = torch .nn .Sequential (
200- torch .nn .LayerNorm (lastlin ),
200+ torch .nn .LayerNorm (lastlin , eps = 1e-6 ),
201201 torch .nn .Linear (lastlin , 128 ),
202202 torch .nn .GELU (),
203203 torch .nn .Linear (128 ,64 ),
204204 torch .nn .GELU (),
205- torch .nn .Linear (64 ,3 ),
206- torch .nn .LogSoftmax (dim = 1 )
205+ torch .nn .Linear (64 ,3 )
207206 )
208207 else :
209208 self .output_ss = False
@@ -231,7 +230,7 @@ def __init__(self, in_channels = {'res':10 , 'godnode4decoder':5 , 'foldx':23 },
231230 self .output_edge_logits = True
232231 self .edge_logits_mlp = torch .nn .Sequential (
233232 #layernorm
234- torch .nn .LayerNorm (2 * lastlin ),
233+ torch .nn .LayerNorm (2 * lastlin , eps = 1e-6 ),
235234 torch .nn .Linear (2 * lastlin , anglesdecoder_hidden [0 ]),
236235 torch .nn .GELU (),
237236 torch .nn .Linear (anglesdecoder_hidden [0 ],anglesdecoder_hidden [1 ]),
@@ -276,7 +275,7 @@ def forward(self, data , contact_pred_index, **kwargs):
276275 if self .residual == True :
277276 z = z + inz
278277 if self .normalize == True :
279- z = z / ( torch .norm (z , dim = 1 , keepdim = True ) + 1e-10 )
278+ z = z / ( torch .norm (z , dim = 1 , keepdim = True ) + 1e-6 )
280279 #decoder_in = torch.cat( [inz, z] , axis = 1)
281280 #amino acid prediction removed
282281
@@ -420,13 +419,10 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
420419 padding = kernel_size // 2
421420 )
422421 )
423- self .body ['norms' ].append (nn .LayerNorm (channels ))
422+ self .body ['norms' ].append (nn .LayerNorm (channels , eps = 1e-6 ))
424423
425- finalout = conv_channels [- 1 ]
426-
427- # Intermediate projection
428424 self .body ['lin' ] = nn .Sequential (
429- nn .Linear (finalout , Xdecoder_hidden [0 ]),
425+ nn .Linear (conv_channels [ - 1 ] , Xdecoder_hidden [0 ]),
430426 nn .GELU (),
431427 nn .Linear (Xdecoder_hidden [0 ], Xdecoder_hidden [1 ]),
432428 nn .GELU (),
@@ -474,13 +470,12 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
474470 # Secondary structure prediction
475471 if output_ss :
476472 self .head ['ss_mlp' ] = nn .Sequential (
477- nn .LayerNorm (lastlin ),
473+ nn .LayerNorm (lastlin , eps = 1e-6 ),
478474 nn .Linear (lastlin , anglesdecoder_hidden [0 ]),
479475 nn .GELU (),
480476 nn .Linear (anglesdecoder_hidden [0 ], anglesdecoder_hidden [1 ]),
481477 nn .GELU (),
482- nn .Linear (anglesdecoder_hidden [1 ], 3 ),
483- nn .LogSoftmax (dim = 1 )
478+ nn .Linear (anglesdecoder_hidden [1 ], 3 )
484479 )
485480
486481 # Bond angles prediction
@@ -499,7 +494,7 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
499494 # Edge logits prediction
500495 if output_edge_logits :
501496 self .head ['edge_logits_mlp' ] = nn .Sequential (
502- nn .LayerNorm (2 * lastlin ),
497+ nn .LayerNorm (2 * lastlin , eps = 1e-6 ),
503498 nn .Linear (2 * lastlin , anglesdecoder_hidden [0 ]),
504499 nn .GELU (),
505500 nn .Linear (anglesdecoder_hidden [0 ], anglesdecoder_hidden [1 ]),
@@ -584,7 +579,7 @@ def forward(self, data, contact_pred_index, **kwargs):
584579 if self .residual :
585580 z = z + inz
586581 if self .normalize :
587- z = z / (torch .norm (z , dim = 1 , keepdim = True ) + 1e-10 )
582+ z = z / (torch .norm (z , dim = 1 , keepdim = True ) + 1e-6 )
588583
589584 # ===================== HEAD PROCESSING =====================
590585 # Godnode/FFT decoder
@@ -730,7 +725,7 @@ def __init__(
730725
731726 # Optional CNN decoder
732727 if use_cnn_decoder := kwargs .get ('use_cnn_decoder' , False ):
733- self .head ['prenorm' ] = nn .LayerNorm (d_model )
728+ self .head ['prenorm' ] = nn .LayerNorm (d_model , eps = 1e-6 )
734729 self .head ['cnn_decoder' ] = nn .Sequential (
735730 # Conv1d expects (batch, channels, seq_len)
736731 nn .Conv1d (d_model , AAdecoder_hidden [0 ], kernel_size = 3 , padding = 1 ),
@@ -750,8 +745,7 @@ def __init__(
750745 nn .GELU (),
751746 nn .Linear (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ]),
752747 nn .GELU (),
753- nn .Linear (AAdecoder_hidden [2 ], 20 ),
754- nn .LogSoftmax (dim = 1 )
748+ nn .Linear (AAdecoder_hidden [2 ], 20 )
755749 )
756750
757751 # Optional secondary structure prediction head
@@ -763,8 +757,7 @@ def __init__(
763757 nn .GELU (),
764758 nn .Linear (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ]),
765759 nn .GELU (),
766- nn .Linear (AAdecoder_hidden [2 ], 3 ),
767- nn .LogSoftmax (dim = 1 )
760+ nn .Linear (AAdecoder_hidden [2 ], 3 )
768761 )
769762
770763 def forward (self , data , ** kwargs ):
@@ -814,7 +807,7 @@ def forward(self, data, **kwargs):
814807
815808 # Apply normalization
816809 if self .normalize :
817- x = x / (torch .norm (x , dim = - 1 , keepdim = True ) + 1e-10 )
810+ x = x / (torch .norm (x , dim = - 1 , keepdim = True ) + 1e-6 )
818811
819812 # ===================== HEAD PROCESSING =====================
820813 if batch is not None :
@@ -975,15 +968,14 @@ def __init__(
975968 if not isinstance (ssdecoder_hidden , list ):
976969 ssdecoder_hidden = [ssdecoder_hidden , ssdecoder_hidden ]
977970 self .head ['ss_head' ] = nn .Sequential (
978- nn .LayerNorm (d_model ),
971+ nn .LayerNorm (d_model , eps = 1e-6 ),
979972 nn .Linear (d_model , ssdecoder_hidden [0 ]),
980973 nn .GELU (),
981974 nn .Linear (ssdecoder_hidden [0 ], ssdecoder_hidden [1 ]),
982975 nn .GELU (),
983976 nn .Linear (ssdecoder_hidden [1 ], ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ]),
984977 nn .GELU (),
985- nn .Linear (ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ], 3 ),
986- nn .LogSoftmax (dim = 1 )
978+ nn .Linear (ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ], 3 )
987979 )
988980
989981 # Bond angles prediction head (phi, psi, omega)
@@ -1044,7 +1036,7 @@ def forward(self, data, contact_pred_index=None, **kwargs):
10441036
10451037 # Apply normalization
10461038 if self .normalize :
1047- x = x / (torch .norm (x , dim = - 1 , keepdim = True ) + 1e-10 )
1039+ x = x / (torch .norm (x , dim = - 1 , keepdim = True ) + 1e-6 )
10481040
10491041 # ===================== HEAD PROCESSING =====================
10501042 rt_pred = None
@@ -1171,7 +1163,7 @@ def forward(self, data, **kwargs):
11711163 # No residual connection here, as pooled is a single vector
11721164 pass
11731165 if self .normalize :
1174- pooled = pooled / (torch .norm (pooled , dim = - 1 , keepdim = True ) + 1e-10 )
1166+ pooled = pooled / (torch .norm (pooled , dim = - 1 , keepdim = True ) + 1e-6 )
11751167 foldx_out = self .lin (pooled )
11761168 return { 'foldx_out' : foldx_out }
11771169
0 commit comments