1414import torch .nn as nn
1515from torch .distributions import Distribution
1616
17- from egnn_pytorch import EGNN_Network
17+ import mdtraj as md
18+ from moleculekit .molecule import Molecule
19+
20+ from egnn_pytorch import EGNN_Network , EGNN
1821from sklearn .mixture import GaussianMixture
1922from hydrantic .model import Model , ModelHparams
20- from moleculekit .molecule import Molecule
2123
2224from split_flows .utils .interpolant import Interpolation , Interpolant
2325from split_flows .mixins .continuous_flow import ContinuousFlowMixin
2426from split_flows .utils .utils import to_one_hot , sum_except_batch , match_dims
27+ from split_flows .utils .metrics import graph_edit_distance
2528
2629
2730logging .basicConfig (level = logging .INFO )
@@ -72,9 +75,9 @@ def augment_gmm(self, R: Tensor, gmm: GaussianMixture) -> Tensor:
7275 :return: Full set of coordinates with noise."""
7376
7477 z = torch .empty ((R .shape [0 ], self .num_particles , R .shape [2 ]), device = R .device )
75- z_gmm = torch .tensor (
76- gmm . sample ( R .shape [0 ])[ 0 ], dtype = R . dtype , device = R . device
77- ). view ( R . shape [ 0 ], - 1 , 3 )
78+ z_gmm = torch .tensor (gmm . sample ( R . shape [ 0 ])[ 0 ], dtype = R . dtype , device = R . device ). view (
79+ R .shape [0 ], - 1 , 3
80+ )
7881
7982 start_idx = 0
8083 for i , (cg_idx , noise_idx ) in enumerate (self .latent_groupings ):
@@ -115,9 +118,7 @@ def log_prob(self, value: Tensor) -> Tensor:
115118 R_cg = value [:, cg_idx , :]
116119 R_noise = value [:, noise_idx , :]
117120 exponential_term = (
118- - 0.5
119- * sum_except_batch ((R_noise - R_cg [:, None , :]) ** 2 )
120- / self .scale ** 2
121+ - 0.5 * sum_except_batch ((R_noise - R_cg [:, None , :]) ** 2 ) / self .scale ** 2
121122 )
122123 normalization_term = - torch .log (Z ) * R_noise .shape [1 ]
123124 log_prob += exponential_term + normalization_term
@@ -150,13 +151,56 @@ def __init__(
150151 self .atom_embedding = nn .Linear (atom_types .size (- 1 ), self .dim )
151152 self .bead_embedding = nn .Linear (bead_types .size (- 1 ), self .dim )
152153
154+ self ._init_weights ()
155+
153156 def forward (self , x : Tensor , t : Tensor ) -> Tensor :
154157 atom_embeddings = self .atom_embedding (self .atom_types ).repeat (x .size (0 ), 1 , 1 )
155158 bead_embeddings = self .bead_embedding (self .bead_types ).repeat (x .size (0 ), 1 , 1 )
156159 t = match_dims (t , x ).repeat (1 , x .shape [1 ], 1 )
157160 h = torch .cat ([atom_embeddings , bead_embeddings , t ], dim = - 1 )
158161 return self .net (h + torch .randn_like (h ), x )[1 ]
159162
163+ def _init_weights (self ):
164+ """Initialize weights following EGNN best practices.
165+
166+ - Message MLPs (phi_e, phi_h): Xavier uniform
167+ - Coordinate MLP last layer: scaled down by 0.01
168+ - Biases: zero initialization
169+ - Embedding layers: Xavier uniform
170+ """
171+
172+ # Initialize embedding layers with Xavier uniform
173+ nn .init .xavier_uniform_ (self .atom_embedding .weight )
174+ nn .init .zeros_ (self .atom_embedding .bias )
175+ nn .init .xavier_uniform_ (self .bead_embedding .weight )
176+ nn .init .zeros_ (self .bead_embedding .bias )
177+
178+ # Initialize EGNN layers
179+ for layer in self .net .layers :
180+ if isinstance (layer , EGNN ):
181+ # Message MLPs (phi_e, phi_h) - Xavier uniform
182+ for module in [layer .edge_mlp , layer .node_mlp ]:
183+ for m in module .modules ():
184+ if isinstance (m , nn .Linear ):
185+ nn .init .xavier_uniform_ (m .weight )
186+ if m .bias is not None :
187+ nn .init .zeros_ (m .bias )
188+
189+ # Coordinate MLP - Xavier uniform for all but last layer
190+ coord_mlp_layers = list (layer .coors_mlp .modules ())
191+ linear_layers = [m for m in coord_mlp_layers if isinstance (m , nn .Linear )]
192+
193+ for i , m in enumerate (linear_layers ):
194+ if i == len (linear_layers ) - 1 :
195+ # Last layer: scale down by 0.01
196+ nn .init .xavier_uniform_ (m .weight )
197+ m .weight .data *= 0.01
198+ else :
199+ nn .init .xavier_uniform_ (m .weight )
200+
201+ if m .bias is not None :
202+ nn .init .zeros_ (m .bias )
203+
160204
161205class SplitFlowHparams (ModelHparams ):
162206 aa_topology_path : str
@@ -181,12 +225,14 @@ class SplitFlowHparams(ModelHparams):
181225class SplitFlow (Model [SplitFlowHparams ], ContinuousFlowMixin ):
182226 hparams_schema = SplitFlowHparams
183227
184- def __init__ (self , thparams : SplitFlowHparams ):
185- super (SplitFlow , self ).__init__ (thparams )
228+ def __init__ (self , hparams : SplitFlowHparams ):
229+ super (SplitFlow , self ).__init__ (hparams )
186230
187231 # Load the all-atom and coarse-grained topologies
188232 self .mol_aa = Molecule (self .thparams .aa_topology_path )
189233 self .mol_cg = Molecule (self .thparams .cg_topology_path )
234+ self .top_aa = md .load_topology (self .thparams .aa_topology_path )
235+ self .top_cg = md .load_topology (self .thparams .cg_topology_path )
190236
191237 # Define the CG mapping
192238 if not hasattr (self .thparams , "cg_map_matrix_path" ):
@@ -266,9 +312,7 @@ def velocity(self, xt: Tensor, t: Tensor) -> Tensor:
266312
267313 return self .velo_net (xt , t )
268314
269- def compute_metrics (
270- self , batch : tuple [Tensor , ...], batch_idx : int
271- ) -> dict [str , Tensor ]:
315+ def compute_metrics (self , batch : tuple [Tensor , ...], batch_idx : int ) -> dict [str , Tensor ]:
272316 """Compute training/validation metrics.
273317
274318 :param batch: Batch data tuple, expecting (r,) where r is a Tensor.
@@ -291,6 +335,11 @@ def compute_metrics(
291335 metrics ["loss_fm" ] = sum_except_batch (torch .pow (vt_hat - vt , 2 )).mean ()
292336 metrics ["loss" ] += metrics ["loss_fm" ]
293337
338+ if not self .training :
339+ x1 = self .compute_flow (x0 , return_intermediate = False , verbose = False )
340+ traj = md .Trajectory (x1 .cpu ().numpy (), self .top_aa )
341+ metrics ["ged" ] = torch .mean (torch .tensor (graph_edit_distance (traj = traj , verbose = False )))
342+
294343 return metrics
295344
296345 @property
@@ -322,3 +371,36 @@ def indices_split(self) -> tuple[Tensor, Tensor]:
322371 noise_indices = torch .tensor (noise_list , device = self .device , dtype = torch .long )
323372
324373 return cg_indices , noise_indices
374+
375+ def fit_latent_gmm (
376+ self ,
377+ r : Tensor ,
378+ n_components : int ,
379+ chunk_size : int | None = None ,
380+ verbose : bool = False ,
381+ * args ,
382+ ** kwargs ,
383+ ) -> GaussianMixture :
384+ """Fit a Gaussian Mixture Model to the latent representations of the fine-grained data.
385+
386+ :param r: fine-grained configurations
387+ :param n_components: number of GMM components
388+ :param chunk_size: chunk size for processing data in batches
389+ :param verbose: whether to display a progress bar
390+ :param args: additional arguments for GaussianMixture
391+ :param kwargs: additional keyword arguments for GaussianMixture
392+ :return: fitted GMM"""
393+
394+ from sklearn .mixture import GaussianMixture
395+
396+ with torch .no_grad ():
397+ x1 = r .to (self .device )
398+ x0 = self .compute_flow (x1 , reverse = True , chunk_size = chunk_size , verbose = verbose ).cpu ()
399+ eps_sn = self .noise .to_standard_normal (x0 )[:, self .indices_split [1 ].cpu (), :].view (
400+ x0 .shape [0 ], - 1
401+ )
402+
403+ gmm = GaussianMixture (n_components = n_components , * args , ** kwargs )
404+ gmm .fit (eps_sn .numpy ())
405+
406+ return gmm
0 commit comments