diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 8e47fe4..4240d69 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -15,7 +15,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10"] + python-version: ["3.12"] runs-on: ${{ matrix.os }} steps: @@ -28,7 +28,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Pre-commit hooks - uses: pre-commit/action@v3.0.0 + uses: pre-commit/action@v3.0.1 build: needs: pre-commit @@ -36,7 +36,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] #, macos-latest, windows-latest] - python-version: ["3.10"] + python-version: ["3.12"] defaults: run: @@ -48,7 +48,7 @@ jobs: uses: actions/checkout@v4 - name: Set up environment - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v2 with: environment-file: environment.yaml environment-name: gha-test-env @@ -94,7 +94,7 @@ jobs: # strategy: # matrix: # os: [ubuntu-latest] - # python-version: ["3.10"] + # python-version: ["3.12"] # runs-on: ${{ matrix.os }} # steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 816cac6..9772597 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,41 +1,13 @@ -# Pre-commit hooks for Python code -# Last revision by: Joao Morado -# Last revision date: 8.01.2023 -# See https://pre-commit.com for more information -# See https://pre-commit.com/hooks.html for more hooks - repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - #- id: check-added-large-files -- repo: https://github.com/psf/black - rev: 23.1.0 - hooks: - - id: black - - id: black-jupyter -- repo: https://github.com/keewis/blackdoc - rev: v0.3.8 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.8 hooks: - - id: blackdoc - #- id: blackdoc-autoupdate-black -- repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", "--filter-files"] -- repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-docstrings] - args: [--max-line-length=127, --exit-zero] - verbose: True -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.1.1 + - id: ruff + args: ["--fix"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.17.1 hooks: - - id: mypy - args: [--no-strict-optional, --ignore-missing-imports, --namespace-packages, --explicit-package-bases] - additional_dependencies: ["types-PyYAML"] + - id: mypy + verbose: true + entry: bash -c 'mypy "$@" || true' -- \ No newline at end of file diff --git a/README.md b/README.md index cfbdc81..28219a6 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,10 @@ A package to run hybrid ML/MM free energy simulations. 1. [Installation](#installation) 2. [Alchemical Modifications](#alchemical-modifications) 3. [Running a Multistate Equilibrium Free Energy Simulation](#running-a-multistate-equilibrium-free-energy-simulation) + 1. [Using Multiple Alchemical Groups](#using-multiple-alchemical-groups) 4. [Dynamics and EMLE settings](#dynamics-and-emle-settings) - 1. [Sire Strategy](#sire-strategy) 2. [OpenFF Strategy](#openff-strategy) - - 5. [Log Level](#log-level) ## Installation @@ -98,6 +96,78 @@ U_kln = fes.run_single_state(1000, 1000, 6) np.save("U_kln_mm_sol_6.npy", np.asarray(U_kln)) ``` +### Using Multiple Alchemical Groups + +For more complex transformations, you can define multiple alchemical groups that can be transformed independently or simultaneously. This is particularly useful when you want to apply different transformations to different regions of your system or transform multiple ligands separately. + +To use multiple alchemical groups, specify the group name as a suffix after a colon in the lambda schedule: + +```python +from fes_ml.fes import FES +import numpy as np + +# Define lambda schedule for multiple alchemical groups +lambda_schedule = { + # Group 1: Turn off LJ and charges for ligand 1 + "LJSoftCore:ligand1": [1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.0, 0.0], + "ChargeScaling:ligand1": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.0], + + # Group 2: Turn off LJ and charges for ligand 2 + "LJSoftCore:ligand2": [1.0, 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.0], + "ChargeScaling:ligand2": [1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.33, 0.0], + + # Group 3: Interpolate between MM and ML for the entire system + "MLInterpolation:system": [0.0, 0.0, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0] +} + +# Define atom indices for each alchemical group +ligand1_atoms = [1, 2, 3, 4, 5] # Atoms belonging to first ligand +ligand2_atoms = [20, 21, 22, 23, 24] # Atoms belonging to second ligand +system_atoms = list(range(1, 50)) # All atoms for ML/MM interpolation + +# Define per-group alchemical atoms +modifications_kwargs = { + "LJSoftCore:ligand1": { + "alchemical_atoms": ligand1_atoms + }, + "ChargeScaling:ligand1": { + "alchemical_atoms": ligand1_atoms + }, + "LJSoftCore:ligand2": { + "alchemical_atoms": ligand2_atoms + }, + "ChargeScaling:ligand2": { + "alchemical_atoms": ligand2_atoms + }, + "MLInterpolation:system": { + "alchemical_atoms": system_atoms + } +} +``` + +#### Multiple Instances of the Same Modification Type + +You can also use multiple instances of the same modification type for the same group of atoms. For example, to interpolate between two sets of `CustomLJ` parameters: + +```python +lambda_schedule = { + "LJSoftCore:openff1": [1.0, 0.8, 0.6, 0.4, 0.2, 0.0], + "LJSoftCore:openff2": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "CustomLJ:openff1": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "CustomLJ:openff2": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], +} + +# Define different LJ parameters for each region +modifications_kwargs = { + "CustomLJ:openff1": { + "lj_offxml": "openff_unconstrained-1.0.0.offxml", + }, + "CustomLJ:openff2": { + "lj_offxml": "openff_unconstrained-2.0.0.offxml", + } +} +``` + ## Dynamics and EMLE settings In fes-ml, the default strategy for creating OpenMM systems is through Sire. Additionally, fes-ml offers the OpenFF strategy. You can select the desired creation strategy, either `'sire'` or `'openff'`, using the `strategy_name` argument when calling the `fes.create_alchemical_states` method to create the alchemical systems. Most other simulation configurations can also be set by passing additional arguments to this method. For details on customization, refer to the definitions of the `SireCreationStrategy` and `OpenFFCreationStrategy` classes. diff --git a/analysis/analyse.py b/analysis/analyse.py deleted file mode 100644 index 04d9884..0000000 --- a/analysis/analyse.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Script to calculate free energy differences using MBAR.""" - -import sys - -import matplotlib.pyplot as plt -import numpy as np -from openmm import unit -from pymbar import MBAR - -if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: - print("Usage: python analyse.py ") - print( - "Please provide exactly two input arguments, which should be a file name and the temperature in Kelvin." - ) - sys.exit(1) - -input_file = sys.argv[1] -temperature = float(sys.argv[2]) -U_kn = np.load(sys.argv[1]) - -# Reformat array such that data is organised in the following away -# U_kn = [[U^{l}_{k,n} for n in nsamples for k in nstates] for l in nstates] -# where l (superscript) is the alchemical state at which the potential energy is evaluated -# and k (subscript) the alchemical state at which it is sampled -nstates, nstates, nsamples = U_kn.shape -U_kn = U_kn.transpose(1, 0, 2)[::-1, :, :] -U_kn = U_kn.reshape(nstates, nstates * nsamples) - -# Keep it in here to contemplate case where number of samples per alchemical state differs -# N_k = [ U_kn.shape[1]//nstates for _ in range(nstates)] -N_k = [nsamples for _ in range(nstates)] - -# Compute the overal -mbar = MBAR(U_kn, N_k) -overlap = mbar.compute_overlap() -plt.figure() -plt.title("Overlap") -plt.imshow(overlap["matrix"], vmin=0, vmax=1) -plt.colorbar() -plt.savefig("overlap.png") - -# If this fails try setting compute_uncertainty to false -# See this issue: https://github.com/choderalab/pymbar/issues/419 -results = mbar.compute_free_energy_differences(compute_uncertainty=True) - -# Calculate the free energy -kT = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA * unit.kelvin * temperature -print( - "Free energy = {}".format( - (results["Delta_f"][nstates - 1, 0] * kT).in_units_of(unit.kilocalorie_per_mole) - ) -) -print( - "Statistical uncertainty = {}".format( - (results["dDelta_f"][nstates - 1, 0] * kT).in_units_of( - unit.kilocalorie_per_mole - ) - ) -) diff --git a/environment.yaml b/environment.yaml index df58c44..668af44 100644 --- a/environment.yaml +++ b/environment.yaml @@ -6,29 +6,26 @@ channels: dependencies: - ambertools - - ase - compilers - - cudatoolkit<11.9 - - deepmd-kit + - cudatoolkit=11.8 - eigen - loguru - openmm>=8.1 - openmm-torch + - openmm-ml - openmmforcefields - openff-toolkit - openff-interchange - nnpops - pip - pybind11 - - pytorch + - pytorch=*=*cuda* - python - pyyaml - sire - torchani - pygit2 - - xtb-python - pip: - git+https://github.com/chemle/emle-engine.git - - git+https://github.com/openmm/openmm-ml.git - coloredlogs - mace-torch diff --git a/environment_rascal.yaml b/environment_rascal.yaml deleted file mode 100644 index 37c0662..0000000 --- a/environment_rascal.yaml +++ /dev/null @@ -1,34 +0,0 @@ -name: fes-ml - -channels: - - conda-forge - - openbiosim/label/emle - -dependencies: - - ambertools - - ase - - compilers - - cudatoolkit<11.9 - - deepmd-kit - - eigen - - loguru - - openmm>=8.1 - - openmm-torch - - openmmforcefields - - openff-toolkit - - openff-interchange - - nnpops - - pip - - pybind11 - - pytorch<3.11 - - python - - pyyaml - - sire - - torchani - - xtb-python - - pip: - - git+https://github.com/lab-cosmo/librascal.git - - git+https://github.com/chemle/emle-engine.git@main - - git+https://github.com/openmm/openmm-ml.git - - coloredlogs - - mace-torch diff --git a/examples/additional_scripts/analysis.py b/examples/additional_scripts/analysis.py index 807a9d8..b825fff 100644 --- a/examples/additional_scripts/analysis.py +++ b/examples/additional_scripts/analysis.py @@ -125,7 +125,7 @@ def water_atom( # Select atoms for which you want to calculate RDF water = universe.select_atoms(f"type O and not {ligand_selection}") - ligand = universe.select_atoms(ligand_selection) + # ligand = universe.select_atoms(ligand_selection) atoms = [ universe.select_atoms(f"index {atom.index} and {ligand_selection}") for atom in universe.select_atoms(ligand_selection) @@ -327,10 +327,10 @@ def __init__(self) -> None: def _plot_stdout_with_time_base(txt_file, save_location, data_name, prop): df = pd.read_csv(txt_file, sep=",") data = df[f"{data_name}"] - steps = df[f'#"Step"'] + steps = df['#"Step"'] plt.plot(steps, data, "r-", linewidth=2, color="maroon") - plt.xlabel(f"Steps") + plt.xlabel("Steps") plt.ylabel(f"{data_name}") plt.title(f"Plot of {prop} for {txt_file}") # plt.savefig(f"{save_location}") @@ -378,12 +378,12 @@ def plot_energy_all_windows( txt_file = f"{folder}/{base_name}{w}.txt" df = pd.read_csv(txt_file, sep=",") data = df[f"{data_name}"] - steps = df[f'#"Step"'] + steps = df['#"Step"'] x_axis = [r for r in range(1, len(steps) + 1, 1)] plt.plot(x_axis, data, linewidth=2, alpha=0.3, label=label, color=colors[w]) legend_labels.append(label) - plt.xlabel(f"Steps") + plt.xlabel("Steps") plt.ylabel(f"{data_name}") plt.title(f"Plot of {prop} for {txt_file}") plt.legend(bbox_to_anchor=(1, 0.5), loc="center left") diff --git a/examples/openff_strategy/mts_benchmark/ml/agg_output.py b/examples/openff_strategy/mts_benchmark/ml/agg_output.py index 3664e87..b298743 100644 --- a/examples/openff_strategy/mts_benchmark/ml/agg_output.py +++ b/examples/openff_strategy/mts_benchmark/ml/agg_output.py @@ -9,7 +9,7 @@ out = f"OUTPUT_{i}" try: f = glob.glob(out + "/*npy")[0] - except: + except Exception: break U_kln.append(np.load(f)[:, frames_disc::step]) diff --git a/examples/openff_strategy/mts_benchmark/ml/analysis.py b/examples/openff_strategy/mts_benchmark/ml/analysis.py index 759d665..05ca87c 100644 --- a/examples/openff_strategy/mts_benchmark/ml/analysis.py +++ b/examples/openff_strategy/mts_benchmark/ml/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: print("Usage: python analyse.py ") diff --git a/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py b/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py index 3664e87..b298743 100644 --- a/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py +++ b/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py @@ -9,7 +9,7 @@ out = f"OUTPUT_{i}" try: f = glob.glob(out + "/*npy")[0] - except: + except Exception: break U_kln.append(np.load(f)[:, frames_disc::step]) diff --git a/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py b/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py index 759d665..05ca87c 100644 --- a/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py +++ b/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: print("Usage: python analyse.py ") diff --git a/examples/openff_strategy/mts_benchmark/mm/agg_output.py b/examples/openff_strategy/mts_benchmark/mm/agg_output.py index 3664e87..b298743 100644 --- a/examples/openff_strategy/mts_benchmark/mm/agg_output.py +++ b/examples/openff_strategy/mts_benchmark/mm/agg_output.py @@ -9,7 +9,7 @@ out = f"OUTPUT_{i}" try: f = glob.glob(out + "/*npy")[0] - except: + except Exception: break U_kln.append(np.load(f)[:, frames_disc::step]) diff --git a/examples/openff_strategy/mts_benchmark/mm/analysis.py b/examples/openff_strategy/mts_benchmark/mm/analysis.py index 759d665..05ca87c 100644 --- a/examples/openff_strategy/mts_benchmark/mm/analysis.py +++ b/examples/openff_strategy/mts_benchmark/mm/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: print("Usage: python analyse.py ") diff --git a/examples/openff_strategy/performance_benchmark_emle_aev.py b/examples/openff_strategy/performance_benchmark_emle_aev.py index 48b603f..e2d3632 100644 --- a/examples/openff_strategy/performance_benchmark_emle_aev.py +++ b/examples/openff_strategy/performance_benchmark_emle_aev.py @@ -10,13 +10,11 @@ import sys import time - import numpy as np import openff.units as offunit - import openmm as mm import openmm.app as app import openmm.unit as unit - from fes_ml import FES, MTS + from fes_ml import FES if len(sys.argv) == 1: raise ValueError("must pass window as positional arguments") diff --git a/examples/openff_strategy/run_scripts/agg_ouput.py b/examples/openff_strategy/run_scripts/agg_ouput.py index f3446ea..e20e657 100644 --- a/examples/openff_strategy/run_scripts/agg_ouput.py +++ b/examples/openff_strategy/run_scripts/agg_ouput.py @@ -15,10 +15,10 @@ def main(args): out = f"{folder}/{i}" try: f = glob.glob(f"{out}.npy")[0] - except: + except Exception: break print(np.load(f).shape) - U_kln.append(np.load(f)) # np.load(f)[:, frames_disc::step] + U_kln.append(np.load(f)[:, frames_disc::step]) U_kln = np.asarray(U_kln) print(U_kln.shape) diff --git a/examples/openff_strategy/run_scripts/analysis.py b/examples/openff_strategy/run_scripts/analysis.py index d9cc651..19c9e7e 100644 --- a/examples/openff_strategy/run_scripts/analysis.py +++ b/examples/openff_strategy/run_scripts/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR def main(args): diff --git a/examples/openff_strategy/run_scripts/create_run_system.py b/examples/openff_strategy/run_scripts/create_run_system.py index ec496a3..6d173fa 100644 --- a/examples/openff_strategy/run_scripts/create_run_system.py +++ b/examples/openff_strategy/run_scripts/create_run_system.py @@ -1,9 +1,7 @@ import logging import math import os -import sys from argparse import ArgumentParser -from typing import Optional, Union import numpy as np import openff.units as offunit @@ -33,7 +31,7 @@ def main(args): try: os.makedirs(folder) - except: + except Exception: pass # --------------------------------------------------------------- # diff --git a/examples/sire_strategy/benzene_ml_mm_sol_emle.py b/examples/sire_strategy/benzene_ml_mm_sol_emle.py index e21844c..ede045f 100644 --- a/examples/sire_strategy/benzene_ml_mm_sol_emle.py +++ b/examples/sire_strategy/benzene_ml_mm_sol_emle.py @@ -14,9 +14,6 @@ if __name__ == "__main__": import numpy as np - from fes_ml.alchemical.modifications.ml_interpolation import ( - MLInterpolationModification, - ) from fes_ml.fes import FES # Set up the alchemical modifications diff --git a/examples/sire_strategy/benzene_ml_mm_sol_mts.py b/examples/sire_strategy/benzene_ml_mm_sol_mts.py index a084ee3..a08dc26 100644 --- a/examples/sire_strategy/benzene_ml_mm_sol_mts.py +++ b/examples/sire_strategy/benzene_ml_mm_sol_mts.py @@ -10,7 +10,6 @@ if __name__ == "__main__": import numpy as np - import openmm as mm import openmm.unit as unit from fes_ml import FES, MTS diff --git a/examples/sire_strategy/benzene_ml_sol_mts.py b/examples/sire_strategy/benzene_ml_sol_mts.py index f824554..40fdd66 100644 --- a/examples/sire_strategy/benzene_ml_sol_mts.py +++ b/examples/sire_strategy/benzene_ml_sol_mts.py @@ -17,7 +17,6 @@ if __name__ == "__main__": import numpy as np - import openmm as mm import openmm.unit as unit from fes_ml import FES, MTS diff --git a/pyproject.toml b/pyproject.toml index 130652f..97ea86f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,17 +6,16 @@ build-backend = "setuptools.build_meta" name = "fes_ml" description = "A package to run hybrid ML/MM free energy simulations." requires-python = ">=3.9" -keywords = ["free energy simulation", "free energy", "force field", "machine learning", "openmm"] +keywords = ["free energy simulations", "free energy", "force field", "machine learning", "openmm"] authors = [{email = "jmorado@ed.ac.uk"},{name = "Joao Morado"}] maintainers = [{name = "Joao Morado", email = "jmorado@ed.ac.uk"}] classifiers = [ "License :: OSI Approved :: GPL License", "Intended Audience :: Science/Research", "Intended Audience :: Developers", - "Topic :: Oceanography Modeling", "Topic :: Scientific/Engineering", "Programming Language :: Python", - "Programming Language :: Python :: 3.10" + "Programming Language :: Python :: 2" ] dynamic = [ @@ -66,3 +65,21 @@ where = ["src"] include = ["*"] exclude = ["*__pycache__*"] namespaces = true + +[tool.ruff] +fix = true +line-length = 144 + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = ["E203", "I001"] + +[tool.ruff.lint.isort] +known-first-party = ["fes_ml"] +combine-as-imports = true +force-single-line = false + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true \ No newline at end of file diff --git a/src/fes_ml/__init__.py b/src/fes_ml/__init__.py index 0322cba..1123f23 100644 --- a/src/fes_ml/__init__.py +++ b/src/fes_ml/__init__.py @@ -1,10 +1,10 @@ """fes_ml base package.""" -__version__ = "0.2.1" +__version__ = "0.2.2" __author__ = "Joao Morado" from .fes import FES from .log import config_logger from .mts import MTS -__all__ = ["FES", "MTS"] +__all__ = ["FES", "MTS", "config_logger"] diff --git a/src/fes_ml/alchemical/alchemist.py b/src/fes_ml/alchemical/alchemist.py index 528deae..33083c3 100644 --- a/src/fes_ml/alchemical/alchemist.py +++ b/src/fes_ml/alchemical/alchemist.py @@ -1,8 +1,8 @@ """Module for the Alchemist class.""" import logging -import sys from copy import deepcopy as _deepcopy +from importlib.metadata import entry_points from typing import Any, Dict, List, Optional import networkx as nx @@ -10,11 +10,6 @@ from .modifications.base_modification import BaseModification, BaseModificationFactory -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points - logger = logging.getLogger(__name__) @@ -22,6 +17,7 @@ class Alchemist: """A class for applying alchemical modifications to an OpenMM system.""" _modification_factories: Dict[str, BaseModificationFactory] = {} + _DEFAULT_ALCHEMICAL_GROUP = ":default" @staticmethod def register_modification_factory( @@ -100,13 +96,13 @@ def add_modification_to_graph( lambda_value : float The value of the alchemical state parameter. """ - if modification.NAME in self._graph.nodes and lambda_value is None: - lambda_value = self._graph.nodes[modification.NAME].get( - "lambda_value", None - ) + node_name = modification.modification_name + + if node_name in self._graph.nodes and lambda_value is None: + lambda_value = self._graph.nodes[node_name].get("lambda_value", None) self._graph.add_node( - modification.NAME, modification=modification, lambda_value=lambda_value + node_name, modification=modification, lambda_value=lambda_value ) if modification.pre_dependencies is not None: @@ -119,10 +115,21 @@ def add_modification_to_graph( "modification is implemented and registered as an entry point." ) - factory = self._modification_factories[pre_dependency] - pre_modification = factory.create_modification() - self._graph.add_edge(pre_modification.NAME, modification.NAME) - self.add_modification_to_graph(pre_modification, None) + # Check if dependency already exists in graph + dep_modification_name = ( + f"{pre_dependency}:{modification.alchemical_group}" + ) + if dep_modification_name not in self._graph.nodes: + factory = self._modification_factories[pre_dependency] + pre_modification = factory.create_modification( + modification_name=dep_modification_name + ) + # Add edge first (like before), then recursively add dependency + self._graph.add_edge(dep_modification_name, node_name) + self.add_modification_to_graph(pre_modification, None) + else: + # Dependency already exists, just add the edge + self._graph.add_edge(dep_modification_name, node_name) if modification.post_dependencies is not None: for post_dependency in modification.post_dependencies: @@ -133,10 +140,22 @@ def add_modification_to_graph( "typos in the name of this post-dependency and that the target " "modification is implemented and registered as an entry point." ) - factory = self._modification_factories[post_dependency] - post_modification = factory.create_modification() - self._graph.add_edge(modification.NAME, post_modification.NAME) - self.add_modification_to_graph(post_modification, None) + + # Check if dependency already exists in graph + dep_modification_name = ( + f"{post_dependency}:{modification.alchemical_group}" + ) + if dep_modification_name not in self._graph.nodes: + factory = self._modification_factories[post_dependency] + post_modification = factory.create_modification( + modification_name=dep_modification_name + ) + # Add edge first (like before), then recursively add dependency + self._graph.add_edge(node_name, dep_modification_name) + self.add_modification_to_graph(post_modification, None) + else: + # Dependency already exists, just add the edge + self._graph.add_edge(node_name, dep_modification_name) def remove_modification_from_graph(self, modification: str) -> None: """ @@ -153,6 +172,7 @@ def create_alchemical_graph( self, lambda_schedule: Dict[str, float], additional_modifications: Optional[List[str]] = None, + modifications_kwargs: Optional[Dict[str, Dict[str, Any]]] = None, ): """ Create a graph of alchemical modifications to apply. @@ -163,6 +183,8 @@ def create_alchemical_graph( A dictionary of λ values to be applied to the system. additional_modifications : list of str Additional modifications to apply. + modifications_kwargs : dict + A dictionary of keyword arguments for the modifications. Returns ------- @@ -171,25 +193,63 @@ def create_alchemical_graph( """ logger.debug("Creating graph of alchemical modifications.") for name, lambda_value in lambda_schedule.items(): - if name in Alchemist._modification_factories: - factory = self._modification_factories[name] - modification = factory.create_modification() + if ":" in name: + base_name, modification_name = name.split(":", 1)[0], name + else: + base_name, modification_name = ( + name, + name + self._DEFAULT_ALCHEMICAL_GROUP, + ) + + if modifications_kwargs is not None and name in modifications_kwargs: + if modification_name not in modifications_kwargs: + modifications_kwargs[modification_name] = modifications_kwargs.pop( + name + ) + + if base_name in Alchemist._modification_factories: + factory = self._modification_factories[base_name] + modification = factory.create_modification( + modification_name=modification_name + ) self.add_modification_to_graph(modification, lambda_value=lambda_value) else: - raise ValueError(f"Modification {name} not found in the factories.") + raise ValueError( + f"Modification {base_name} not found in the factories." + ) if additional_modifications is not None: for name in additional_modifications: - if name in Alchemist._modification_factories: - factory = self._modification_factories[name] - modification = factory.create_modification() - self.add_modification_to_graph( - modification, lambda_value=lambda_value + if ":" in name: + base_name, modification_name = name.split(":", 1)[0], name + else: + base_name, modification_name = ( + name, + name + self._DEFAULT_ALCHEMICAL_GROUP, ) + + if modifications_kwargs is not None and name in modifications_kwargs: + if modification_name in modifications_kwargs: + raise ValueError( + f"Cannot rename '{name}' to '{modification_name}': " + "key already exists in modifications_kwargs." + ) + modifications_kwargs[modification_name] = modifications_kwargs.pop( + name + ) + + if base_name in Alchemist._modification_factories: + factory = self._modification_factories[base_name] + modification = factory.create_modification( + modification_name=modification_name + ) + self.add_modification_to_graph(modification, lambda_value=None) else: - raise ValueError(f"Modification {name} not found in the factories.") + raise ValueError( + f"Modification {base_name} not found in the factories." + ) - # After the graph is created, removed dependencies to skip + # After constructing the graph, remove dependencies to skip ref_graph = _deepcopy(self._graph) for _, data in ref_graph.nodes.data(): modification = data["modification"] @@ -197,6 +257,46 @@ def create_alchemical_graph( for skip_dependency in modification.skip_dependencies: self.remove_modification_from_graph(skip_dependency) + # After constructing the graph, remove redundant modifications + # Redundancy is determined using a binary overlap principle: + # - Total overlap: if two modifications have identical alchemical atoms, keep only one + # - No overlap: if the alchemical atoms are disjoint, keep both modifications + # - Partial overlap: if the alchemical atoms partially overlap, raise an error + # (this behavior may be implemented in the future) + modifications_kwargs = modifications_kwargs or {} + redundant_modifications = [ + name for name in self._graph.nodes if name not in lambda_schedule + ] + mod_atoms = { + name: set(modifications_kwargs.get(name, {}).get("alchemical_atoms", [])) + for name in redundant_modifications + } + to_remove = set() + for i, name in enumerate(redundant_modifications): + base_name = name.split(":", 1)[0] + set_a = mod_atoms[name] + for j in range(i + 1, len(redundant_modifications)): + other_name = redundant_modifications[j] + other_base_name = other_name.split(":", 1)[0] + if base_name != other_base_name: + continue + set_b = mod_atoms[other_name] + if set_a == set_b: + # If both modifications have the same alchemical atoms, keep only one + to_remove.add(other_name) + elif not set_a.isdisjoint(set_b): + # Partial overlap detected, raise an error + raise ValueError( + f"Partial overlap detected between modifications '{name}' and '{other_name}'. " + "Please ensure that alchemical atoms either fully overlap or are disjoint." + ) + + for name in to_remove: + self.remove_modification_from_graph(name) + + for _, data in list(self._graph.nodes.data()): + modification = data["modification"] + logger.debug("Created graph of alchemical modifications:\n") for line in nx.generate_network_text( self._graph, vertical_chains=False, ascii_only=True @@ -246,7 +346,10 @@ def apply_modifications( for mod in nx.topological_sort(self._graph): lambda_value = self._graph.nodes[mod]["lambda_value"] mod_instance = self._graph.nodes[mod]["modification"] - mod_kwargs = modifications_kwargs.get(mod, {}) + # Try both instance name and base name for kwargs lookup + mod_kwargs = modifications_kwargs.get( + mod, modifications_kwargs.get(mod_instance.NAME, {}) + ) if lambda_value is None: logger.debug(f"Applying {mod} modification") diff --git a/src/fes_ml/alchemical/modifications/__init__.py b/src/fes_ml/alchemical/modifications/__init__.py index 9e9014c..769f092 100644 --- a/src/fes_ml/alchemical/modifications/__init__.py +++ b/src/fes_ml/alchemical/modifications/__init__.py @@ -1,14 +1,14 @@ """Init file for the modification module.""" from . import ( - charge_scaling, - charge_transfer, - custom_lj, - emle_potential, - intramolecular, - lj_softcore, - ml_correction, - ml_interpolation, - ml_potential, + charge_scaling as charge_scaling, + charge_transfer as charge_transfer, + custom_lj as custom_lj, + emle_potential as emle_potential, + intramolecular as intramolecular, + lj_softcore as lj_softcore, + ml_correction as ml_correction, + ml_interpolation as ml_interpolation, + ml_potential as ml_potential, ) -from .emle_potential import _EMLE_CALCULATORS +from .emle_potential import _EMLE_CALCULATORS as _EMLE_CALCULATORS diff --git a/src/fes_ml/alchemical/modifications/base_modification.py b/src/fes_ml/alchemical/modifications/base_modification.py index f56c8c0..7fe00f3 100644 --- a/src/fes_ml/alchemical/modifications/base_modification.py +++ b/src/fes_ml/alchemical/modifications/base_modification.py @@ -19,12 +19,16 @@ class BaseModification(ABC): - If modification A requires B, add B to A's ``pre_dependencies`` class-level attribute. - If modification A is incompatible with B being present in the graph, - add B to A's ``skip_depencies`` class-level attribute. + add B to A's ``skip_dependencies`` class-level attribute. - If multiple modifications imply B, B is applied only once. - - Modification are applied in topologically sorted order based on + - Modifications are applied in topologically sorted order based on the dependency graph. - - A modification must an associated key controlled by the class-level + - A modification must have a default key controlled by the class-level attribute ``NAME``. + - A modication can also have a custom name by passing a string + ``modification_name`` when instantiating the modification. This + allows multiple instances of the same modification to coexist in + the same system, each with its own parameters. - Modifications can be applied by passing a lambda schedule dictionary when creating alchemical states. This dictionary maps ``NAME``s to λ values for each ``AlchemicalState``. @@ -43,6 +47,22 @@ class BaseModification(ABC): post_dependencies: List[str] = None skip_dependencies: List[str] = None + def __init__(self, modification_name: str = None): + """ + Initialize the BaseModification. + + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. If not provided, + uses the class NAME. + """ + assert modification_name.count(":") <= 1, ( + f"Invalid modification_name '{modification_name}': " + "it may contain at most one ':' to indicate the alchemical group." + ) + self.modification_name = modification_name or self.NAME + def __init_subclass__(cls, **kwargs): """ Initialize the subclass. @@ -135,6 +155,74 @@ def remove_post_dependency(cls, name: str) -> None: else: raise ValueError(f"{name} is not a post-dependency of {cls.NAME}.") + @staticmethod + def find_forces_by_group(system: _mm.System, group: str) -> List[_mm.Force]: + """ + Find all forces in the system that belong to a given alchemical group. + + Notes + ----- + Alchemical groups are defined by suffixes in the force names, e.g., ':region1'. + + Parameters + ---------- + system : openmm.System + The OpenMM system to search. + group : str + The alchemical group to match (e.g., ':region1'). + + Returns + ------- + List[openmm.Force] + List of forces whose names end with the specified alchemical group. + """ + return [ + force + for force in system.getForces() + if force.getName().endswith(f":{group}") + ] + + @staticmethod + def find_force_by_name(system: _mm.System, force_name: str) -> _mm.Force: + """ + Find a force in the system by its full name. + + Parameters + ---------- + system : openmm.System + The OpenMM system to search. + force_name : str + The name of the force to find. + + Returns + ------- + openmm.Force + The force with the specified name. + + Raises + ------ + ValueError + If no force with the specified name is found. + """ + for force in system.getForces(): + if force.getName() == force_name: + return force + raise ValueError(f"Force with name '{force_name}' not found in system.") + + @property + def alchemical_group(self) -> str: + """ + Get the suffix of the current instance name. + + Returns + ------- + str + The suffix part after ':' if present, empty string otherwise. + """ + if ":" in self.modification_name: + return self.modification_name.split(":", 1)[1] + return "" + class BaseModificationFactory(ABC): """ @@ -145,10 +233,17 @@ class BaseModificationFactory(ABC): """ @abstractmethod - def create_modification(self, *args, **kwargs) -> BaseModification: + def create_modification( + self, modification_name: str = None, *args, **kwargs + ) -> BaseModification: """ Create an instance of the modification. + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. + Returns ------- BaseModification diff --git a/src/fes_ml/alchemical/modifications/charge_scaling.py b/src/fes_ml/alchemical/modifications/charge_scaling.py index 60cac64..56e32dd 100644 --- a/src/fes_ml/alchemical/modifications/charge_scaling.py +++ b/src/fes_ml/alchemical/modifications/charge_scaling.py @@ -63,22 +63,22 @@ def apply( system : openmm.System The modified System with the charges scaled. """ - nb_forces = [ + nonbondedforces = [ force for force in system.getForces() if isinstance(force, _mm.NonbondedForce) ] - if len(nb_forces) > 1: + if len(nonbondedforces) > 1: raise ValueError( "The system must not contain more than one NonbondedForce." ) - elif len(nb_forces) == 0: + elif len(nonbondedforces) == 0: logger.warning( - "The system does not contain a NonbondedForce and therefore no charge scaling will be applied.=." + "The system does not contain a NonbondedForce and therefore no charge scaling will be applied." ) return system else: - force = nb_forces[0] + force = nonbondedforces[0] for index in range(system.getNumParticles()): [charge, sigma, epsilon] = force.getParticleParameters(index) if index in alchemical_atoms: diff --git a/src/fes_ml/alchemical/modifications/charge_transfer.py b/src/fes_ml/alchemical/modifications/charge_transfer.py index 12eaf16..a2b3d8d 100644 --- a/src/fes_ml/alchemical/modifications/charge_transfer.py +++ b/src/fes_ml/alchemical/modifications/charge_transfer.py @@ -2,7 +2,6 @@ from typing import List, Optional, Union import openmm as _mm -import openmm.app as _app import openmm.unit as _unit from openff.toolkit.topology import Topology as _Topology from openff.toolkit.typing.engines.smirnoff import ForceField as _ForceField @@ -73,11 +72,11 @@ def get_is_donor_acceptor( break # Acceptor: heavy atoms elif atom.atomic_number == 8 and (symmetric or idx not in alchemical_atoms): - num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) + #num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) is_acceptor[idx] = 1 elif atom.atomic_number == 7: # Nitrogen - num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) + #num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) is_acceptor[idx] = 1 # if num_H != 1 and num_H != 2: # optionally exclude primary/secondary amines # is_acceptor[idx] = 1 @@ -91,6 +90,7 @@ def apply( system: _mm.System, alchemical_atoms: List[int], original_offxml: List[str], + ct_offxml: str, topology_off: _Topology, lambda_value: Optional[Union[float, int]] = 1.0, *args, @@ -105,6 +105,8 @@ def apply( The system to be modified. alchemical_atoms : list of int The indices of the alchemical atoms in the system. + ct_offxml: str + Path to the offxml file containing the charge transfer parameters. original_offxml : List[str] List of paths to the original offxml files. topology_off : openff.toolkit.topology.Topology @@ -121,34 +123,12 @@ def apply( openmm.System The modified system. """ - ct_paraams = { - "n21": {"sigma": 3.1739, "eps": 169.0640}, - "n16": {"sigma": 4.3983, "eps": 42.4290}, - "n3": {"sigma": 5.9554, "eps": 0.1890}, - "n-tip3p-O": {"sigma": 5.5019, "eps": 119.9913}, - "n-tip3p-H": {"sigma": 5.6480, "eps": 3.4209}, - "n20": {"sigma": 4.1808, "eps": 71.9843}, - "n2": {"sigma": 7.6786, "eps": 0.1502}, - "n18": {"sigma": 4.7983, "eps": 78.6935}, - "n14": {"sigma": 3.9557, "eps": 50.0886}, - "n17": {"sigma": 4.6716, "eps": 61.3267}, - "n9": {"sigma": 5.2866, "eps": 0.8104}, - "n4": {"sigma": 5.8806, "eps": 0.4651}, - "n13": {"sigma": 5.3179, "eps": 1.1033}, - "n11": {"sigma": 5.2541, "eps": 0.7611}, - "n19": {"sigma": 4.9034, "eps": 75.1253}, - "n12": {"sigma": 4.6507, "eps": 1.0854}, - "n7": {"sigma": 6.8461, "eps": 0.5855}, - "n15": {"sigma": 4.3835, "eps": 59.0651}, - "n10": {"sigma": 5.3841, "eps": 0.4192}, - "n8": {"sigma": 6.7245, "eps": 0.7046}, - } - - energy_function = f"-{lambda_value}*donor_acceptor*epsilon*exp(-sigma*r);" - energy_function += "sigma = 0.5*(sigma1+sigma2);" - energy_function += "epsilon = sqrt(epsilon1*epsilon2);" + # Convert units + energy_function = f"-{lambda_value}*donor_acceptor*epsilon*exp(-r/sigma);" + energy_function += "sigma = sqrt(sigma1*sigma2);" + energy_function += "epsilon = (epsilon1*epsilon2);" energy_function += ( - "donor_acceptor = isDonor1*isAcceptor2 + isDonor2*isAcceptor1;" + "donor_acceptor = 1;"#isDonor1*isAcceptor2 + isDonor2*isAcceptor1;" ) logger.debug(f"Charge transfer function: {energy_function}") @@ -164,28 +144,43 @@ def apply( # Add per-particle parameters to the CustomNonbondedForce charge_transfer_force.addPerParticleParameter("sigma") charge_transfer_force.addPerParticleParameter("epsilon") - charge_transfer_force.addPerParticleParameter("isDonor") - charge_transfer_force.addPerParticleParameter("isAcceptor") + #charge_transfer_force.addPerParticleParameter("isDonor") + #charge_transfer_force.addPerParticleParameter("isAcceptor") # Update the Lennard-Jones parameters in the CustomNonbondedForce - force_field = _ForceField(*original_offxml) - labels = force_field.label_molecules(topology_off) + #force_field = _ForceField(*original_offxml) + #labels = force_field.label_molecules(topology_off) # Get atom types - atom_types = [val.id for mol in labels for _, val in mol["vdW"].items()] + #atom_types = [val.id for mol in labels for _, val in mol["vdW"].items()] # Get donor/acceptor flags is_donor, is_acceptor = ChargeTransferModification.get_is_donor_acceptor( topology_off, alchemical_atoms ) + # CT force field + ct_force_field = _ForceField(ct_offxml) + labels = ct_force_field.label_molecules(topology_off) + ct_params = { + p.id: { + "epsilon": p.epsilon.to_openmm().value_in_unit( + _unit.kilojoules_per_mole + ), + "sigma": p.sigma.to_openmm().value_in_unit(_unit.nanometer), + } + for p in ct_force_field.get_parameter_handler("vdW") + } + atom_types = [val.id for mol in labels for _, val in mol["vdW"].items()] + for index in range(system.getNumParticles()): + at_type = atom_types[index] charge_transfer_force.addParticle( [ - ct_paraams.get(atom_types[index], {}).get("sigma", 0) * 10, - ct_paraams.get(atom_types[index], {}).get("eps", 0) * 1e3, - is_donor[index], - is_acceptor[index], + ct_params[at_type]["sigma"], + ct_params[at_type]["epsilon"],# * 10.0, + #is_donor[index], + #is_acceptor[index], ] ) diff --git a/src/fes_ml/alchemical/modifications/custom_lj.py b/src/fes_ml/alchemical/modifications/custom_lj.py index f3e23ff..98567e0 100644 --- a/src/fes_ml/alchemical/modifications/custom_lj.py +++ b/src/fes_ml/alchemical/modifications/custom_lj.py @@ -19,6 +19,11 @@ def create_modification(self, *args, **kwargs) -> BaseModification: """ Create an instance of CustomLJModification. + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. + Returns ------- CustomLJModification @@ -76,8 +81,19 @@ def apply( openmm.System The modified system. """ - forces = {force.__class__.__name__: force for force in system.getForces()} - custom_nb_force = forces["CustomNonbondedForce"] + # Find the related LJSoftCore force based on instance naming + alchemical_group = self.alchemical_group + group_forces = self.find_forces_by_group(system, self.alchemical_group) + custom_nb_forces = [ + force + for force in group_forces + if force.getName() == f"LJSoftCore:{alchemical_group}" + ] + if not custom_nb_forces: + raise ValueError( + f"Attempting to modify LJ parameters but no LJSoftCore force found for alchemical group '{alchemical_group}'" + ) + custom_nb_force = custom_nb_forces[0] # Create a dictionary with the optimized Lennard-Jones parameters for each atom type force_field_opt = _ForceField(lj_offxml) @@ -106,8 +122,9 @@ def apply( for index, parameters in enumerate(vdw_parameters): if alchemical_atoms_only and index not in alchemical_atoms: continue - print(f"Setting Lennard-Jones parameters for atom {index} to {parameters}") - print(f"Parameters are {parameters}") custom_nb_force.setParticleParameters(index, parameters) + # Update name (will override LJSoftCore name) + custom_nb_force.setName(self.modification_name) + return system diff --git a/src/fes_ml/alchemical/modifications/emle_potential.py b/src/fes_ml/alchemical/modifications/emle_potential.py index fccc090..535ae0d 100644 --- a/src/fes_ml/alchemical/modifications/emle_potential.py +++ b/src/fes_ml/alchemical/modifications/emle_potential.py @@ -192,10 +192,11 @@ def apply( emle_force, interpolation_force = engine.get_forces() # Add the EMLE force to the system. + emle_force.setName(self.modification_name) system.addForce(emle_force) # Add the interpolation force to the system, so that the EMLE force does not scale # with the MLInterpolation force. - interpolation_force.setName("EMLECustomBondForce") + interpolation_force.setName("EMLECustomBondForce_" + self.modification_name) system.addForce(interpolation_force) # Zero the charges on the atoms within the QM region diff --git a/src/fes_ml/alchemical/modifications/intramolecular.py b/src/fes_ml/alchemical/modifications/intramolecular.py index ee7c56d..1a55d5e 100644 --- a/src/fes_ml/alchemical/modifications/intramolecular.py +++ b/src/fes_ml/alchemical/modifications/intramolecular.py @@ -15,14 +15,8 @@ class IntraMolecularNonBondedExceptionsModificationFactory(BaseModificationFacto """Factory for creating IntraMolecularNonBondedModification instances.""" def create_modification(self, *args, **kwargs) -> BaseModification: - """Create an instance of IntraMolecularNonBondedModification. - - Parameters - ---------- - args : list - Additional arguments to be passed to the modification. - kwargs : dict - Additional keyword arguments to be passed to the modification. + """ + Create an instance of IntraMolecularNonBondedModification. Returns ------- @@ -36,14 +30,8 @@ class IntraMolecularNonBondedForcesModificationFactory(BaseModificationFactory): """Factory for creating IntraMolecularNonBondedForcesModification instances.""" def create_modification(self, *args, **kwargs) -> BaseModification: - """Create an instance of IntraMolecularNonBondedForcesModification. - - Parameters - ---------- - args : list - Additional arguments to be passed to the modification. - kwargs : dict - Additional keyword arguments to be passed to the modification. + """ + Create an instance of IntraMolecularNonBondedForcesModification. Returns ------- @@ -57,14 +45,8 @@ class IntraMolecularBondedRemovalModificationFactory(BaseModificationFactory): """Factory for creating IntraMolecularBondedRemovalModification instances.""" def create_modification(self, *args, **kwargs) -> BaseModification: - """Create an instance of IntraMolecularBondedRemovalModification. - - Parameters - ---------- - args : list - Additional arguments to be passed to the modification. - kwargs : dict - Additional keyword arguments to be passed to the modification. + """ + Create an instance of IntraMolecularBondedRemovalModification. Returns ------- @@ -218,6 +200,17 @@ class IntraMolecularBondedRemovalModification(BaseModification): NAME = "IntraMolecularBondedRemoval" + def __init__(self, modification_name: str = None): + """ + Initialize the IntraMolecularBondedRemovalModification. + + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. + """ + super().__init__(modification_name) + @staticmethod def _should_remove(term_atoms: tuple, atom_set: set, remove_in_set: bool) -> bool: """ diff --git a/src/fes_ml/alchemical/modifications/lj_softcore.py b/src/fes_ml/alchemical/modifications/lj_softcore.py index a308e6e..7db6368 100644 --- a/src/fes_ml/alchemical/modifications/lj_softcore.py +++ b/src/fes_ml/alchemical/modifications/lj_softcore.py @@ -51,6 +51,8 @@ def apply( system: _mm.System, alchemical_atoms: List[int], lambda_value: Optional[Union[float, int]] = 1.0, + include_repulsion: bool = True, + include_attraction: bool = True, *args, **kwargs, ) -> _mm.System: @@ -65,6 +67,12 @@ def apply( The indices of the alchemical atoms in the system. lambda_value : float The value of the alchemical state parameter. + include_repulsion : bool, optional + Whether to include the short-range repulsive term (r^-12 component). + Default is True. + include_attraction : bool, optional + Whether to include the long-range attractive term (r^-6 component). + Default is True. args : tuple Additional arguments to be passed to the modification. kwargs : dict @@ -78,21 +86,54 @@ def apply( forces = {force.__class__.__name__: force for force in system.getForces()} nb_force = forces["NonbondedForce"] + if not include_repulsion and not include_attraction: + raise ValueError("At least one of include_repulsion or include_attraction must be True.") + # Define the softcore Lennard-Jones energy function - energy_function = ( - f"{lambda_value}*4*epsilon*x*(x-1.0); x = (sigma/reff_sterics)^6;" - ) - energy_function += ( - f"reff_sterics = sigma*(0.5*(1.0-{lambda_value}) + (r/sigma)^6)^(1/6);" - ) - energy_function += ( - "sigma = 0.5*(sigma1+sigma2); epsilon = sqrt(epsilon1*epsilon2);" + # Standard LJ: 4*epsilon*(x^2 - x) where x = (sigma/r)^6 + # Softcore LJ: 4*epsilon*(x_soft^2 - x_soft) where x_soft = (sigma/reff)^6 + terms = [] + # Main energy expression pieces + if include_repulsion: + terms.append(f"{lambda_value}*4*epsilon*x*x") + if include_attraction: + terms.append(f"-{lambda_value}*4*epsilon*x*fdamp") + + energy_expr = " + ".join(terms) + + # Full OpenMM expression: E; intermediate definitions follow + expr = ( + f"{energy_expr}; " + "x=(sigma/reff)^6; " + f"reff=sigma*(0.5*(1.0-{lambda_value})+(r/sigma)^6)^(1/6); " + "sigma=0.5*(sigma1+sigma2); " + "epsilon=sqrt(epsilon1*epsilon2); " ) - logger.debug(f"LJ softcore function: {energy_function}") + # Add damping if needed + if include_attraction and not include_repulsion: + expr += ( + # Tang-Toennies f6(b*r) expansion, truncated to 6th order + "fdamp=1.0 - exp(-xdamp)*(1.0 + xdamp + 0.5*xdamp^2 + " + "(1.0/6.0)*xdamp^3 + (1.0/24.0)*xdamp^4 + " + "(1.0/120.0)*xdamp^5 + (1.0/720.0)*xdamp^6);" + "xdamp=bdamp*r; " + "bdamp=1/0.021; " + ) + else: + expr += "fdamp=1.0;" + + """ + energy_function = 'lambda*4*epsilon*x*(x-1.0); x = (sigma/reff_sterics)^6;' + energy_function += 'reff_sterics = sigma*(0.5*(1.0-lambda) + (r/sigma)^6)^(1/6);' + energy_function += 'sigma = 0.5*(sigma1+sigma2); epsilon = sqrt(epsilon1*epsilon2);' + custom_force = CustomNonbondedForce(energy_function) + """ + + logger.debug(f"Softcore LJ expression: {expr}") # Create a CustomNonbondedForce to compute the softcore Lennard-Jones - soft_core_force = _mm.CustomNonbondedForce(energy_function) + soft_core_force = _mm.CustomNonbondedForce(expr) if self._NON_BONDED_METHODS[nb_force.getNonbondedMethod()] in [ self._NON_BONDED_METHODS[3], @@ -135,6 +176,7 @@ def apply( soft_core_force.addInteractionGroup(alchemical_atoms, mm_atoms) # Add the CustomNonbondedForce to the System + soft_core_force.setName(self.modification_name) system.addForce(soft_core_force) return system diff --git a/src/fes_ml/alchemical/modifications/ml_correction.py b/src/fes_ml/alchemical/modifications/ml_correction.py index 11439b4..31ea327 100644 --- a/src/fes_ml/alchemical/modifications/ml_correction.py +++ b/src/fes_ml/alchemical/modifications/ml_correction.py @@ -6,12 +6,9 @@ import openmm as _mm from .base_modification import BaseModification, BaseModificationFactory -from .intramolecular import ( - IntraMolecularBondedRemovalModification, - IntraMolecularNonBondedExceptionsModification, -) from .ml_base_modification import MLBaseModification from .ml_potential import MLPotentialModification +from .intramolecular import IntraMolecularBondedRemovalModification, IntraMolecularNonBondedExceptionsModification logger = logging.getLogger(__name__) @@ -42,12 +39,8 @@ class MLCorrectionModification(MLBaseModification, BaseModification): NAME = "MLCorrection" pre_dependencies: List[str] = [MLPotentialModification.NAME] - post_dependencies: List[str] = [ - IntraMolecularNonBondedExceptionsModification.NAME, - ] - skip_dependencies: List[str] = [ - IntraMolecularBondedRemovalModification.NAME, - ] + post_dependencies = [IntraMolecularNonBondedExceptionsModification.NAME] + skip_dependencies: List[str] = [IntraMolecularBondedRemovalModification.NAME] def apply( self, @@ -97,7 +90,7 @@ def apply( mm_sum = "+".join(mm_vars) if len(mm_vars) > 0 else "0" ml_interpolation_function = f"lambda_interpolate*({ml_sum} - ({mm_sum}))" cv.setEnergyFunction(ml_interpolation_function) - cv.setName(self.NAME) + cv.setName(self.modification_name) system.addForce(cv) logger.debug(f"ML correction function: {ml_interpolation_function}") diff --git a/src/fes_ml/alchemical/modifications/ml_interpolation.py b/src/fes_ml/alchemical/modifications/ml_interpolation.py index 49173af..3e7280d 100644 --- a/src/fes_ml/alchemical/modifications/ml_interpolation.py +++ b/src/fes_ml/alchemical/modifications/ml_interpolation.py @@ -7,6 +7,8 @@ from .base_modification import BaseModification, BaseModificationFactory from .ml_base_modification import MLBaseModification +from .ml_potential import MLPotentialModification +from .intramolecular import IntraMolecularBondedRemovalModification, IntraMolecularNonBondedExceptionsModification logger = logging.getLogger(__name__) @@ -36,10 +38,10 @@ class MLInterpolationModification(MLBaseModification, BaseModification): """Class to add a CustomCVForce to interpolate between ML and MM forces.""" NAME = "MLInterpolation" - pre_dependencies = ["MLPotential"] - post_dependencies = [ - "IntraMolecularNonBondedExceptions", - "IntraMolecularBondedRemoval", + pre_dependencies = [MLPotentialModification.NAME] + post_dependencies: List[str] = [ + IntraMolecularBondedRemovalModification.NAME, + IntraMolecularNonBondedExceptionsModification.NAME, ] def apply( @@ -92,7 +94,7 @@ def apply( ) cv.setEnergyFunction(ml_interpolation_function) system.addForce(cv) - cv.setName(self.NAME) + cv.setName(self.modification_name) logger.debug(f"ML interpolation function: {ml_interpolation_function}") diff --git a/src/fes_ml/alchemical/strategies/base_strategy.py b/src/fes_ml/alchemical/strategies/base_strategy.py index 602380b..a401cf6 100644 --- a/src/fes_ml/alchemical/strategies/base_strategy.py +++ b/src/fes_ml/alchemical/strategies/base_strategy.py @@ -33,6 +33,7 @@ def _run_alchemist( system: _mm.System, alchemical_atoms: List[int], lambda_schedule: Dict[str, Union[float, int]], + modifications_kwargs: Optional[Dict[str, Dict[str, Any]]] = None, *args, **kwargs, ) -> None: @@ -47,16 +48,34 @@ def _run_alchemist( The system to be modified. alchemical_atoms : list of int The list of alchemical atoms. + modifications_kwargs : dict, optional + A dictionary of keyword arguments for the modifications. + It is structured as follows: + { + "modification_name": { + "key1": value1, + "key2": value2, + ... + }, + ... + } args : list Additional arguments to be passed to the Alchemist ``apply_modifications`` method. kwargs : dict Additional keyword arguments to be passed to the Alchemist ``apply_modifications`` method. """ alchemist = Alchemist() - alchemist.create_alchemical_graph(lambda_schedule) + alchemist.create_alchemical_graph( + lambda_schedule, modifications_kwargs=modifications_kwargs + ) + + if modifications_kwargs is None: + modifications_kwargs = {} + alchemist.apply_modifications( system, alchemical_atoms, + modifications_kwargs, *args, **kwargs, ) @@ -87,6 +106,56 @@ def _report_dict( if initial: logger.debug("+" + "-" * 98 + "+") + @staticmethod + def _has_modification_type( + lambda_schedule: Dict[str, Union[float, int]], base_name: str + ) -> bool: + """ + Check if any modification of a given base type exists in the lambda schedule. + + Parameters + ---------- + lambda_schedule : Dict[str, Union[float, int]] + Dictionary mapping modification names to lambda values. + base_name : str + The base modification type name (e.g., "EMLEPotential", "CustomLJ"). + + Returns + ------- + bool + True if any modification of the base type exists, False otherwise. + """ + return any( + key == base_name or key.startswith(f"{base_name}:") + for key in lambda_schedule.keys() + ) + + @staticmethod + def _get_modification_instances( + lambda_schedule: Dict[str, Union[float, int]], base_name: str + ) -> List[str]: + """ + Get all modification instance names of a given base type from the lambda schedule. + + Parameters + ---------- + lambda_schedule : Dict[str, Union[float, int]] + Dictionary mapping modification names to lambda values. + base_name : str + The base modification type name (e.g., "EMLEPotential", "CustomLJ"). + + Returns + ------- + List[str] + List of all modification instance names matching the base type. + Examples: ["CustomLJ", "CustomLJ:region1", "CustomLJ:region2"] + """ + return [ + key + for key in lambda_schedule.keys() + if key == base_name or key.startswith(f"{base_name}:") + ] + @staticmethod def _report_energy_decomposition(context, system) -> None: """ diff --git a/src/fes_ml/alchemical/strategies/openff_strategy.py b/src/fes_ml/alchemical/strategies/openff_strategy.py index d702c87..321181a 100644 --- a/src/fes_ml/alchemical/strategies/openff_strategy.py +++ b/src/fes_ml/alchemical/strategies/openff_strategy.py @@ -333,11 +333,14 @@ def _solvate( packmol_kwargs_local.pop("number_of_copies") - topology_off = _pack_box( - molecules=mols, - number_of_copies=number_of_copies, - **packmol_kwargs_local, - ) + if number_of_copies[1] == 1: + topology_off = _Topology.from_molecules(mols) + else: + topology_off = _pack_box( + molecules=mols, + number_of_copies=number_of_copies, + **packmol_kwargs_local, + ) return topology_off @@ -593,7 +596,7 @@ def create_alchemical_state( if sdf_file_ligand is None: # Only generate conformers if no SDF file is provided # Otherwise, the geometry is taken from the SDF file - logger.debug("Generating conformers for the ligand") + logger.debug("Generating conformers for the ligand.") molecules["ligand"].generate_conformers() else: logger.debug(f"Using provided ligand geometry from {sdf_file_ligand}") @@ -681,21 +684,28 @@ def create_alchemical_state( # Create/update the modifications kwargs modifications_kwargs = _deepcopy(modifications_kwargs) or {} - if any(key in lambda_schedule for key in ["EMLEPotential"]): - modifications_kwargs["EMLEPotential"] = modifications_kwargs.get( - "EMLEPotential", {} - ) - - # Import required + # Handle EMLEPotential modifications + emle_instances = self._get_modification_instances( + lambda_schedule, "EMLEPotential" + ) + if emle_instances: import numpy as _np import sire as _sr + # For now, we apply the same kwargs to all instances + # In the future, this could be made per-instance specific + for modification_name in emle_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + # Write .top and .gro files via the OpenFF interchange if _os.path.exists(self._TMP_DIR): _shutil.rmtree(self._TMP_DIR) _os.makedirs(self._TMP_DIR, exist_ok=True) files_prefix = _os.path.join(self._TMP_DIR, "interchange") - interchange.to_gromacs(prefix=files_prefix) + interchange.to_gro(files_prefix + ".gro") + interchange.to_top(files_prefix + ".top") # Read back those files using Sire sr_mols = _sr.load( @@ -713,39 +723,58 @@ def create_alchemical_state( format=["prm7"], ) - # Add required EMLEPotential kwargs to the modifications_kwargs dict - modifications_kwargs["EMLEPotential"]["mols"] = sr_mols - modifications_kwargs["EMLEPotential"]["parm7"] = alchemical_prm7[0] - modifications_kwargs["EMLEPotential"]["mm_charges"] = _np.asarray( - [atom.charge().value() for atom in sr_mols.atoms(alchemical_atoms)] - ) - # Get the original charges of the OpenMM system openmm_charges = self._get_openmm_charges(system) - modifications_kwargs["EMLEPotential"]["openmm_charges"] = openmm_charges - if any( - key in lambda_schedule - for key in ["MLPotential", "MLInterpolation", "MLCorrection"] - ): + + # Add required EMLEPotential kwargs to all instances + for modification_name in emle_instances: + modifications_kwargs[modification_name]["mols"] = sr_mols + modifications_kwargs[modification_name]["parm7"] = alchemical_prm7[0] + modifications_kwargs[modification_name]["mm_charges"] = _np.asarray( + [atom.charge().value() for atom in sr_mols.atoms(alchemical_atoms)] + ) + modifications_kwargs[modification_name][ + "openmm_charges" + ] = openmm_charges + + # Handle ML-related modifications + ml_types = ["MLPotential", "MLInterpolation", "MLCorrection"] + ml_instances = [] + for ml_type in ml_types: + ml_instances.extend( + self._get_modification_instances(lambda_schedule, ml_type) + ) + + if ml_instances: modifications_kwargs["MLPotential"] = modifications_kwargs.get( "MLPotential", {} ) modifications_kwargs["MLPotential"]["topology"] = topology - if any(key in lambda_schedule for key in ["CustomLJ"]): - modifications_kwargs["CustomLJ"] = modifications_kwargs.get("CustomLJ", {}) - modifications_kwargs["CustomLJ"]["original_offxml"] = ffs - modifications_kwargs["CustomLJ"]["topology_off"] = topology_off - - # Get the positions of the alchemical atoms - modifications_kwargs["CustomLJ"]["positions"] = positions + # Handle CustomLJ modifications + customlj_instances = self._get_modification_instances( + lambda_schedule, "CustomLJ" + ) + if customlj_instances: + for modification_name in customlj_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["original_offxml"] = ffs + modifications_kwargs[modification_name]["topology_off"] = topology_off + modifications_kwargs[modification_name]["positions"] = positions - if any(key in lambda_schedule for key in ["ChargeTransfer"]): - modifications_kwargs["ChargeTransfer"] = modifications_kwargs.get( - "ChargeTransfer", {} - ) - modifications_kwargs["ChargeTransfer"]["original_offxml"] = ffs - modifications_kwargs["ChargeTransfer"]["topology_off"] = topology_off + # Handle ChargeTransfer modifications + chargetransfer_instances = self._get_modification_instances( + lambda_schedule, "ChargeTransfer" + ) + if chargetransfer_instances: + for modification_name in chargetransfer_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["original_offxml"] = ffs + modifications_kwargs[modification_name]["topology_off"] = topology_off # Run the Alchemist self._run_alchemist( diff --git a/src/fes_ml/alchemical/strategies/sire_strategy.py b/src/fes_ml/alchemical/strategies/sire_strategy.py index e92f8fb..8ad57a0 100644 --- a/src/fes_ml/alchemical/strategies/sire_strategy.py +++ b/src/fes_ml/alchemical/strategies/sire_strategy.py @@ -136,19 +136,31 @@ def create_alchemical_state( self._report_energy_decomposition(omm_context, omm_system) modifications_kwargs = _deepcopy(modifications_kwargs) or {} - if any(key in lambda_schedule for key in ["EMLEPotential", "MLInterpolation"]): - modifications_kwargs["EMLEPotential"] = modifications_kwargs.get( - "EMLEPotential", {} - ) - modifications_kwargs["EMLEPotential"]["mols"] = mols - modifications_kwargs["EMLEPotential"]["parm7"] = alchemical_prm7[0] - modifications_kwargs["EMLEPotential"]["mm_charges"] = _np.asarray( - [atom.charge().value() for atom in mols.atoms(alchemical_atoms)] + + # Handle EMLEPotential modifications + emle_instances = self._get_modification_instances( + lambda_schedule, "EMLEPotential" + ) + if emle_instances: + for modification_name in emle_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["mols"] = mols + modifications_kwargs[modification_name]["parm7"] = alchemical_prm7[0] + modifications_kwargs[modification_name]["mm_charges"] = _np.asarray( + [atom.charge().value() for atom in mols.atoms(alchemical_atoms)] + ) + + # Handle ML-related modifications + ml_types = ["MLPotential", "MLInterpolation", "MLCorrection"] + ml_instances = [] + for ml_type in ml_types: + ml_instances.extend( + self._get_modification_instances(lambda_schedule, ml_type) ) - if any( - key in lambda_schedule - for key in ["MLPotential", "MLInterpolation", "MLCorrection"] - ): + + if ml_instances: modifications_kwargs["MLPotential"] = modifications_kwargs.get( "MLPotential", {} ) diff --git a/tests/test_alchemical_states.py b/tests/test_alchemical_states.py index 24d58cc..9b8519b 100644 --- a/tests/test_alchemical_states.py +++ b/tests/test_alchemical_states.py @@ -244,7 +244,7 @@ def _test_energy_decomposition( non_bonded_forces = [ "NonbondedForce", "CustomBondForce", - "CustomNonbondedForce", + "LJSoftCore:default", ] alc_nonbonded_force = sum( alc_energy_decom[force]._value