Skip to content

Commit 58da838

Browse files
add GED metrics; add GMM method
1 parent b148e84 commit 58da838

File tree

3 files changed

+305
-16
lines changed

3 files changed

+305
-16
lines changed

split_flows/models/split_flow.py

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
import torch.nn as nn
1515
from torch.distributions import Distribution
1616

17-
from egnn_pytorch import EGNN_Network
17+
import mdtraj as md
18+
from moleculekit.molecule import Molecule
19+
20+
from egnn_pytorch import EGNN_Network, EGNN
1821
from sklearn.mixture import GaussianMixture
1922
from hydrantic.model import Model, ModelHparams
20-
from moleculekit.molecule import Molecule
2123

2224
from split_flows.utils.interpolant import Interpolation, Interpolant
2325
from split_flows.mixins.continuous_flow import ContinuousFlowMixin
2426
from split_flows.utils.utils import to_one_hot, sum_except_batch, match_dims
27+
from split_flows.utils.metrics import graph_edit_distance
2528

2629

2730
logging.basicConfig(level=logging.INFO)
@@ -72,9 +75,9 @@ def augment_gmm(self, R: Tensor, gmm: GaussianMixture) -> Tensor:
7275
:return: Full set of coordinates with noise."""
7376

7477
z = torch.empty((R.shape[0], self.num_particles, R.shape[2]), device=R.device)
75-
z_gmm = torch.tensor(
76-
gmm.sample(R.shape[0])[0], dtype=R.dtype, device=R.device
77-
).view(R.shape[0], -1, 3)
78+
z_gmm = torch.tensor(gmm.sample(R.shape[0])[0], dtype=R.dtype, device=R.device).view(
79+
R.shape[0], -1, 3
80+
)
7881

7982
start_idx = 0
8083
for i, (cg_idx, noise_idx) in enumerate(self.latent_groupings):
@@ -115,9 +118,7 @@ def log_prob(self, value: Tensor) -> Tensor:
115118
R_cg = value[:, cg_idx, :]
116119
R_noise = value[:, noise_idx, :]
117120
exponential_term = (
118-
-0.5
119-
* sum_except_batch((R_noise - R_cg[:, None, :]) ** 2)
120-
/ self.scale**2
121+
-0.5 * sum_except_batch((R_noise - R_cg[:, None, :]) ** 2) / self.scale**2
121122
)
122123
normalization_term = -torch.log(Z) * R_noise.shape[1]
123124
log_prob += exponential_term + normalization_term
@@ -150,13 +151,56 @@ def __init__(
150151
self.atom_embedding = nn.Linear(atom_types.size(-1), self.dim)
151152
self.bead_embedding = nn.Linear(bead_types.size(-1), self.dim)
152153

154+
self._init_weights()
155+
153156
def forward(self, x: Tensor, t: Tensor) -> Tensor:
154157
atom_embeddings = self.atom_embedding(self.atom_types).repeat(x.size(0), 1, 1)
155158
bead_embeddings = self.bead_embedding(self.bead_types).repeat(x.size(0), 1, 1)
156159
t = match_dims(t, x).repeat(1, x.shape[1], 1)
157160
h = torch.cat([atom_embeddings, bead_embeddings, t], dim=-1)
158161
return self.net(h + torch.randn_like(h), x)[1]
159162

163+
def _init_weights(self):
164+
"""Initialize weights following EGNN best practices.
165+
166+
- Message MLPs (phi_e, phi_h): Xavier uniform
167+
- Coordinate MLP last layer: scaled down by 0.01
168+
- Biases: zero initialization
169+
- Embedding layers: Xavier uniform
170+
"""
171+
172+
# Initialize embedding layers with Xavier uniform
173+
nn.init.xavier_uniform_(self.atom_embedding.weight)
174+
nn.init.zeros_(self.atom_embedding.bias)
175+
nn.init.xavier_uniform_(self.bead_embedding.weight)
176+
nn.init.zeros_(self.bead_embedding.bias)
177+
178+
# Initialize EGNN layers
179+
for layer in self.net.layers:
180+
if isinstance(layer, EGNN):
181+
# Message MLPs (phi_e, phi_h) - Xavier uniform
182+
for module in [layer.edge_mlp, layer.node_mlp]:
183+
for m in module.modules():
184+
if isinstance(m, nn.Linear):
185+
nn.init.xavier_uniform_(m.weight)
186+
if m.bias is not None:
187+
nn.init.zeros_(m.bias)
188+
189+
# Coordinate MLP - Xavier uniform for all but last layer
190+
coord_mlp_layers = list(layer.coors_mlp.modules())
191+
linear_layers = [m for m in coord_mlp_layers if isinstance(m, nn.Linear)]
192+
193+
for i, m in enumerate(linear_layers):
194+
if i == len(linear_layers) - 1:
195+
# Last layer: scale down by 0.01
196+
nn.init.xavier_uniform_(m.weight)
197+
m.weight.data *= 0.01
198+
else:
199+
nn.init.xavier_uniform_(m.weight)
200+
201+
if m.bias is not None:
202+
nn.init.zeros_(m.bias)
203+
160204

161205
class SplitFlowHparams(ModelHparams):
162206
aa_topology_path: str
@@ -181,12 +225,14 @@ class SplitFlowHparams(ModelHparams):
181225
class SplitFlow(Model[SplitFlowHparams], ContinuousFlowMixin):
182226
hparams_schema = SplitFlowHparams
183227

184-
def __init__(self, thparams: SplitFlowHparams):
185-
super(SplitFlow, self).__init__(thparams)
228+
def __init__(self, hparams: SplitFlowHparams):
229+
super(SplitFlow, self).__init__(hparams)
186230

187231
# Load the all-atom and coarse-grained topologies
188232
self.mol_aa = Molecule(self.thparams.aa_topology_path)
189233
self.mol_cg = Molecule(self.thparams.cg_topology_path)
234+
self.top_aa = md.load_topology(self.thparams.aa_topology_path)
235+
self.top_cg = md.load_topology(self.thparams.cg_topology_path)
190236

191237
# Define the CG mapping
192238
if not hasattr(self.thparams, "cg_map_matrix_path"):
@@ -266,9 +312,7 @@ def velocity(self, xt: Tensor, t: Tensor) -> Tensor:
266312

267313
return self.velo_net(xt, t)
268314

269-
def compute_metrics(
270-
self, batch: tuple[Tensor, ...], batch_idx: int
271-
) -> dict[str, Tensor]:
315+
def compute_metrics(self, batch: tuple[Tensor, ...], batch_idx: int) -> dict[str, Tensor]:
272316
"""Compute training/validation metrics.
273317
274318
:param batch: Batch data tuple, expecting (r,) where r is a Tensor.
@@ -291,6 +335,11 @@ def compute_metrics(
291335
metrics["loss_fm"] = sum_except_batch(torch.pow(vt_hat - vt, 2)).mean()
292336
metrics["loss"] += metrics["loss_fm"]
293337

338+
if not self.training:
339+
x1 = self.compute_flow(x0, return_intermediate=False, verbose=False)
340+
traj = md.Trajectory(x1.cpu().numpy(), self.top_aa)
341+
metrics["ged"] = torch.mean(torch.tensor(graph_edit_distance(traj=traj, verbose=False)))
342+
294343
return metrics
295344

296345
@property
@@ -322,3 +371,36 @@ def indices_split(self) -> tuple[Tensor, Tensor]:
322371
noise_indices = torch.tensor(noise_list, device=self.device, dtype=torch.long)
323372

324373
return cg_indices, noise_indices
374+
375+
def fit_latent_gmm(
376+
self,
377+
r: Tensor,
378+
n_components: int,
379+
chunk_size: int | None = None,
380+
verbose: bool = False,
381+
*args,
382+
**kwargs,
383+
) -> GaussianMixture:
384+
"""Fit a Gaussian Mixture Model to the latent representations of the fine-grained data.
385+
386+
:param r: fine-grained configurations
387+
:param n_components: number of GMM components
388+
:param chunk_size: chunk size for processing data in batches
389+
:param verbose: whether to display a progress bar
390+
:param args: additional arguments for GaussianMixture
391+
:param kwargs: additional keyword arguments for GaussianMixture
392+
:return: fitted GMM"""
393+
394+
from sklearn.mixture import GaussianMixture
395+
396+
with torch.no_grad():
397+
x1 = r.to(self.device)
398+
x0 = self.compute_flow(x1, reverse=True, chunk_size=chunk_size, verbose=verbose).cpu()
399+
eps_sn = self.noise.to_standard_normal(x0)[:, self.indices_split[1].cpu(), :].view(
400+
x0.shape[0], -1
401+
)
402+
403+
gmm = GaussianMixture(n_components=n_components, *args, **kwargs)
404+
gmm.fit(eps_sn.numpy())
405+
406+
return gmm

split_flows/utils/metrics.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from tqdm import tqdm
2+
3+
import mdtraj as md
4+
import numpy as np
5+
import torch
6+
7+
8+
def compute_cg_rmsd(
9+
traj1: md.Trajectory, traj2: md.Trajectory, indices: list[int], splits: int = 5
10+
) -> list[float]:
11+
"""Computes the root mean squared deviation of configurations of the provided trajectories
12+
in the coarse-grained space.
13+
14+
:param traj1: first trajectory
15+
:param traj2: second trajectory
16+
:param indices: indices of atoms to retain in the coarse-grained representation
17+
:param splits: number of splits to average over
18+
:return: list of RMSDs of coarse-grained representations per split"""
19+
20+
split_size = traj1.n_frames // splits
21+
22+
rmsd_list = []
23+
for i in range(splits):
24+
rmsd_list_split = []
25+
start = i * split_size
26+
end = (i + 1) * split_size if i < splits - 1 else traj1.n_frames
27+
for j in tqdm(range(start, end)):
28+
rmsd_list_split.append(md.rmsd(traj2[j], traj1[j], 0, atom_indices=indices))
29+
30+
rmsd_list.append(np.mean(rmsd_list_split))
31+
32+
return rmsd_list
33+
34+
35+
COVCUTOFFTABLE = {
36+
1: 0.23,
37+
2: 0.93,
38+
3: 0.68,
39+
4: 0.35,
40+
5: 0.83,
41+
6: 0.68,
42+
7: 0.68,
43+
8: 0.68,
44+
9: 0.64,
45+
10: 1.12,
46+
11: 0.97,
47+
12: 1.1,
48+
13: 1.35,
49+
14: 1.2,
50+
15: 0.75,
51+
16: 1.02,
52+
17: 0.99,
53+
18: 1.57,
54+
19: 1.33,
55+
20: 0.99,
56+
21: 1.44,
57+
22: 1.47,
58+
23: 1.33,
59+
24: 1.35,
60+
25: 1.35,
61+
26: 1.34,
62+
27: 1.33,
63+
28: 1.5,
64+
29: 1.52,
65+
30: 1.45,
66+
31: 1.22,
67+
32: 1.17,
68+
33: 1.21,
69+
34: 1.22,
70+
35: 1.21,
71+
36: 1.91,
72+
37: 1.47,
73+
38: 1.12,
74+
39: 1.78,
75+
40: 1.56,
76+
41: 1.48,
77+
42: 1.47,
78+
43: 1.35,
79+
44: 1.4,
80+
45: 1.45,
81+
46: 1.5,
82+
47: 1.59,
83+
48: 1.69,
84+
49: 1.63,
85+
50: 1.46,
86+
51: 1.46,
87+
52: 1.47,
88+
53: 1.4,
89+
54: 1.98,
90+
55: 1.67,
91+
56: 1.34,
92+
57: 1.87,
93+
58: 1.83,
94+
59: 1.82,
95+
60: 1.81,
96+
61: 1.8,
97+
62: 1.8,
98+
63: 1.99,
99+
64: 1.79,
100+
65: 1.76,
101+
66: 1.75,
102+
67: 1.74,
103+
68: 1.73,
104+
69: 1.72,
105+
70: 1.94,
106+
71: 1.72,
107+
72: 1.57,
108+
73: 1.43,
109+
74: 1.37,
110+
75: 1.35,
111+
76: 1.37,
112+
77: 1.32,
113+
78: 1.5,
114+
79: 1.5,
115+
80: 1.7,
116+
81: 1.55,
117+
82: 1.54,
118+
83: 1.54,
119+
84: 1.68,
120+
85: 1.7,
121+
86: 2.4,
122+
87: 2.0,
123+
88: 1.9,
124+
89: 1.88,
125+
90: 1.79,
126+
91: 1.61,
127+
92: 1.58,
128+
93: 1.55,
129+
94: 1.53,
130+
95: 1.51,
131+
96: 1.5,
132+
97: 1.5,
133+
98: 1.5,
134+
99: 1.5,
135+
100: 1.5,
136+
101: 1.5,
137+
102: 1.5,
138+
103: 1.5,
139+
104: 1.57,
140+
105: 1.49,
141+
106: 1.43,
142+
107: 1.41,
143+
}
144+
145+
146+
def compute_bond_cutoff_mdtraj(topology, scale=1.3):
147+
"""Compute bond cutoffs for MDTraj topology"""
148+
atomic_nums = [atom.element.atomic_number for atom in topology.atoms]
149+
# COVCUTOFFTABLE values are in Angstroms, convert to nanometers for MDTraj
150+
vdw_array = torch.Tensor([COVCUTOFFTABLE[int(el)] / 10.0 for el in atomic_nums]) # Å to nm
151+
152+
cutoff_array = (vdw_array[None, :] + vdw_array[:, None]) * scale
153+
154+
return cutoff_array
155+
156+
157+
def compute_distance_mat_mdtraj(xyz, device="cpu"):
158+
"""Compute distance matrix from xyz coordinates"""
159+
xyz_tensor = torch.Tensor(xyz).to(device)
160+
dist = (xyz_tensor[:, None, :] - xyz_tensor[None, :, :]).pow(2).sum(-1).sqrt()
161+
162+
return dist
163+
164+
165+
def get_bond_graphs_mdtraj(traj, frame_idx=0, device="cpu", scale=1.3):
166+
"""Get bond graph for a specific frame in MDTraj trajectory"""
167+
xyz = traj.xyz[frame_idx] # coordinates for specific frame
168+
dist = compute_distance_mat_mdtraj(xyz, device=device)
169+
cutoff = compute_bond_cutoff_mdtraj(traj.topology, scale=scale)
170+
bond_mat = dist < cutoff.to(device)
171+
bond_mat[np.diag_indices(traj.n_atoms)] = 0
172+
173+
del dist, cutoff
174+
175+
return bond_mat.to(torch.long).to("cpu")
176+
177+
178+
def compare_graph_mdtraj(ref_traj, traj, ref_frame=0, frame=0, scale=1.3):
179+
"""Compare bond graphs between two MDTraj trajectory frames"""
180+
ref_bonds = get_bond_graphs_mdtraj(ref_traj, frame_idx=ref_frame, scale=scale)
181+
bonds = get_bond_graphs_mdtraj(traj, frame_idx=frame, scale=scale)
182+
183+
diff = (bonds != ref_bonds).sum().item()
184+
185+
return diff
186+
187+
188+
def graph_edit_distance(
189+
traj: md.Trajectory, scale: float = 1.3, verbose: bool = True
190+
) -> list[float]:
191+
"""Compare trajectory bond graphs to topology bonds."""
192+
n_atoms = traj.n_atoms
193+
# Create reference adjacency matrix from topology bonds
194+
A_ref = np.zeros((n_atoms, n_atoms), dtype=int)
195+
for bond in traj.topology.bonds:
196+
i, j = bond[0].index, bond[1].index
197+
A_ref[i, j] = 1
198+
A_ref[j, i] = 1
199+
n_bonds = A_ref.sum()
200+
201+
ged_list = []
202+
iterator = range(traj.n_frames)
203+
if verbose:
204+
iterator = tqdm(iterator, desc="Computing GED")
205+
for i in iterator:
206+
A = get_bond_graphs_mdtraj(traj, frame_idx=i, scale=scale).numpy()
207+
ged = np.abs((A - A_ref).sum()) / n_bonds
208+
ged_list.append(ged)
209+
return ged_list

split_flows/utils/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,5 @@ def gradient(
8080

8181
if grad_outputs is None:
8282
grad_outputs = torch.ones_like(output).detach()
83-
grad = torch.autograd.grad(
84-
output, x, grad_outputs=grad_outputs, create_graph=create_graph
85-
)[0]
83+
grad = torch.autograd.grad(output, x, grad_outputs=grad_outputs, create_graph=create_graph)[0]
8684
return grad

0 commit comments

Comments
 (0)