diff --git a/README.md b/README.md
index 296a08f3..fffb0930 100644
--- a/README.md
+++ b/README.md
@@ -20,15 +20,16 @@ DOI: 10.1021/acs.jctc.2c01297`
+ [2. Installation](docs/user_guide/2.installation.md)
+ [3. Basic Usage](docs/user_guide/3.usage.md)
+ [4. Modules](docs/user_guide/4.modules.md)
- + [Classical](docs/user_guide/4.1classical.md)
- + [ADMP](docs/user_guide/4.2ADMPPmeForce.md)
- + [Qeq](docs/user_guide/4.3ADMPQeqForce.md)
- + [Machine Learning](docs/user_guide/4.4MLForce.md)
- + [Optimization](docs/user_guide/4.5Optimization.md)
- + [Mbar Estimator](docs/user_guide/4.6MBAR.md)
- + [OpenMM Plugin](docs/user_guide/4.7OpenMMplugin.md)
- + [DiffTraj](docs/user_guide/4.8DiffTraj.md)
- + [ASE MD interface](docs/user_guide/4.9ASE.md)
+ + [4.1 Classical](docs/user_guide/4.1classical.md)
+ + [4.2 ADMP](docs/user_guide/4.2ADMPPmeForce.md)
+ + [4.3 Qeq](docs/user_guide/4.3ADMPQeqForce.md)
+ + [4.4 Machine Learning Potential](docs/user_guide/4.4MLForce.md)
+ + [4.5 Custom Torch Force](docs/user_guide/4.5CustomTorch.md)
+ + [4.6 Optimization](docs/user_guide/4.6Optimization.md)
+ + [4.7 Mbar Estimator](docs/user_guide/4.7MBAR.md)
+ + [4.8 OpenMM Plugin](docs/user_guide/4.8OpenMMplugin.md)
+ + [4.9 DiffTraj](docs/user_guide/4.9DiffTraj.md)
+ + [4.10 ASE MD interface](docs/user_guide/4.10ASE.md)
+ [5. Advanced examples](docs/user_guide/DMFF_example.ipynb)
+ [And here is a tutorial notebook of the basic usage of DMFF. Welcome to read it and get started with DMFF!](docs/user_guide/test.ipynb)
diff --git a/dmff/common/constants.py b/dmff/common/constants.py
index 91bdeb4a..286861a1 100644
--- a/dmff/common/constants.py
+++ b/dmff/common/constants.py
@@ -1,4 +1,8 @@
import numpy as np
DIELECTRIC = 1389.35455846
-SQRT_PI = np.sqrt(np.pi)
\ No newline at end of file
+SQRT_PI = np.sqrt(np.pi)
+
+# units
+EV2KJ = 96.48530749925791
+
diff --git a/dmff/generators/ml.py b/dmff/generators/ml.py
index d875bd0a..7689da7a 100644
--- a/dmff/generators/ml.py
+++ b/dmff/generators/ml.py
@@ -9,10 +9,33 @@
import openmm.app as app
import openmm.unit as unit
import pickle
+import re
+from functools import partial
+from collections import OrderedDict
+import copy
from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb
from ..sgnn.gnn import MolGNNForce, prm_transform_f2i
from ..eann.eann import EANNForce, get_elem_indices
+from ..api.topology import elem_to_index
+from ..common.constants import EV2KJ
+
+# load torch-related module
+try:
+ import torch
+ import torch.nn as nn
+ from ..torch_tools import t2j_pytree, j2t_pytree, wrap_torch_potential_kernel, t2j_extract_grad
+ from torch2jax import t2j, j2t
+except ImportError:
+ pass
+
+# load base-related module
+try:
+ from base.inference.calculator import get_parser
+ import pymatgen.core.structure
+except ImportError:
+ pass
+
class SGNNGenerator:
def __init__(self, ffinfo: dict, paramset: ParamSet):
@@ -139,3 +162,186 @@ def getJaxPotential(self):
_DMFFGenerators["EANNForce"] = EANNGenerator
+class CustomTorchGenerator:
+
+ def __init__(self, ffinfo: dict, paramset: ParamSet, dtype=None):
+ """
+ A custom torch model is specified by a full model checkpoint file
+ The xml front end should be:
+
+ ```xml
+
+
+
+ ```
+ """
+
+ self.name = "CustomTorchForce"
+ self.ffinfo = ffinfo
+ paramset.addField(self.name)
+ self.key_type = None
+ ffmeta = self.ffinfo["Forces"][self.name]["meta"]
+ self.ckpt_file = None
+ self.state_dict_file = None
+ self.config_file = None
+ self.torch_script = False
+
+ if dtype is None:
+ self.dtype = torch.float32
+ else:
+ self.dtype = dtype
+ # precision
+ if "dtype" in ffmeta["dtype"]:
+ if '32' in ffmeta["dtype"]:
+ self.dtype = torch.float32
+ elif '64' in ffmeta["dtype"]:
+ self.dtype = torch.float64
+ self.ckpt_file = ffmeta["ckpt"]
+ self.torch_script = (ffmeta["torch_script"] == 'True')
+
+ if self.torch_script:
+ self.load_method = torch.jit.load
+ self.save_method = torch.jit.save
+ else:
+ self.load_method = partial(torch.load, weights_only=False)
+ self.save_method = torch.save
+
+ self.model = self._initialize_model()
+
+ # now model is fully loaded, start to register parameters
+ named_parameters = self.model.named_parameters()
+ self.params_t = OrderedDict()
+ for name, param in named_parameters:
+ self.params_t[name] = param
+ self.params = t2j_pytree(self.params_t)
+ for k in self.params:
+ # set mask to all true
+ paramset.addParameter(self.params[k], k, field=self.name, mask=jnp.ones(self.params[k].shape))
+
+ self.params_noopt = OrderedDict()
+ state_dict = self.model.state_dict()
+ for k in state_dict:
+ if k not in self.params_t:
+ self.params_noopt[k] = state_dict[k]
+ return
+
+
+ def _initialize_model(self):
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ model = self.load_method(self.ckpt_file, map_location=self.device)
+ # state dictionary file
+ self.state_dict_file = re.sub('.([a-zA-Z0-9]+)$', '_sd.\g<1>', self.ckpt_file)
+ return model
+
+ def getName(self) -> str:
+ return self.name
+
+ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs):
+ self.topdata = topdata
+ self.n_atoms = topdata.getNumAtoms()
+
+ # topo(primarily atom type) data
+ self.atom_types = []
+ for atom in topdata.atoms():
+ element = atom.element.upper()
+ self.atom_types.append(elem_to_index[element])
+ self.atom_types = np.array(self.atom_types)
+
+ # torch kernel
+ def potential_torch_kernel(positions, box, pairs, params):
+
+ # load parameter to model
+ if self.name in params.keys():
+ state_dict = copy.deepcopy(params[self.name])
+ else:
+ state_dict = copy.deepcopy(params)
+ for k in self.params_noopt:
+ state_dict[k] = self.params_noopt[k]
+
+ # build a model object for every invokation to avoid gradient accumulation
+ model = copy.deepcopy(self.model)
+ model.load_state_dict(state_dict)
+ results = model.forward(positions, box, self.atom_types)
+ return results, model
+
+ # jax wrapper
+ @partial(jax.custom_vjp, nondiff_argnums=(2,))
+ def potential_fn(position, box, pairs, params):
+ position_t = j2t(position)
+ box_t = j2t(box)
+ params_t = j2t_pytree(params)
+ results, model = potential_torch_kernel(position_t, box_t, None, params_t)
+ return t2j(result['pred_energy'])
+
+ def potential_fwd(positions, box, pairs, params):
+ # gradient of positions and box will be computed internally and returned
+ # by force an virial
+ position_t = j2t(positions).detach()
+ box_t = j2t(box).detach()
+ position_t.requires_grad_(False)
+ box_t.requires_grad_(False)
+ params_t = j2t_pytree(params)
+ result, model = potential_torch_kernel(position_t, box_t, None, params_t)
+ model.zero_grad()
+ result['pred_energy'].backward()
+
+ inputs = {'pos': positions,
+ 'box': box,
+ 'params': params
+ }
+ energy = t2j(result['pred_energy'])
+ dE_dp = jax.tree.map(lambda x: jnp.zeros(x.shape), inputs['params'])
+ # read parameter gradient from the model
+ for name, param in model.named_parameters():
+ dE_dp[self.name][name] = t2j_extract_grad(param)
+ return energy, (t2j_pytree(result), inputs, dE_dp)
+
+ def potential_bwd(pairs, res, g):
+ preds = res[0]
+ inputs = res[1]
+ dE_dp = res[2]
+ force = preds['pred_forces']
+ # virial is in kJ/mol
+ virial = preds['pred_virial']
+ pos = inputs['pos'] # in nm
+ box = inputs['box'] # in nm
+ box_inv = jnp.linalg.inv(box) # in nm-1
+ # force in kJ/mol/nm, positions in nm
+ dE_dB = box_inv.T@(pos.T@force - virial)
+ dE_dr = -force
+ # unit conversion from eV/A to kJ/mol/nm
+ return dE_dr*g, dE_dB, jax.tree.map(lambda x: x*g, dE_dp)
+
+ potential_fn.defvjp(potential_fwd, potential_bwd)
+
+ return potential_fn
+
+
+ def write_to(self, params, ckpt_file, state_dict_file):
+ if self.name in params:
+ self.params = params[self.name]
+ else:
+ self.params = params
+ self.params_t = j2t_pytree(self.params)
+ state_dict = copy.deepcopy(self.params_t)
+ for k in self.params_noopt:
+ state_dict[k] = self.params_noopt[k]
+ # save state dictiontary file
+ torch.save(state_dict, state_dict_file)
+ # save the full model checkpoint
+ self.model.load_state_dict(state_dict)
+ self.save_method(self.model, ckpt_file)
+ return
+
+ def overwrite(self, params):
+ # do not use xml to handle ML potentials
+ # for ML potentials, xml only documents param file path
+ # so for ML potentials, overwrite function overwrites the file directly
+ self.write_to(params, self.ckpt_file, self.state_dict_file)
+ return
+
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+_DMFFGenerators["CustomTorchForce"] = CustomTorchGenerator
diff --git a/docs/user_guide/3.usage.md b/docs/user_guide/3.usage.md
index 1c2af9f8..0314892f 100644
--- a/docs/user_guide/3.usage.md
+++ b/docs/user_guide/3.usage.md
@@ -116,25 +116,4 @@ print(pgrad["NonbondedForce"]["sigma"])
0.00000000e+00]
```
-### 3.4 Wrapping torch potential kernel
-
-Considering the popularity of pytorch, DMFF also provides a convenient wrapper to wrap a torch potential kernel into a DMFF potential function that is compatible with the JAX environment. It uses the [torch2jax](https://github.com/samuela/torch2jax) package to convert between tensors and jax ndarrays, and the `custom_jvp` function in jax to call torch gradient function, so one does not need to rewrite a torch potential kernel using jax. It may not be the most efficient way, but should work for efficiency-insensitive scenarios.
-
-Examples can be found in the `examples/torch_kernel` folder. And the basic usage is like below:
-
-```python
-from dmff.torch_tools import wrap_torch_potential_kernel
-
-def potential_t(positions, box, pairs, params):
- # Suppose potential_t is a torch kernel, with positions and box being torch tensors
- # and params is a multi-level dictionary with torch-tensor leaves
- ...
- return ene
-
-# Wrap it to make it compatible with jax environment
-potential_wrapped = wrap_torch_potential_kernel(potential_t)
-# then you can use it as a normal jax-based potential function, which can be fed to jax.grad
-ene, p_grad = jax.value_and_grad(potential_wrapped, argnums=3)(positions, box, nbl.pairs, params)
-
-```
diff --git a/docs/user_guide/4.9ASE.md b/docs/user_guide/4.10ASE.md
similarity index 100%
rename from docs/user_guide/4.9ASE.md
rename to docs/user_guide/4.10ASE.md
diff --git a/docs/user_guide/4.5CustomTorch.md b/docs/user_guide/4.5CustomTorch.md
new file mode 100644
index 00000000..c11fe8d0
--- /dev/null
+++ b/docs/user_guide/4.5CustomTorch.md
@@ -0,0 +1,141 @@
+# CustomTorchForce
+
+This module is designed to support pytorch machine learning potential (MLP) kernel.
+It aims to perform top-down finetuning for MLP models writen in torch
+A full simple example can be found in `examples/torch_kernel/torch_mlp_frontend`
+To use this module, you need to install both `torch` and [torch2jax](https://github.com/samuela/torch2jax)
+
+NOTE: currently we DO NOT support differentiating losses related to force and virial!
+
+## 1. Backend Definition
+
+To use this force, you need to first define a torch backend, which is a torch module in
+your own `force_model.py` file, like this:
+
+```python
+class ForceModel(torch.nn.Module):
+
+ def __init__(self):
+ super(ForceModel, self).__init__()
+ # define initilization method ...
+ return
+
+ def forward(self, positions, box, atomtypes):
+ """
+ define the calculation
+ """
+ results = {}
+ results['pred_energy'] = energy # kj/mol
+ results['pred_forces'] = forces # kj/mol/nm
+ results['pred_virial'] = virial # kj/mol, rij*fij
+ return results
+```
+
+The model should have all optimizable parameters stored as named parameters. And the `forward`
+function should have `positions` (in nm, shape=(N, 3)), `box` (in nm, shape=(3, 3), with lattice vectors
+defined in rows), and `atomtypes` (np array of atom numbers) as inputs. The outputs should be a
+dictionary including:
+
+* `results['pred_energy']`: predicted energy in kj/mol
+
+* `results['pred_forces']`: predicted atomic forces in kj/mol/nm
+
+* `results['pred_virial']`: predicted virial tensor in kJ/mol
+
+Currently, virial and forces are used to compute the differentiation of positions and box. If you
+do not need these two differentiations (as in most cases when doing thermodynamic reweighting), you can
+simply put zeros in these two terms as placeholders.
+
+Then you have to save the model as a full model checkpoint file:
+
+```python
+model = ForceModel()
+torch.save(model, 'model.pt')
+```
+
+Then DMFF can read the `model.pt` checkpoint to create a differentiable potential function that runs in JAX
+
+## 2. Frontend Definition
+
+The xml frontend of the CustomTorchForce should be defined as:
+
+```xml
+
+```
+
+Here, you need to specify:
+
+* `ckpt`: the model checkpoint file
+
+* `torch_script`: False or True. It should be true if you use `torch.jit.save` instead of `torch.save`
+
+* `dtype`: data type of the model, default should be float32, and use float64 if your model is created in double precision
+
+You also need a `structure.pdb` and `residues.xml` files to define the "topology" of the system. Note that most MLPs
+do not care about or use molecular topology. The system top definition in here only aims to provide element information.
+Therefore feel free to define each atomtype as a "residue" if you do not care about bonding topology.
+
+Then in the python script you can create the the potential function in standaard DMFF format:
+
+```python
+import jax
+import jax.numpy as jnp
+import dmff
+from dmff.api import Hamiltonian
+import torch
+
+# import your model
+from force_model import ForceModel
+
+# Hamiltonian and parameters
+H = Hamiltonian('forcefield.xml')
+params = H.getParameters().parameters
+app.Topology.loadBondDefinitions("residues.xml")
+pdb = app.PDBFile("structure.pdb")
+rc = 1.0
+
+# create potential object
+pots = H.createPotential(pdb.topology, \
+ nonbondedMethod=app.CutoffPeriodic, \
+ nonbondedCutoff=rc*unit.nanometer)
+efunc = pots.getPotentialFunc()
+positions = jnp.array(pdb.positions._value)
+box = jnp.array(pdb.topology.getPeriodicBoxVectors()._value)
+
+# calculate value and grad
+ene, pgrad = jax.value_and_grad(efunc, argnums=3)(positions, box, None, params)
+```
+
+NOTE: we assume you are dealing with the neighbor searching issue in your own model, so the `rc` variable and the `pairs` variables
+here are merely placeholders (note you can simply put None in the position of `pair`)
+
+You can also save the current parameters as both full model checkpoint file and state_dict file using the `write_to` or the standard
+`overwrite` functions:
+
+```python
+gen = H.getGenerators()[0]
+# save model to new.pt and state_dict to new_sd.pta
+gen.write_to(params, 'new.pt', 'new_sd.pt')
+```
+
+## 3. Convenient function wrappers
+
+DMFF also provides a convenient wrapper to wrap a torch potential into a DMFF potential function that is compatible with the JAX environment.
+
+Examples can be found in the `examples/torch_kernel` folder. And the basic usage is like below:
+
+```python
+from dmff.torch_tools import wrap_torch_potential_kernel
+
+def potential_t(positions, box, pairs, params):
+ # Suppose potential_t is a torch kernel, with positions and box being torch tensors
+ # and params is a multi-level dictionary with torch-tensor leaves
+ ...
+ return ene
+
+# Wrap it to make it compatible with jax environment
+potential_wrapped = wrap_torch_potential_kernel(potential_t)
+# then you can use it as a normal jax-based potential function, which can be fed to jax.grad
+ene, p_grad = jax.value_and_grad(potential_wrapped, argnums=3)(positions, box, nbl.pairs, params)
+
+```
diff --git a/docs/user_guide/4.5Optimization.md b/docs/user_guide/4.6Optimization.md
similarity index 100%
rename from docs/user_guide/4.5Optimization.md
rename to docs/user_guide/4.6Optimization.md
diff --git a/docs/user_guide/4.6MBAR.md b/docs/user_guide/4.7MBAR.md
similarity index 100%
rename from docs/user_guide/4.6MBAR.md
rename to docs/user_guide/4.7MBAR.md
diff --git a/docs/user_guide/4.7OpenMMplugin.md b/docs/user_guide/4.8OpenmMplugin.md
similarity index 100%
rename from docs/user_guide/4.7OpenMMplugin.md
rename to docs/user_guide/4.8OpenmMplugin.md
diff --git a/docs/user_guide/4.8DiffTraj.md b/docs/user_guide/4.9DiffTraj.md
similarity index 100%
rename from docs/user_guide/4.8DiffTraj.md
rename to docs/user_guide/4.9DiffTraj.md
diff --git a/examples/config/freud.ini b/examples/config/freud.ini
new file mode 100644
index 00000000..8cbf638d
--- /dev/null
+++ b/examples/config/freud.ini
@@ -0,0 +1,40 @@
+
+[LAYOUT]
+# Width of sidebar that contains server names
+server_width = 15
+# Height of container that holds headers
+header_height = 10
+# Height of container that shows server summary
+summary_height = 10
+
+[KEYS]
+new_server = n
+edit_server = e
+send_request = r
+edit_authentication = a
+edit_headers = h
+edit_body = b
+delete_server = d
+open_response_body = o
+sort_servers = s
+key_quick_ref = c-f
+
+[DB]
+filename = requests.db
+
+[JSON]
+indentation = 2
+
+[SORT_BY]
+# Column options: name, timestamp, url, method, body, authtype, authuser,
+# authpass, headers
+# Order options: asc, desc
+column = timestamp
+order = asc
+
+[STYLE]
+# More styles here:
+# https://bitbucket.org/birkenfeld/pygments-main/src/stable/pygments/styles
+theme = default
+separator_line_fg = gray
+separator_line_bg = black
diff --git a/examples/torch_kernel/torch_mlp_frontend/force_model.py b/examples/torch_kernel/torch_mlp_frontend/force_model.py
new file mode 100755
index 00000000..004c6f45
--- /dev/null
+++ b/examples/torch_kernel/torch_mlp_frontend/force_model.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python
+import sys
+import numpy as np
+import torch
+import torch.nn as nn
+import MDAnalysis as mda
+import copy
+
+
+class ForceModel(torch.nn.Module):
+
+ def __init__(self):
+ super(ForceModel, self).__init__()
+ # The parameter will be named 'my_bias'
+ torch.manual_seed(12345)
+ self.k = nn.Parameter((torch.randn(3)*0.3 + 1.0) * 10)
+ self.r0 = nn.Parameter((torch.rand(3)) * 0.01 + 0.3) # in nm
+
+ def forward(self, positions, box, atomtypes):
+ # pairwise parameters
+ ks = torch.zeros(3)
+ r0s = torch.zeros(3)
+ ks[0:2] = self.k[0]
+ ks[2] = self.k[1]
+ r0s[0:2] = self.r0[0]
+ r0s[2] = self.r0[1]
+ # pairwise distances
+ rij = torch.zeros(3, 3)
+ rij[0] = positions[0] - positions[1]
+ rij[1] = positions[0] - positions[2]
+ rij[2] = positions[1] - positions[2]
+ rijnorm = torch.norm(rij, dim=1)
+ rij_normed = (rij.T / rijnorm).T
+ energy = torch.sum(0.5 * ks * (rijnorm - r0s)**2)
+ fij = torch.zeros(3, 3)
+ fij = (-ks * (rijnorm - r0s) * rij_normed.T).T
+ forces = torch.zeros(3, 3)
+ forces[0] = fij[0] + fij[1]
+ forces[1] = -fij[0] + fij[2]
+ forces[2] = -fij[1] - fij[2]
+ virial = rij.T @ fij
+ results = {}
+ results['pred_energy'] = energy
+ results['pred_forces'] = forces
+ results['pred_virial'] = virial
+ return results
+
+
+if __name__ == "__main__":
+ model = ForceModel()
+ u = mda.Universe('structure.pdb')
+ positions = torch.tensor(u.atoms.positions / 10)
+ box = torch.eye(3) * 3.0
+ atomtypes = np.array([0, 0, 1])
+ res = model(positions, box, atomtypes)
+ print('Predicted Energy')
+ print(res['pred_energy'])
+ res['pred_energy'].backward()
+ print('Parameter Gradient')
+ for pname, p in model.named_parameters():
+ print(pname, p.grad)
+ print('Forces:')
+ print(res['pred_forces'])
+ # # print(res['pred_forces'])
+ # print(res['pred_virial'])
+ # # check force
+ # x0 = copy.deepcopy(positions[1, 1])
+ # print('#', res['pred_forces'][1, 1])
+ # delta = 0.0005
+ # energies = torch.zeros(5)
+ # for i in range(-2, 3):
+ # positions[1, 1] = x0 + delta * i
+ # res = model(positions, box, atomtypes)
+ # energies[i+2] = res['pred_energy']
+ # # print(i*delta, '%.8f'%res['pred_energy'])
+ # print((energies[-1] - energies[0])/delta/4)
+ # # check virial
+
+ # save model
+ torch.save(model, 'model.pt')
diff --git a/examples/torch_kernel/torch_mlp_frontend/forcefield.xml b/examples/torch_kernel/torch_mlp_frontend/forcefield.xml
new file mode 100644
index 00000000..67c2bf4c
--- /dev/null
+++ b/examples/torch_kernel/torch_mlp_frontend/forcefield.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/torch_kernel/torch_mlp_frontend/model.pt b/examples/torch_kernel/torch_mlp_frontend/model.pt
new file mode 100644
index 00000000..17dece63
Binary files /dev/null and b/examples/torch_kernel/torch_mlp_frontend/model.pt differ
diff --git a/examples/torch_kernel/torch_mlp_frontend/new.pt b/examples/torch_kernel/torch_mlp_frontend/new.pt
new file mode 100644
index 00000000..4d892a37
Binary files /dev/null and b/examples/torch_kernel/torch_mlp_frontend/new.pt differ
diff --git a/examples/torch_kernel/torch_mlp_frontend/new_sd.pt b/examples/torch_kernel/torch_mlp_frontend/new_sd.pt
new file mode 100644
index 00000000..ca93c710
Binary files /dev/null and b/examples/torch_kernel/torch_mlp_frontend/new_sd.pt differ
diff --git a/examples/torch_kernel/torch_mlp_frontend/residues.xml b/examples/torch_kernel/torch_mlp_frontend/residues.xml
new file mode 100644
index 00000000..e1bcbac5
--- /dev/null
+++ b/examples/torch_kernel/torch_mlp_frontend/residues.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/examples/torch_kernel/torch_mlp_frontend/run.py b/examples/torch_kernel/torch_mlp_frontend/run.py
new file mode 100755
index 00000000..99e8cf26
--- /dev/null
+++ b/examples/torch_kernel/torch_mlp_frontend/run.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+import openmm.app as app
+import openmm.unit as unit
+from dmff.api import Hamiltonian
+from dmff.common import nblist
+import numpy as np
+import jax
+import jax.numpy as jnp
+import torch
+from torch2jax import j2t, t2j
+# from dmff.torch_tools import wrap_torch_potential_kernel, j2t_pytree, t2j_pytree
+
+from force_model import ForceModel
+
+
+if __name__ == '__main__':
+ H = Hamiltonian('forcefield.xml')
+ params = H.getParameters().parameters
+ app.Topology.loadBondDefinitions("residues.xml")
+ pdb = app.PDBFile("structure.pdb")
+ rc = 1.0
+
+ pots = H.createPotential(pdb.topology, \
+ nonbondedMethod=app.CutoffPeriodic, \
+ nonbondedCutoff=rc*unit.nanometer)
+
+ efunc = pots.getPotentialFunc()
+ positions = jnp.array(pdb.positions._value)
+ box = jnp.array(pdb.topology.getPeriodicBoxVectors()._value)
+
+ ene, pgrad = jax.value_and_grad(efunc, argnums=3)(positions, box, None, params)
+ ene, rgrad = jax.value_and_grad(efunc, argnums=0)(positions, box, None, params)
+ print('Predicted Energy:')
+ print(ene)
+ print('Parameter Gradient:')
+ print(pgrad)
+ print('Position Gradient:')
+ print(rgrad)
+
+ # saving new parameters
+ gen = H.getGenerators()[0]
+ gen.write_to(params, 'new.pt', 'new_sd.pt')
diff --git a/examples/torch_kernel/torch_mlp_frontend/structure.pdb b/examples/torch_kernel/torch_mlp_frontend/structure.pdb
new file mode 100644
index 00000000..9d1a42fd
--- /dev/null
+++ b/examples/torch_kernel/torch_mlp_frontend/structure.pdb
@@ -0,0 +1,6 @@
+REMARK 1 CREATED WITH OPENMM 8.0, 2023-08-25
+CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1
+HETATM 1 AR ARG A 1 4.125 13.679 13.761 1.00 0.00 Ar
+HETATM 2 AR ARG A 2 5.406 17.008 13.462 1.00 0.00 Ar
+HETATM 3 NE NEO A 3 7.292 15.267 11.931 1.00 0.00 Ne
+END