Skip to content

Commit c40a394

Browse files
authored
Feat: Add DeepMD MLFF support (#999)
1 parent 0030d68 commit c40a394

File tree

11 files changed

+127
-4
lines changed

11 files changed

+127
-4
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,5 @@ docs/reference/atomate2.*
7878

7979
.ipynb_checkpoints
8080
.aider*
81+
82+
tests/test_data/forcefields/deepmd_graph.pb

docs/user/codes/forcefields.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
(codes.forcefields)=
2+
3+
# Machine Learning forcefields / interatomic potentials
4+
5+
`atomate2` includes an interface to a few common machine learning interatomic potentials (MLIPs), also known variously as machine learning forcefields (MLFFs), or foundation potentials (FPs) for universal variants.
6+
Support is provided for the following models, which can be selected using `atomate2.forcefields.utils.MLFF`, as shown in the table below.
7+
**You need only install packages for the forcefields you wish to use.**
8+
9+
| Forcefield Name | `MLFF` | Reference | Description |
10+
| ---- | ---- | ---- | ---- |
11+
| CHGNet | `CHGNet` | [10.1038/s42256-023-00716-3](https://doi.org/10.1038/s42256-023-00716-3) | Available via the `chgnet` and `matgl` packages |
12+
| DeepMD | `MLFF.DeepMD` | [10.1103/PhysRevB.108.L180104](https://doi.org/10.1103/PhysRevB.108.L180104) | The Deep Potential model used for this test is `UniPero`, a universal interatomic potential for perovskite oxides. It can be downloaded [here](https://github.com/sliutheorygroup/UniPero) |
13+
| Gaussian Approximation Potential (GAP) | `GAP` | [10.1103/PhysRevLett.104.136403](https://doi.org/10.1103/PhysRevLett.104.136403) | Relies on `quippy-ase` package |
14+
| M3GNet | `M3GNet` | [10.1038/s43588-022-00349-3](https://doi.org/10.1038/s43588-022-00349-3) | Relies on `matgl` package |
15+
| MACE-MP-0 | `MACE` or `MACE_MP_0` (recommended) | [10.1063/5.0297006](https://doi.org/10.1063/5.0297006) | Relies on `mace_torch` and optionally `torch_dftd` packages |
16+
| MACE-MP-0b3 | `MACE_MP_0B3` | [10.1063/5.0297006](https://doi.org/10.1063/5.0297006) | Relies on `mace_torch` and optionally `torch_dftd` packages |
17+
| MACE-MPA-0 | `MACE_MPA_0` | [10.1063/5.0297006](https://doi.org/10.1063/5.0297006) | Relies on `mace_torch` and optionally `torch_dftd` packages |
18+
| MatPES-PBE | `MATPES_PBE` | [10.48550/arXiv.2503.04070](https://doi.org/10.48550/arXiv.2503.04070) | Relies on `matgl`. Defaults to TensorNet architecture, but can also use M3GNet or CHGNet architectures via kwargs. See `atomate2.forcefields.utils._DEFAULT_CALCULATOR_KWARGS` for more options. |
19+
| MatPES-r<sup>2</sup>SCAN | `MATPES_R2SCAN`| [10.48550/arXiv.2503.04070](https://doi.org/10.48550/arXiv.2503.04070) | Relies on `matgl`. Defaults to TensorNet architecture, but can also use M3GNet or CHGNet architectures via kwargs. See `atomate2.forcefields.utils._DEFAULT_CALCULATOR_KWARGS` for more options. |
20+
| Neuroevolution Potential (NEP) | `NEP` | [10.1103/PhysRevB.104.104309](https://doi.org/10.1103/PhysRevB.104.104309) | Relies on `calorine` package |
21+
| Neural Equivariant Interatomic Potentials (Nequip) | `Nequip` | [10.1038/s41467-022-29939-5](https://doi.org/10.1038/s41467-022-29939-5) | Relies on the `nequip` package |
22+
| SevenNet | `SevenNet` | [10.1021/acs.jctc.4c00190](https://doi.org/10.1021/acs.jctc.4c00190) | Relies on the `sevenn` package |

docs/user/codes/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ The section gives the instructions for codes supported by atomate2.
77
```{toctree}
88
vasp
99
openmm
10+
forcefields
1011
```

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ forcefields = [
6161
# quippy-ase support for py3.12 tracked in https://github.com/libAtoms/QUIP/issues/645
6262
"quippy-ase>=0.9.14; python_version < '3.12'",
6363
"sevenn>=0.9.3",
64+
"deepmd-kit>=2.1.4",
6465
]
6566
approxneb = ["pymatgen-analysis-diffusion>=2024.7.15"]
6667
ase = ["ase>=3.26.0"]
@@ -111,6 +112,8 @@ strict-forcefields = [
111112
"sevenn==0.10.4",
112113
"torch==2.2.0",
113114
"torchdata==0.7.1", # TODO: remove when issue fixed
115+
"deepmd-kit==2.2.11",
116+
"tensorflow-cpu==2.16.2",
114117
]
115118
strict = [
116119
"atomate2[strict-forcefields, docs, cclib, phonons, lobster, openmm, mp, defects, ase, ase-ext]",

src/atomate2/forcefields/schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def from_ase_compatible_result(
119119
MLFF.MACE_MP_0B3: "mace-torch",
120120
MLFF.GAP: "quippy-ase",
121121
MLFF.Nequip: "nequip",
122+
MLFF.DeepMD: "deepmd-kit",
122123
MLFF.MATPES_PBE: "matgl",
123124
MLFF.MATPES_R2SCAN: "matgl",
124125
}

src/atomate2/forcefields/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
4040
SevenNet = "SevenNet"
4141
MATPES_R2SCAN = "MatPES-r2SCAN"
4242
MATPES_PBE = "MatPES-PBE"
43+
DeepMD = "DeepMD"
4344

4445
@classmethod
4546
def _missing_(cls, value: Any) -> Any:
@@ -273,6 +274,11 @@ def ase_calculator(
273274

274275
calculator = SevenNetCalculator(**{"model": "7net-0"} | kwargs)
275276

277+
elif calculator_name == MLFF.DeepMD:
278+
from deepmd.calculator import DP
279+
280+
calculator = DP(**kwargs)
281+
276282
elif isinstance(calculator_meta, dict):
277283
calc_cls = MontyDecoder().process_decoded(calculator_meta)
278284
calculator = calc_cls(**kwargs)

tests/forcefields/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from __future__ import annotations
22

3+
import hashlib
4+
import urllib.request
35
from typing import TYPE_CHECKING
46

7+
import pytest
58
import torch
9+
from emmet.core.utils import get_hash_blocked
610

711
if TYPE_CHECKING:
12+
from pathlib import Path
813
from typing import Any
914

1015

@@ -13,3 +18,18 @@ def pytest_runtest_setup(item: Any) -> None:
1318
torch.set_default_dtype(torch.float32)
1419
# For consistent performance across hardware, explicitly set device to CPU
1520
torch.set_default_device("cpu")
21+
22+
23+
@pytest.fixture(scope="session", autouse=True)
24+
def download_deepmd_pretrained_model(test_dir: Path) -> None:
25+
# Download DeepMD pretrained model from GitHub
26+
file_url = "https://raw.github.com/sliutheorygroup/UniPero/main/model/graph.pb"
27+
local_path = test_dir / "forcefields" / "deepmd_graph.pb"
28+
ref_md5 = "2814ae7f2eb1c605dd78f2964187de40"
29+
_, http_message = urllib.request.urlretrieve(file_url, local_path) # noqa: S310
30+
if "Content-Type: text/html" in http_message:
31+
raise RuntimeError(f"Failed to download from: {file_url}")
32+
33+
# Check MD5 to ensure file integrity
34+
if (file_md5 := get_hash_blocked(local_path, hasher=hashlib.md5())) != ref_md5:
35+
raise RuntimeError(f"MD5 mismatch: {file_md5} != {ref_md5}")

tests/forcefields/test_jobs.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,69 @@ def test_nequip_relax_maker(
522522
assert final_spg_num == 99
523523

524524

525+
def test_deepmd_static_maker(sr_ti_o3_structure: Structure, test_dir: Path):
526+
importorskip("deepmd")
527+
528+
# generate job
529+
job = ForceFieldStaticMaker(
530+
force_field_name="DeepMD",
531+
ionic_step_data=("structure", "energy"),
532+
calculator_kwargs={"model": test_dir / "forcefields" / "deepmd_graph.pb"},
533+
).make(sr_ti_o3_structure)
534+
535+
# run the flow or job and ensure that it finished running successfully
536+
responses = run_locally(job, ensure_success=True)
537+
538+
# validate the outputs of the job
539+
output1 = responses[job.uuid][1].output
540+
assert isinstance(output1, ForceFieldTaskDocument)
541+
assert output1.output.energy == approx(-3723.09868, rel=1e-4)
542+
assert output1.output.n_steps == 1
543+
assert output1.forcefield_version == get_imported_version("deepmd-kit")
544+
545+
546+
@pytest.mark.parametrize(
547+
("relax_cell", "fix_symmetry"),
548+
[(True, False), (False, True)],
549+
)
550+
def test_deepmd_relax_maker(
551+
sr_ti_o3_structure: Structure,
552+
test_dir: Path,
553+
relax_cell: bool,
554+
fix_symmetry: bool,
555+
):
556+
importorskip("deepmd")
557+
# translate one atom to ensure a small number of relaxation steps are taken
558+
sr_ti_o3_structure.translate_sites(0, [0, 0, 0.01])
559+
# generate job
560+
job = ForceFieldRelaxMaker(
561+
force_field_name="DeepMD",
562+
steps=25,
563+
optimizer_kwargs={"optimizer": "BFGSLineSearch"},
564+
relax_cell=relax_cell,
565+
fix_symmetry=fix_symmetry,
566+
calculator_kwargs={"model": test_dir / "forcefields" / "deepmd_graph.pb"},
567+
).make(sr_ti_o3_structure)
568+
569+
# run the flow or job and ensure that it finished running successfully
570+
responses = run_locally(job, ensure_success=True)
571+
572+
# validate the outputs of the job
573+
output1 = responses[job.uuid][1].output
574+
assert isinstance(output1, ForceFieldTaskDocument)
575+
if relax_cell:
576+
assert output1.output.energy == approx(-3723.099519623731, rel=1e-3)
577+
assert output1.output.n_steps == 3
578+
else:
579+
assert output1.output.energy == approx(-3723.0981880334643, rel=1e-4)
580+
assert output1.output.n_steps == 3
581+
582+
# fix_symmetry makes no difference for this structure relaxer combo
583+
# just testing that passing fix_symmetry doesn't break
584+
final_spg_num = output1.output.structure.get_space_group_info()[1]
585+
assert final_spg_num == 99
586+
587+
525588
@pytest.mark.parametrize("ref_func", ["PBE", "r2SCAN"])
526589
def test_matpes_relax_makers(
527590
sr_ti_o3_structure: Structure,

tests/forcefields/test_md.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_ml_ff_md_maker(
7171
MLFF.NEP: -3.966232215741286,
7272
MLFF.Nequip: -8.84670181274414,
7373
MLFF.SevenNet: -5.394115447998047,
74+
MLFF.DeepMD: -744.6197365326168,
7475
MLFF.MATPES_PBE: -5.230762481689453,
7576
MLFF.MATPES_R2SCAN: -8.561729431152344,
7677
}
@@ -101,6 +102,9 @@ def test_ml_ff_md_maker(
101102
"model_path": test_dir / "forcefields" / "nequip" / "nequip_ff_sr_ti_o3.pth"
102103
}
103104
unit_cell_structure = sr_ti_o3_structure.copy()
105+
elif ff_name == MLFF.DeepMD:
106+
calculator_kwargs = {"model": test_dir / "forcefields" / "deepmd_graph.pb"}
107+
unit_cell_structure = sr_ti_o3_structure.copy()
104108

105109
structure = unit_cell_structure.to_conventional() * (2, 2, 2)
106110

tests/forcefields/test_phonon.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_supercell_orthorhombic(clean_dir, si_structure: Structure):
2828
min_length=5,
2929
max_length=10,
3030
prefer_90_degrees=False,
31-
allow_orhtorhombic=True,
31+
allow_orthorhombic=True,
3232
)
3333

3434
# run the flow or job and ensure that it finished running successfully
@@ -43,14 +43,14 @@ def test_supercell_orthorhombic(clean_dir, si_structure: Structure):
4343
min_length=5,
4444
max_length=10,
4545
prefer_90_degrees=True,
46-
allow_orhtorhombic=True,
46+
allow_orthorhombic=True,
4747
)
4848

4949
# run the flow or job and ensure that it finished running successfully
5050
responses = run_locally(job2, create_folders=True, ensure_success=True)
5151

5252
assert_allclose(
53-
responses[job2.uuid][1].output, [[2, -1, 0], [0, 3, 0], [-1, -1, 2]]
53+
responses[job2.uuid][1].output, [[2, -1, 0], [0, 2, 0], [-1, -1, 2]]
5454
)
5555

5656

@@ -74,6 +74,7 @@ def test_phonon_maker_initialization_with_all_mlff(
7474
calc_kwargs = {
7575
MLFF.Nequip: {"model_path": f"{chk_pt_dir}/nequip/nequip_ff_sr_ti_o3.pth"},
7676
MLFF.NEP: {"model_filename": f"{test_dir}/forcefields/nep/nep.txt"},
77+
MLFF.DeepMD: {"model": test_dir / "forcefields" / "deepmd_graph.pb"},
7778
}.get(mlff, {})
7879
static_maker = ForceFieldStaticMaker(
7980
name=f"{mlff} static",

0 commit comments

Comments
 (0)