Skip to content

Commit b2f9821

Browse files
simplify classes
1 parent 214c8b2 commit b2f9821

File tree

3 files changed

+40
-87
lines changed

3 files changed

+40
-87
lines changed

src/atomate2/forcefields/jobs.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from jobflow import job
1111

1212
from atomate2.ase.jobs import AseRelaxMaker
13-
from atomate2.forcefields.schemas import BaseForceFieldTaskDocument
13+
from atomate2.forcefields.schemas import ForceFieldTaskDocument
1414
from atomate2.forcefields.utils import _FORCEFIELD_DATA_OBJECTS, MLFF, ForceFieldMixin
1515

1616
if TYPE_CHECKING:
@@ -19,10 +19,7 @@
1919

2020
from pymatgen.core.structure import Molecule, Structure
2121

22-
from atomate2.forcefields.schemas import (
23-
ForceFieldMoleculeTaskDocument,
24-
ForceFieldTaskDocument,
25-
)
22+
from atomate2.forcefields.schemas import ForceFieldMoleculeTaskDocument
2623

2724
logger = logging.getLogger(__name__)
2825

@@ -144,7 +141,7 @@ def make(
144141
stacklevel=1,
145142
)
146143

147-
return BaseForceFieldTaskDocument.from_ase_compatible_result(
144+
return ForceFieldTaskDocument.from_ase_compatible_result(
148145
str(self.force_field_name), # make mypy happy
149146
ase_result,
150147
self.steps,

src/atomate2/forcefields/md.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,15 @@
99
from jobflow import job
1010

1111
from atomate2.ase.md import AseMDMaker, MDEnsemble
12-
from atomate2.forcefields.schemas import BaseForceFieldTaskDocument
12+
from atomate2.forcefields.schemas import ForceFieldTaskDocument
1313
from atomate2.forcefields.utils import _FORCEFIELD_DATA_OBJECTS, ForceFieldMixin
1414

1515
if TYPE_CHECKING:
1616
from pathlib import Path
1717

1818
from pymatgen.core.structure import Molecule, Structure
1919

20-
from atomate2.forcefields.schemas import (
21-
ForceFieldMoleculeTaskDocument,
22-
ForceFieldTaskDocument,
23-
)
20+
from atomate2.forcefields.schemas import ForceFieldMoleculeTaskDocument
2421

2522

2623
@dataclass
@@ -137,7 +134,7 @@ def make(
137134
stacklevel=1,
138135
)
139136

140-
return BaseForceFieldTaskDocument.from_ase_compatible_result(
137+
return ForceFieldTaskDocument.from_ase_compatible_result(
141138
str(self.force_field_name), # make mypy happy
142139
md_result,
143140
relax_cell=(self.ensemble == MDEnsemble.npt),

src/atomate2/forcefields/schemas.py

Lines changed: 34 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from __future__ import annotations
44

5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

77
from emmet.core.types.enums import StoreTrajectoryOption
88
from pydantic import BaseModel, Field
9-
from pymatgen.core import Molecule, Structure
9+
from pymatgen.core import Molecule
1010

1111
from atomate2.ase.schemas import (
1212
AseMoleculeTaskDoc,
@@ -18,6 +18,9 @@
1818
)
1919
from atomate2.forcefields import MLFF
2020

21+
if TYPE_CHECKING:
22+
from typing_extensions import Self
23+
2124

2225
class ForceFieldMeta(BaseModel):
2326
"""Add metadata to forcefield output documents."""
@@ -56,44 +59,33 @@ def forcefield_objects(self) -> dict[AseObject, Any] | None:
5659
return self.objects
5760

5861

59-
class ForceFieldTaskDocument(AseStructureTaskDoc, ForceFieldMeta):
60-
"""Document containing information on structure manipulation using a force field."""
61-
62-
6362
class ForceFieldMoleculeTaskDocument(AseMoleculeTaskDoc, ForceFieldMeta):
64-
"""Document containing information on structure manipulation using a force field."""
65-
66-
67-
class BaseForceFieldTaskDocument(AseTaskDoc):
68-
"""Document containing information on structure manipulation using a force field."""
69-
70-
forcefield_name: str | None = Field(
71-
None,
72-
description="name of the interatomic potential used for relaxation.",
73-
)
63+
"""Document containing information on molecule manipulation using a force field."""
7464

75-
forcefield_version: str | None = Field(
76-
"Unknown",
77-
description="version of the interatomic potential used for relaxation.",
78-
)
65+
@classmethod
66+
def from_ase_task_doc(
67+
cls, ase_task_doc: AseTaskDoc, **task_document_kwargs
68+
) -> Self:
69+
"""Create a ForceFieldMoleculeTaskDocument from an AseTaskDoc.
7970
80-
dir_name: str | None = Field(
81-
None, description="Directory where the force field calculations are performed."
82-
)
71+
Parameters
72+
----------
73+
ase_task_doc : AseTaskDoc
74+
Task doc for the calculation
75+
task_document_kwargs : dict
76+
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
77+
"""
78+
task_document_kwargs.update(
79+
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
80+
structure=ase_task_doc.mol_or_struct,
81+
)
82+
return cls.from_molecule(
83+
meta_molecule=ase_task_doc.mol_or_struct, **task_document_kwargs
84+
)
8385

84-
included_objects: list[AseObject] | None = Field(
85-
None, description="list of forcefield objects included with this task document"
86-
)
87-
objects: dict[AseObject, Any] | None = Field(
88-
None, description="Forcefield objects associated with this task"
89-
)
9086

91-
is_force_converged: bool | None = Field(
92-
None,
93-
description=(
94-
"Whether the calculation is converged with respect to interatomic forces."
95-
),
96-
)
87+
class ForceFieldTaskDocument(AseStructureTaskDoc, ForceFieldMeta):
88+
"""Document containing information on atomistic manipulation using a force field."""
9789

9890
@classmethod
9991
def from_ase_compatible_result( # type: ignore[override]
@@ -115,8 +107,8 @@ def from_ase_compatible_result( # type: ignore[override]
115107
store_trajectory: StoreTrajectoryOption = StoreTrajectoryOption.NO,
116108
tags: list[str] | None = None,
117109
**task_document_kwargs,
118-
) -> ForceFieldTaskDocument | ForceFieldMoleculeTaskDocument:
119-
"""Create an ForceField output for a task that has ASE-compatible outputs.
110+
) -> Self | ForceFieldMoleculeTaskDocument:
111+
"""Create forcefield output for a task that has ASE-compatible outputs.
120112
121113
Parameters
122114
----------
@@ -184,41 +176,8 @@ def from_ase_compatible_result( # type: ignore[override]
184176

185177
ff_kwargs["forcefield_version"] = importlib.metadata.version(pkg_name)
186178

187-
return cls.from_ase_task_doc(ase_task_doc, **ff_kwargs)
188-
189-
@classmethod
190-
def from_ase_task_doc(
191-
cls, ase_task_doc: AseTaskDoc, **task_document_kwargs
192-
) -> ForceFieldTaskDocument | ForceFieldMoleculeTaskDocument:
193-
"""Create an ForceField output for a task that has ASE-compatible outputs.
194-
195-
Parameters
196-
----------
197-
ase_task_doc : AseTaskDoc
198-
Task doc for the calculation
199-
task_document_kwargs : dict
200-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`
201-
or `.ForceFieldMoleculeTaskDocument()`.
202-
"""
203-
task_document_kwargs.update(
204-
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
205-
)
206-
if isinstance(ase_task_doc.mol_or_struct, Structure):
207-
meta_class = ForceFieldTaskDocument
208-
k = "structure"
209-
if relax_cell := getattr(ase_task_doc, "relax_cell", None):
210-
task_document_kwargs.update({"relax_cell": relax_cell})
211-
task_document_kwargs.update(structure=ase_task_doc.mol_or_struct)
212-
elif isinstance(ase_task_doc.mol_or_struct, Molecule):
213-
meta_class = ForceFieldMoleculeTaskDocument
214-
k = "molecule"
215-
task_document_kwargs.update(molecule=ase_task_doc.mol_or_struct)
216-
task_document_kwargs.update(
217-
{k: ase_task_doc.mol_or_struct, f"meta_{k}": ase_task_doc.mol_or_struct}
218-
)
219-
return getattr(meta_class, f"from_{k}")(**task_document_kwargs)
220-
221-
@property
222-
def forcefield_objects(self) -> dict[AseObject, Any] | None:
223-
"""Alias `objects` attr for backwards compatibility."""
224-
return self.objects
179+
return (
180+
ForceFieldMoleculeTaskDocument
181+
if isinstance(result.final_mol_or_struct, Molecule)
182+
else cls
183+
).from_ase_task_doc(ase_task_doc, **ff_kwargs)

0 commit comments

Comments
 (0)