Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ DOI: 10.1021/acs.jctc.2c01297`
+ [2. Installation](docs/user_guide/2.installation.md)
+ [3. Basic Usage](docs/user_guide/3.usage.md)
+ [4. Modules](docs/user_guide/4.modules.md)
+ [Classical](docs/user_guide/4.1classical.md)
+ [ADMP](docs/user_guide/4.2ADMPPmeForce.md)
+ [Qeq](docs/user_guide/4.3ADMPQeqForce.md)
+ [Machine Learning](docs/user_guide/4.4MLForce.md)
+ [Optimization](docs/user_guide/4.5Optimization.md)
+ [Mbar Estimator](docs/user_guide/4.6MBAR.md)
+ [OpenMM Plugin](docs/user_guide/4.7OpenMMplugin.md)
+ [DiffTraj](docs/user_guide/4.8DiffTraj.md)
+ [ASE MD interface](docs/user_guide/4.9ASE.md)
+ [4.1 Classical](docs/user_guide/4.1classical.md)
+ [4.2 ADMP](docs/user_guide/4.2ADMPPmeForce.md)
+ [4.3 Qeq](docs/user_guide/4.3ADMPQeqForce.md)
+ [4.4 Machine Learning Potential](docs/user_guide/4.4MLForce.md)
+ [4.5 Custom Torch Force](docs/user_guide/4.5CustomTorch.md)
+ [4.6 Optimization](docs/user_guide/4.6Optimization.md)
+ [4.7 Mbar Estimator](docs/user_guide/4.7MBAR.md)
+ [4.8 OpenMM Plugin](docs/user_guide/4.8OpenMMplugin.md)
+ [4.9 DiffTraj](docs/user_guide/4.9DiffTraj.md)
+ [4.10 ASE MD interface](docs/user_guide/4.10ASE.md)
+ [5. Advanced examples](docs/user_guide/DMFF_example.ipynb)
+ [And here is a tutorial notebook of the basic usage of DMFF. Welcome to read it and get started with DMFF!](docs/user_guide/test.ipynb)

Expand Down
6 changes: 5 additions & 1 deletion dmff/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import numpy as np

DIELECTRIC = 1389.35455846
SQRT_PI = np.sqrt(np.pi)
SQRT_PI = np.sqrt(np.pi)

# units
EV2KJ = 96.48530749925791

206 changes: 206 additions & 0 deletions dmff/generators/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,33 @@
import openmm.app as app
import openmm.unit as unit
import pickle
import re
from functools import partial
from collections import OrderedDict
import copy

from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb
from ..sgnn.gnn import MolGNNForce, prm_transform_f2i
from ..eann.eann import EANNForce, get_elem_indices
from ..api.topology import elem_to_index
from ..common.constants import EV2KJ

# load torch-related module
try:
import torch
import torch.nn as nn
from ..torch_tools import t2j_pytree, j2t_pytree, wrap_torch_potential_kernel, t2j_extract_grad
from torch2jax import t2j, j2t
except ImportError:
pass

# load base-related module
try:
from base.inference.calculator import get_parser
import pymatgen.core.structure
except ImportError:
pass


class SGNNGenerator:
def __init__(self, ffinfo: dict, paramset: ParamSet):
Expand Down Expand Up @@ -139,3 +162,186 @@ def getJaxPotential(self):
_DMFFGenerators["EANNForce"] = EANNGenerator


class CustomTorchGenerator:

def __init__(self, ffinfo: dict, paramset: ParamSet, dtype=None):
"""
A custom torch model is specified by a full model checkpoint file
The xml front end should be:

```xml
<ForceField>
<CustomTorchForce ckpt="model.pt" torch_script="True" dtype="float32"/>
</ForceField>
```
"""

self.name = "CustomTorchForce"
self.ffinfo = ffinfo
paramset.addField(self.name)
self.key_type = None
ffmeta = self.ffinfo["Forces"][self.name]["meta"]
self.ckpt_file = None
self.state_dict_file = None
self.config_file = None
self.torch_script = False

if dtype is None:
self.dtype = torch.float32
else:
self.dtype = dtype
# precision
if "dtype" in ffmeta["dtype"]:
if '32' in ffmeta["dtype"]:
self.dtype = torch.float32
elif '64' in ffmeta["dtype"]:
self.dtype = torch.float64
self.ckpt_file = ffmeta["ckpt"]
self.torch_script = (ffmeta["torch_script"] == 'True')

if self.torch_script:
self.load_method = torch.jit.load
self.save_method = torch.jit.save
else:
self.load_method = partial(torch.load, weights_only=False)
self.save_method = torch.save

self.model = self._initialize_model()

# now model is fully loaded, start to register parameters
named_parameters = self.model.named_parameters()
self.params_t = OrderedDict()
for name, param in named_parameters:
self.params_t[name] = param
self.params = t2j_pytree(self.params_t)
for k in self.params:
# set mask to all true
paramset.addParameter(self.params[k], k, field=self.name, mask=jnp.ones(self.params[k].shape))

self.params_noopt = OrderedDict()
state_dict = self.model.state_dict()
for k in state_dict:
if k not in self.params_t:
self.params_noopt[k] = state_dict[k]
return


def _initialize_model(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
model = self.load_method(self.ckpt_file, map_location=self.device)
# state dictionary file
self.state_dict_file = re.sub('.([a-zA-Z0-9]+)$', '_sd.\g<1>', self.ckpt_file)
return model

def getName(self) -> str:
return self.name

def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs):
self.topdata = topdata
self.n_atoms = topdata.getNumAtoms()

# topo(primarily atom type) data
self.atom_types = []
for atom in topdata.atoms():
element = atom.element.upper()
self.atom_types.append(elem_to_index[element])
self.atom_types = np.array(self.atom_types)

# torch kernel
def potential_torch_kernel(positions, box, pairs, params):

# load parameter to model
if self.name in params.keys():
state_dict = copy.deepcopy(params[self.name])
else:
state_dict = copy.deepcopy(params)
for k in self.params_noopt:
state_dict[k] = self.params_noopt[k]

# build a model object for every invokation to avoid gradient accumulation
model = copy.deepcopy(self.model)
model.load_state_dict(state_dict)
results = model.forward(positions, box, self.atom_types)
return results, model

# jax wrapper
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def potential_fn(position, box, pairs, params):
position_t = j2t(position)
box_t = j2t(box)
params_t = j2t_pytree(params)
results, model = potential_torch_kernel(position_t, box_t, None, params_t)
return t2j(result['pred_energy'])

def potential_fwd(positions, box, pairs, params):
# gradient of positions and box will be computed internally and returned
# by force an virial
position_t = j2t(positions).detach()
box_t = j2t(box).detach()
position_t.requires_grad_(False)
box_t.requires_grad_(False)
params_t = j2t_pytree(params)
result, model = potential_torch_kernel(position_t, box_t, None, params_t)
model.zero_grad()
result['pred_energy'].backward()

inputs = {'pos': positions,
'box': box,
'params': params
}
energy = t2j(result['pred_energy'])
dE_dp = jax.tree.map(lambda x: jnp.zeros(x.shape), inputs['params'])
# read parameter gradient from the model
for name, param in model.named_parameters():
dE_dp[self.name][name] = t2j_extract_grad(param)
return energy, (t2j_pytree(result), inputs, dE_dp)

def potential_bwd(pairs, res, g):
preds = res[0]
inputs = res[1]
dE_dp = res[2]
force = preds['pred_forces']
# virial is in kJ/mol
virial = preds['pred_virial']
pos = inputs['pos'] # in nm
box = inputs['box'] # in nm
box_inv = jnp.linalg.inv(box) # in nm-1
# force in kJ/mol/nm, positions in nm
dE_dB = box_inv.T@(pos.T@force - virial)
dE_dr = -force
# unit conversion from eV/A to kJ/mol/nm
return dE_dr*g, dE_dB, jax.tree.map(lambda x: x*g, dE_dp)

potential_fn.defvjp(potential_fwd, potential_bwd)

return potential_fn


def write_to(self, params, ckpt_file, state_dict_file):
if self.name in params:
self.params = params[self.name]
else:
self.params = params
self.params_t = j2t_pytree(self.params)
state_dict = copy.deepcopy(self.params_t)
for k in self.params_noopt:
state_dict[k] = self.params_noopt[k]
# save state dictiontary file
torch.save(state_dict, state_dict_file)
# save the full model checkpoint
self.model.load_state_dict(state_dict)
self.save_method(self.model, ckpt_file)
return

def overwrite(self, params):
# do not use xml to handle ML potentials
# for ML potentials, xml only documents param file path
# so for ML potentials, overwrite function overwrites the file directly
self.write_to(params, self.ckpt_file, self.state_dict_file)
return


def getJaxPotential(self):
return self._jaxPotential

_DMFFGenerators["CustomTorchForce"] = CustomTorchGenerator
21 changes: 0 additions & 21 deletions docs/user_guide/3.usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,25 +116,4 @@ print(pgrad["NonbondedForce"]["sigma"])
0.00000000e+00]
```

### 3.4 Wrapping torch potential kernel

Considering the popularity of pytorch, DMFF also provides a convenient wrapper to wrap a torch potential kernel into a DMFF potential function that is compatible with the JAX environment. It uses the [torch2jax](https://github.com/samuela/torch2jax) package to convert between tensors and jax ndarrays, and the `custom_jvp` function in jax to call torch gradient function, so one does not need to rewrite a torch potential kernel using jax. It may not be the most efficient way, but should work for efficiency-insensitive scenarios.

Examples can be found in the `examples/torch_kernel` folder. And the basic usage is like below:

```python
from dmff.torch_tools import wrap_torch_potential_kernel

def potential_t(positions, box, pairs, params):
# Suppose potential_t is a torch kernel, with positions and box being torch tensors
# and params is a multi-level dictionary with torch-tensor leaves
...
return ene

# Wrap it to make it compatible with jax environment
potential_wrapped = wrap_torch_potential_kernel(potential_t)
# then you can use it as a normal jax-based potential function, which can be fed to jax.grad
ene, p_grad = jax.value_and_grad(potential_wrapped, argnums=3)(positions, box, nbl.pairs, params)

```

File renamed without changes.
Loading