From 124e6dd4e4851a714e0e33c64e0ade40c6ab28d0 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 16:43:17 +0100 Subject: [PATCH 01/20] Add alchemical groups --- src/fes_ml/alchemical/alchemist.py | 160 ++++++++++++++---- .../modifications/base_modification.py | 103 ++++++++++- .../modifications/charge_scaling.py | 10 +- .../alchemical/modifications/custom_lj.py | 25 ++- .../modifications/emle_potential.py | 3 +- .../modifications/intramolecular.py | 41 ++--- .../alchemical/modifications/lj_softcore.py | 1 + .../alchemical/modifications/ml_correction.py | 2 +- .../modifications/ml_interpolation.py | 2 +- .../alchemical/strategies/base_strategy.py | 71 +++++++- .../alchemical/strategies/openff_strategy.py | 95 +++++++---- .../alchemical/strategies/sire_strategy.py | 45 +++-- 12 files changed, 437 insertions(+), 121 deletions(-) diff --git a/src/fes_ml/alchemical/alchemist.py b/src/fes_ml/alchemical/alchemist.py index 528deae..991877e 100644 --- a/src/fes_ml/alchemical/alchemist.py +++ b/src/fes_ml/alchemical/alchemist.py @@ -3,6 +3,7 @@ import logging import sys from copy import deepcopy as _deepcopy +from importlib.metadata import entry_points from typing import Any, Dict, List, Optional import networkx as nx @@ -10,11 +11,6 @@ from .modifications.base_modification import BaseModification, BaseModificationFactory -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points - logger = logging.getLogger(__name__) @@ -22,6 +18,7 @@ class Alchemist: """A class for applying alchemical modifications to an OpenMM system.""" _modification_factories: Dict[str, BaseModificationFactory] = {} + _DEFAULT_ALCHEMICAL_GROUP = ":default" @staticmethod def register_modification_factory( @@ -100,13 +97,13 @@ def add_modification_to_graph( lambda_value : float The value of the alchemical state parameter. """ - if modification.NAME in self._graph.nodes and lambda_value is None: - lambda_value = self._graph.nodes[modification.NAME].get( - "lambda_value", None - ) + node_name = modification.modification_name + + if node_name in self._graph.nodes and lambda_value is None: + lambda_value = self._graph.nodes[node_name].get("lambda_value", None) self._graph.add_node( - modification.NAME, modification=modification, lambda_value=lambda_value + node_name, modification=modification, lambda_value=lambda_value ) if modification.pre_dependencies is not None: @@ -119,10 +116,18 @@ def add_modification_to_graph( "modification is implemented and registered as an entry point." ) - factory = self._modification_factories[pre_dependency] - pre_modification = factory.create_modification() - self._graph.add_edge(pre_modification.NAME, modification.NAME) - self.add_modification_to_graph(pre_modification, None) + # Check if dependency already exists in graph + dep_modification_name = ( + f"{pre_dependency}:{modification.alchemical_group}" + ) + if dep_modification_name not in self._graph.nodes: + factory = self._modification_factories[pre_dependency] + pre_modification = factory.create_modification( + modification_name=dep_modification_name + ) + self.add_modification_to_graph(pre_modification, None) + + self._graph.add_edge(dep_modification_name, node_name) if modification.post_dependencies is not None: for post_dependency in modification.post_dependencies: @@ -133,10 +138,19 @@ def add_modification_to_graph( "typos in the name of this post-dependency and that the target " "modification is implemented and registered as an entry point." ) - factory = self._modification_factories[post_dependency] - post_modification = factory.create_modification() - self._graph.add_edge(modification.NAME, post_modification.NAME) - self.add_modification_to_graph(post_modification, None) + + # Check if dependency already exists in graph + dep_modification_name = ( + f"{post_dependency}:{modification.alchemical_group}" + ) + if dep_modification_name not in self._graph.nodes: + factory = self._modification_factories[post_dependency] + post_modification = factory.create_modification( + modification_name=dep_modification_name + ) + self.add_modification_to_graph(post_modification, None) + + self._graph.add_edge(node_name, dep_modification_name) def remove_modification_from_graph(self, modification: str) -> None: """ @@ -153,6 +167,7 @@ def create_alchemical_graph( self, lambda_schedule: Dict[str, float], additional_modifications: Optional[List[str]] = None, + modifications_kwargs: Optional[Dict[str, Dict[str, Any]]] = None, ): """ Create a graph of alchemical modifications to apply. @@ -163,6 +178,8 @@ def create_alchemical_graph( A dictionary of λ values to be applied to the system. additional_modifications : list of str Additional modifications to apply. + modifications_kwargs : dict + A dictionary of keyword arguments for the modifications. Returns ------- @@ -171,25 +188,65 @@ def create_alchemical_graph( """ logger.debug("Creating graph of alchemical modifications.") for name, lambda_value in lambda_schedule.items(): - if name in Alchemist._modification_factories: - factory = self._modification_factories[name] - modification = factory.create_modification() + if ":" in name: + base_name, modification_name = name.split(":", 1)[0], name + else: + base_name, modification_name = ( + name, + name + self._DEFAULT_ALCHEMICAL_GROUP, + ) + + if modifications_kwargs is not None and name in modifications_kwargs: + if modification_name in modifications_kwargs: + raise ValueError( + f"Cannot rename '{name}' to '{modification_name}': " + "key already exists in modifications_kwargs." + ) + modifications_kwargs[modification_name] = modifications_kwargs.pop(name) + + if base_name in Alchemist._modification_factories: + factory = self._modification_factories[base_name] + modification = factory.create_modification( + modification_name=modification_name + ) self.add_modification_to_graph(modification, lambda_value=lambda_value) else: - raise ValueError(f"Modification {name} not found in the factories.") + raise ValueError( + f"Modification {base_name} not found in the factories." + ) if additional_modifications is not None: for name in additional_modifications: - if name in Alchemist._modification_factories: - factory = self._modification_factories[name] - modification = factory.create_modification() - self.add_modification_to_graph( - modification, lambda_value=lambda_value + if ":" in name: + base_name, modification_name = name.split(":", 1)[0], name + else: + base_name, modification_name = ( + name, + name + self._DEFAULT_ALCHEMICAL_GROUP, + ) + + if modifications_kwargs is not None and name in modifications_kwargs: + if modification_name in modifications_kwargs: + raise ValueError( + f"Cannot rename '{name}' to '{modification_name}': " + "key already exists in modifications_kwargs." + ) + modifications_kwargs[modification_name] = modifications_kwargs.pop( + name ) + + if base_name in Alchemist._modification_factories: + factory = self._modification_factories[base_name] + modification = factory.create_modification( + modification_name=modification_name + ) + self.add_modification_to_graph(modification, lambda_value=None) else: - raise ValueError(f"Modification {name} not found in the factories.") + raise ValueError( + f"Modification {base_name} not found in the factories." + ) - # After the graph is created, removed dependencies to skip + # After constructing the graph, remove dependencies to skip ref_graph = _deepcopy(self._graph) for _, data in ref_graph.nodes.data(): modification = data["modification"] @@ -197,6 +254,46 @@ def create_alchemical_graph( for skip_dependency in modification.skip_dependencies: self.remove_modification_from_graph(skip_dependency) + # After constructing the graph, remove redundant modifications + # Redundancy is determined using a binary overlap principle: + # - Total overlap: if two modifications have identical alchemical atoms, keep only one + # - No overlap: if the alchemical atoms are disjoint, keep both modifications + # - Partial overlap: if the alchemical atoms partially overlap, raise an error + # (this behavior may be implemented in the future) + modifications_kwargs = modifications_kwargs or {} + redundant_modifications = [ + name for name in self._graph.nodes if name not in lambda_schedule + ] + mod_atoms = { + name: set(modifications_kwargs.get(name, {}).get("alchemical_atoms", [])) + for name in redundant_modifications + } + to_remove = set() + for i, name in enumerate(redundant_modifications): + base_name = name.split(":", 1)[0] + set_a = mod_atoms[name] + for j in range(i + 1, len(redundant_modifications)): + other_name = redundant_modifications[j] + other_base_name = other_name.split(":", 1)[0] + if base_name != other_base_name: + continue + set_b = mod_atoms[other_name] + if set_a == set_b: + # If both modifications have the same alchemical atoms, keep only one + to_remove.add(other_name) + elif not set_a.isdisjoint(set_b): + # Partial overlap detected, raise an error + raise ValueError( + f"Partial overlap detected between modifications '{name}' and '{other_name}'. " + "Please ensure that alchemical atoms either fully overlap or are disjoint." + ) + + for name in to_remove: + self.remove_modification_from_graph(name) + + for _, data in list(self._graph.nodes.data()): + modification = data["modification"] + logger.debug("Created graph of alchemical modifications:\n") for line in nx.generate_network_text( self._graph, vertical_chains=False, ascii_only=True @@ -246,7 +343,10 @@ def apply_modifications( for mod in nx.topological_sort(self._graph): lambda_value = self._graph.nodes[mod]["lambda_value"] mod_instance = self._graph.nodes[mod]["modification"] - mod_kwargs = modifications_kwargs.get(mod, {}) + # Try both instance name and base name for kwargs lookup + mod_kwargs = modifications_kwargs.get( + mod, modifications_kwargs.get(mod_instance.NAME, {}) + ) if lambda_value is None: logger.debug(f"Applying {mod} modification") diff --git a/src/fes_ml/alchemical/modifications/base_modification.py b/src/fes_ml/alchemical/modifications/base_modification.py index f56c8c0..7fe00f3 100644 --- a/src/fes_ml/alchemical/modifications/base_modification.py +++ b/src/fes_ml/alchemical/modifications/base_modification.py @@ -19,12 +19,16 @@ class BaseModification(ABC): - If modification A requires B, add B to A's ``pre_dependencies`` class-level attribute. - If modification A is incompatible with B being present in the graph, - add B to A's ``skip_depencies`` class-level attribute. + add B to A's ``skip_dependencies`` class-level attribute. - If multiple modifications imply B, B is applied only once. - - Modification are applied in topologically sorted order based on + - Modifications are applied in topologically sorted order based on the dependency graph. - - A modification must an associated key controlled by the class-level + - A modification must have a default key controlled by the class-level attribute ``NAME``. + - A modication can also have a custom name by passing a string + ``modification_name`` when instantiating the modification. This + allows multiple instances of the same modification to coexist in + the same system, each with its own parameters. - Modifications can be applied by passing a lambda schedule dictionary when creating alchemical states. This dictionary maps ``NAME``s to λ values for each ``AlchemicalState``. @@ -43,6 +47,22 @@ class BaseModification(ABC): post_dependencies: List[str] = None skip_dependencies: List[str] = None + def __init__(self, modification_name: str = None): + """ + Initialize the BaseModification. + + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. If not provided, + uses the class NAME. + """ + assert modification_name.count(":") <= 1, ( + f"Invalid modification_name '{modification_name}': " + "it may contain at most one ':' to indicate the alchemical group." + ) + self.modification_name = modification_name or self.NAME + def __init_subclass__(cls, **kwargs): """ Initialize the subclass. @@ -135,6 +155,74 @@ def remove_post_dependency(cls, name: str) -> None: else: raise ValueError(f"{name} is not a post-dependency of {cls.NAME}.") + @staticmethod + def find_forces_by_group(system: _mm.System, group: str) -> List[_mm.Force]: + """ + Find all forces in the system that belong to a given alchemical group. + + Notes + ----- + Alchemical groups are defined by suffixes in the force names, e.g., ':region1'. + + Parameters + ---------- + system : openmm.System + The OpenMM system to search. + group : str + The alchemical group to match (e.g., ':region1'). + + Returns + ------- + List[openmm.Force] + List of forces whose names end with the specified alchemical group. + """ + return [ + force + for force in system.getForces() + if force.getName().endswith(f":{group}") + ] + + @staticmethod + def find_force_by_name(system: _mm.System, force_name: str) -> _mm.Force: + """ + Find a force in the system by its full name. + + Parameters + ---------- + system : openmm.System + The OpenMM system to search. + force_name : str + The name of the force to find. + + Returns + ------- + openmm.Force + The force with the specified name. + + Raises + ------ + ValueError + If no force with the specified name is found. + """ + for force in system.getForces(): + if force.getName() == force_name: + return force + raise ValueError(f"Force with name '{force_name}' not found in system.") + + @property + def alchemical_group(self) -> str: + """ + Get the suffix of the current instance name. + + Returns + ------- + str + The suffix part after ':' if present, empty string otherwise. + """ + if ":" in self.modification_name: + return self.modification_name.split(":", 1)[1] + return "" + class BaseModificationFactory(ABC): """ @@ -145,10 +233,17 @@ class BaseModificationFactory(ABC): """ @abstractmethod - def create_modification(self, *args, **kwargs) -> BaseModification: + def create_modification( + self, modification_name: str = None, *args, **kwargs + ) -> BaseModification: """ Create an instance of the modification. + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. + Returns ------- BaseModification diff --git a/src/fes_ml/alchemical/modifications/charge_scaling.py b/src/fes_ml/alchemical/modifications/charge_scaling.py index 60cac64..56e32dd 100644 --- a/src/fes_ml/alchemical/modifications/charge_scaling.py +++ b/src/fes_ml/alchemical/modifications/charge_scaling.py @@ -63,22 +63,22 @@ def apply( system : openmm.System The modified System with the charges scaled. """ - nb_forces = [ + nonbondedforces = [ force for force in system.getForces() if isinstance(force, _mm.NonbondedForce) ] - if len(nb_forces) > 1: + if len(nonbondedforces) > 1: raise ValueError( "The system must not contain more than one NonbondedForce." ) - elif len(nb_forces) == 0: + elif len(nonbondedforces) == 0: logger.warning( - "The system does not contain a NonbondedForce and therefore no charge scaling will be applied.=." + "The system does not contain a NonbondedForce and therefore no charge scaling will be applied." ) return system else: - force = nb_forces[0] + force = nonbondedforces[0] for index in range(system.getNumParticles()): [charge, sigma, epsilon] = force.getParticleParameters(index) if index in alchemical_atoms: diff --git a/src/fes_ml/alchemical/modifications/custom_lj.py b/src/fes_ml/alchemical/modifications/custom_lj.py index f3e23ff..98567e0 100644 --- a/src/fes_ml/alchemical/modifications/custom_lj.py +++ b/src/fes_ml/alchemical/modifications/custom_lj.py @@ -19,6 +19,11 @@ def create_modification(self, *args, **kwargs) -> BaseModification: """ Create an instance of CustomLJModification. + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. + Returns ------- CustomLJModification @@ -76,8 +81,19 @@ def apply( openmm.System The modified system. """ - forces = {force.__class__.__name__: force for force in system.getForces()} - custom_nb_force = forces["CustomNonbondedForce"] + # Find the related LJSoftCore force based on instance naming + alchemical_group = self.alchemical_group + group_forces = self.find_forces_by_group(system, self.alchemical_group) + custom_nb_forces = [ + force + for force in group_forces + if force.getName() == f"LJSoftCore:{alchemical_group}" + ] + if not custom_nb_forces: + raise ValueError( + f"Attempting to modify LJ parameters but no LJSoftCore force found for alchemical group '{alchemical_group}'" + ) + custom_nb_force = custom_nb_forces[0] # Create a dictionary with the optimized Lennard-Jones parameters for each atom type force_field_opt = _ForceField(lj_offxml) @@ -106,8 +122,9 @@ def apply( for index, parameters in enumerate(vdw_parameters): if alchemical_atoms_only and index not in alchemical_atoms: continue - print(f"Setting Lennard-Jones parameters for atom {index} to {parameters}") - print(f"Parameters are {parameters}") custom_nb_force.setParticleParameters(index, parameters) + # Update name (will override LJSoftCore name) + custom_nb_force.setName(self.modification_name) + return system diff --git a/src/fes_ml/alchemical/modifications/emle_potential.py b/src/fes_ml/alchemical/modifications/emle_potential.py index fccc090..535ae0d 100644 --- a/src/fes_ml/alchemical/modifications/emle_potential.py +++ b/src/fes_ml/alchemical/modifications/emle_potential.py @@ -192,10 +192,11 @@ def apply( emle_force, interpolation_force = engine.get_forces() # Add the EMLE force to the system. + emle_force.setName(self.modification_name) system.addForce(emle_force) # Add the interpolation force to the system, so that the EMLE force does not scale # with the MLInterpolation force. - interpolation_force.setName("EMLECustomBondForce") + interpolation_force.setName("EMLECustomBondForce_" + self.modification_name) system.addForce(interpolation_force) # Zero the charges on the atoms within the QM region diff --git a/src/fes_ml/alchemical/modifications/intramolecular.py b/src/fes_ml/alchemical/modifications/intramolecular.py index ee7c56d..1a55d5e 100644 --- a/src/fes_ml/alchemical/modifications/intramolecular.py +++ b/src/fes_ml/alchemical/modifications/intramolecular.py @@ -15,14 +15,8 @@ class IntraMolecularNonBondedExceptionsModificationFactory(BaseModificationFacto """Factory for creating IntraMolecularNonBondedModification instances.""" def create_modification(self, *args, **kwargs) -> BaseModification: - """Create an instance of IntraMolecularNonBondedModification. - - Parameters - ---------- - args : list - Additional arguments to be passed to the modification. - kwargs : dict - Additional keyword arguments to be passed to the modification. + """ + Create an instance of IntraMolecularNonBondedModification. Returns ------- @@ -36,14 +30,8 @@ class IntraMolecularNonBondedForcesModificationFactory(BaseModificationFactory): """Factory for creating IntraMolecularNonBondedForcesModification instances.""" def create_modification(self, *args, **kwargs) -> BaseModification: - """Create an instance of IntraMolecularNonBondedForcesModification. - - Parameters - ---------- - args : list - Additional arguments to be passed to the modification. - kwargs : dict - Additional keyword arguments to be passed to the modification. + """ + Create an instance of IntraMolecularNonBondedForcesModification. Returns ------- @@ -57,14 +45,8 @@ class IntraMolecularBondedRemovalModificationFactory(BaseModificationFactory): """Factory for creating IntraMolecularBondedRemovalModification instances.""" def create_modification(self, *args, **kwargs) -> BaseModification: - """Create an instance of IntraMolecularBondedRemovalModification. - - Parameters - ---------- - args : list - Additional arguments to be passed to the modification. - kwargs : dict - Additional keyword arguments to be passed to the modification. + """ + Create an instance of IntraMolecularBondedRemovalModification. Returns ------- @@ -218,6 +200,17 @@ class IntraMolecularBondedRemovalModification(BaseModification): NAME = "IntraMolecularBondedRemoval" + def __init__(self, modification_name: str = None): + """ + Initialize the IntraMolecularBondedRemovalModification. + + Parameters + ---------- + modification_name : str, optional + Custom name for this modification instance. + """ + super().__init__(modification_name) + @staticmethod def _should_remove(term_atoms: tuple, atom_set: set, remove_in_set: bool) -> bool: """ diff --git a/src/fes_ml/alchemical/modifications/lj_softcore.py b/src/fes_ml/alchemical/modifications/lj_softcore.py index a308e6e..b796d43 100644 --- a/src/fes_ml/alchemical/modifications/lj_softcore.py +++ b/src/fes_ml/alchemical/modifications/lj_softcore.py @@ -135,6 +135,7 @@ def apply( soft_core_force.addInteractionGroup(alchemical_atoms, mm_atoms) # Add the CustomNonbondedForce to the System + soft_core_force.setName(self.modification_name) system.addForce(soft_core_force) return system diff --git a/src/fes_ml/alchemical/modifications/ml_correction.py b/src/fes_ml/alchemical/modifications/ml_correction.py index 11439b4..40776da 100644 --- a/src/fes_ml/alchemical/modifications/ml_correction.py +++ b/src/fes_ml/alchemical/modifications/ml_correction.py @@ -97,7 +97,7 @@ def apply( mm_sum = "+".join(mm_vars) if len(mm_vars) > 0 else "0" ml_interpolation_function = f"lambda_interpolate*({ml_sum} - ({mm_sum}))" cv.setEnergyFunction(ml_interpolation_function) - cv.setName(self.NAME) + cv.setName(self.modification_name) system.addForce(cv) logger.debug(f"ML correction function: {ml_interpolation_function}") diff --git a/src/fes_ml/alchemical/modifications/ml_interpolation.py b/src/fes_ml/alchemical/modifications/ml_interpolation.py index 49173af..b51e1c0 100644 --- a/src/fes_ml/alchemical/modifications/ml_interpolation.py +++ b/src/fes_ml/alchemical/modifications/ml_interpolation.py @@ -92,7 +92,7 @@ def apply( ) cv.setEnergyFunction(ml_interpolation_function) system.addForce(cv) - cv.setName(self.NAME) + cv.setName(self.modification_name) logger.debug(f"ML interpolation function: {ml_interpolation_function}") diff --git a/src/fes_ml/alchemical/strategies/base_strategy.py b/src/fes_ml/alchemical/strategies/base_strategy.py index 602380b..a401cf6 100644 --- a/src/fes_ml/alchemical/strategies/base_strategy.py +++ b/src/fes_ml/alchemical/strategies/base_strategy.py @@ -33,6 +33,7 @@ def _run_alchemist( system: _mm.System, alchemical_atoms: List[int], lambda_schedule: Dict[str, Union[float, int]], + modifications_kwargs: Optional[Dict[str, Dict[str, Any]]] = None, *args, **kwargs, ) -> None: @@ -47,16 +48,34 @@ def _run_alchemist( The system to be modified. alchemical_atoms : list of int The list of alchemical atoms. + modifications_kwargs : dict, optional + A dictionary of keyword arguments for the modifications. + It is structured as follows: + { + "modification_name": { + "key1": value1, + "key2": value2, + ... + }, + ... + } args : list Additional arguments to be passed to the Alchemist ``apply_modifications`` method. kwargs : dict Additional keyword arguments to be passed to the Alchemist ``apply_modifications`` method. """ alchemist = Alchemist() - alchemist.create_alchemical_graph(lambda_schedule) + alchemist.create_alchemical_graph( + lambda_schedule, modifications_kwargs=modifications_kwargs + ) + + if modifications_kwargs is None: + modifications_kwargs = {} + alchemist.apply_modifications( system, alchemical_atoms, + modifications_kwargs, *args, **kwargs, ) @@ -87,6 +106,56 @@ def _report_dict( if initial: logger.debug("+" + "-" * 98 + "+") + @staticmethod + def _has_modification_type( + lambda_schedule: Dict[str, Union[float, int]], base_name: str + ) -> bool: + """ + Check if any modification of a given base type exists in the lambda schedule. + + Parameters + ---------- + lambda_schedule : Dict[str, Union[float, int]] + Dictionary mapping modification names to lambda values. + base_name : str + The base modification type name (e.g., "EMLEPotential", "CustomLJ"). + + Returns + ------- + bool + True if any modification of the base type exists, False otherwise. + """ + return any( + key == base_name or key.startswith(f"{base_name}:") + for key in lambda_schedule.keys() + ) + + @staticmethod + def _get_modification_instances( + lambda_schedule: Dict[str, Union[float, int]], base_name: str + ) -> List[str]: + """ + Get all modification instance names of a given base type from the lambda schedule. + + Parameters + ---------- + lambda_schedule : Dict[str, Union[float, int]] + Dictionary mapping modification names to lambda values. + base_name : str + The base modification type name (e.g., "EMLEPotential", "CustomLJ"). + + Returns + ------- + List[str] + List of all modification instance names matching the base type. + Examples: ["CustomLJ", "CustomLJ:region1", "CustomLJ:region2"] + """ + return [ + key + for key in lambda_schedule.keys() + if key == base_name or key.startswith(f"{base_name}:") + ] + @staticmethod def _report_energy_decomposition(context, system) -> None: """ diff --git a/src/fes_ml/alchemical/strategies/openff_strategy.py b/src/fes_ml/alchemical/strategies/openff_strategy.py index d702c87..7068ac9 100644 --- a/src/fes_ml/alchemical/strategies/openff_strategy.py +++ b/src/fes_ml/alchemical/strategies/openff_strategy.py @@ -681,15 +681,21 @@ def create_alchemical_state( # Create/update the modifications kwargs modifications_kwargs = _deepcopy(modifications_kwargs) or {} - if any(key in lambda_schedule for key in ["EMLEPotential"]): - modifications_kwargs["EMLEPotential"] = modifications_kwargs.get( - "EMLEPotential", {} - ) - - # Import required + # Handle EMLEPotential modifications + emle_instances = self._get_modification_instances( + lambda_schedule, "EMLEPotential" + ) + if emle_instances: import numpy as _np import sire as _sr + # For now, we apply the same kwargs to all instances + # In the future, this could be made per-instance specific + for modification_name in emle_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + # Write .top and .gro files via the OpenFF interchange if _os.path.exists(self._TMP_DIR): _shutil.rmtree(self._TMP_DIR) @@ -713,39 +719,60 @@ def create_alchemical_state( format=["prm7"], ) - # Add required EMLEPotential kwargs to the modifications_kwargs dict - modifications_kwargs["EMLEPotential"]["mols"] = sr_mols - modifications_kwargs["EMLEPotential"]["parm7"] = alchemical_prm7[0] - modifications_kwargs["EMLEPotential"]["mm_charges"] = _np.asarray( - [atom.charge().value() for atom in sr_mols.atoms(alchemical_atoms)] - ) - # Get the original charges of the OpenMM system openmm_charges = self._get_openmm_charges(system) - modifications_kwargs["EMLEPotential"]["openmm_charges"] = openmm_charges - if any( - key in lambda_schedule - for key in ["MLPotential", "MLInterpolation", "MLCorrection"] - ): - modifications_kwargs["MLPotential"] = modifications_kwargs.get( - "MLPotential", {} - ) - modifications_kwargs["MLPotential"]["topology"] = topology - if any(key in lambda_schedule for key in ["CustomLJ"]): - modifications_kwargs["CustomLJ"] = modifications_kwargs.get("CustomLJ", {}) - modifications_kwargs["CustomLJ"]["original_offxml"] = ffs - modifications_kwargs["CustomLJ"]["topology_off"] = topology_off + # Add required EMLEPotential kwargs to all instances + for modification_name in emle_instances: + modifications_kwargs[modification_name]["mols"] = sr_mols + modifications_kwargs[modification_name]["parm7"] = alchemical_prm7[0] + modifications_kwargs[modification_name]["mm_charges"] = _np.asarray( + [atom.charge().value() for atom in sr_mols.atoms(alchemical_atoms)] + ) + modifications_kwargs[modification_name][ + "openmm_charges" + ] = openmm_charges + + # Handle ML-related modifications + ml_types = ["MLPotential", "MLInterpolation", "MLCorrection"] + ml_instances = [] + for ml_type in ml_types: + ml_instances.extend( + self._get_modification_instances(lambda_schedule, ml_type) + ) - # Get the positions of the alchemical atoms - modifications_kwargs["CustomLJ"]["positions"] = positions + if ml_instances: + for modification_name in ml_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["topology"] = topology - if any(key in lambda_schedule for key in ["ChargeTransfer"]): - modifications_kwargs["ChargeTransfer"] = modifications_kwargs.get( - "ChargeTransfer", {} - ) - modifications_kwargs["ChargeTransfer"]["original_offxml"] = ffs - modifications_kwargs["ChargeTransfer"]["topology_off"] = topology_off + # Handle CustomLJ modifications + customlj_instances = self._get_modification_instances( + lambda_schedule, "CustomLJ" + ) + if customlj_instances: + for modification_name in customlj_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["original_offxml"] = ffs + modifications_kwargs[modification_name]["topology_off"] = topology_off + # Get the positions of the alchemical atoms + modifications_kwargs[modification_name]["positions"] = positions + + # Handle ChargeTransfer modifications + chargetransfer_instances = self._get_modification_instances( + lambda_schedule, "ChargeTransfer" + ) + if chargetransfer_instances: + for modification_name in chargetransfer_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["original_offxml"] = ffs + modifications_kwargs[modification_name]["topology_off"] = topology_off # Run the Alchemist self._run_alchemist( diff --git a/src/fes_ml/alchemical/strategies/sire_strategy.py b/src/fes_ml/alchemical/strategies/sire_strategy.py index e92f8fb..4767f74 100644 --- a/src/fes_ml/alchemical/strategies/sire_strategy.py +++ b/src/fes_ml/alchemical/strategies/sire_strategy.py @@ -136,23 +136,36 @@ def create_alchemical_state( self._report_energy_decomposition(omm_context, omm_system) modifications_kwargs = _deepcopy(modifications_kwargs) or {} - if any(key in lambda_schedule for key in ["EMLEPotential", "MLInterpolation"]): - modifications_kwargs["EMLEPotential"] = modifications_kwargs.get( - "EMLEPotential", {} - ) - modifications_kwargs["EMLEPotential"]["mols"] = mols - modifications_kwargs["EMLEPotential"]["parm7"] = alchemical_prm7[0] - modifications_kwargs["EMLEPotential"]["mm_charges"] = _np.asarray( - [atom.charge().value() for atom in mols.atoms(alchemical_atoms)] - ) - if any( - key in lambda_schedule - for key in ["MLPotential", "MLInterpolation", "MLCorrection"] - ): - modifications_kwargs["MLPotential"] = modifications_kwargs.get( - "MLPotential", {} + + # Handle EMLEPotential modifications + emle_instances = self._get_modification_instances( + lambda_schedule, "EMLEPotential" + ) + if emle_instances: + for modification_name in emle_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["mols"] = mols + modifications_kwargs[modification_name]["parm7"] = alchemical_prm7[0] + modifications_kwargs[modification_name]["mm_charges"] = _np.asarray( + [atom.charge().value() for atom in mols.atoms(alchemical_atoms)] + ) + + # Handle ML-related modifications + ml_types = ["MLPotential", "MLInterpolation", "MLCorrection"] + ml_instances = [] + for ml_type in ml_types: + ml_instances.extend( + self._get_modification_instances(lambda_schedule, ml_type) ) - modifications_kwargs["MLPotential"]["topology"] = topology + + if ml_instances: + for modification_name in ml_instances: + modifications_kwargs[modification_name] = modifications_kwargs.get( + modification_name, {} + ) + modifications_kwargs[modification_name]["topology"] = topology # Remove constraints involving alchemical atoms if remove_constraints: From 4a0c55a1745bb6f5d29dfdf58da333296e0287ce Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 16:46:06 +0100 Subject: [PATCH 02/20] Update environment --- environment.yaml | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/environment.yaml b/environment.yaml index df58c44..bce2013 100644 --- a/environment.yaml +++ b/environment.yaml @@ -6,29 +6,26 @@ channels: dependencies: - ambertools - - ase - compilers - cudatoolkit<11.9 - - deepmd-kit - eigen - loguru - - openmm>=8.1 + - openmm>=8.3 - openmm-torch + - openmm-ml - openmmforcefields - openff-toolkit - openff-interchange - nnpops - pip - pybind11 - - pytorch + - pytorch=*=*cuda* - python - pyyaml - sire - torchani - pygit2 - - xtb-python - pip: - git+https://github.com/chemle/emle-engine.git - - git+https://github.com/openmm/openmm-ml.git - coloredlogs - mace-torch From 63945bfee0e31ed945c72f757db690f8faf65200 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 17:45:11 +0100 Subject: [PATCH 03/20] Update README --- README.md | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index cfbdc81..171d208 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,10 @@ A package to run hybrid ML/MM free energy simulations. 1. [Installation](#installation) 2. [Alchemical Modifications](#alchemical-modifications) 3. [Running a Multistate Equilibrium Free Energy Simulation](#running-a-multistate-equilibrium-free-energy-simulation) + 1. [Using Multiple Alchemical Groups](#using-multiple-alchemical-groups) 4. [Dynamics and EMLE settings](#dynamics-and-emle-settings) - 1. [Sire Strategy](#sire-strategy) 2. [OpenFF Strategy](#openff-strategy) - - 5. [Log Level](#log-level) ## Installation @@ -98,6 +96,80 @@ U_kln = fes.run_single_state(1000, 1000, 6) np.save("U_kln_mm_sol_6.npy", np.asarray(U_kln)) ``` +### Using Multiple Alchemical Groups + +For more complex transformations, you can define multiple alchemical groups that can be transformed independently or simultaneously. This is particularly useful when you want to apply different transformations to different regions of your system or transform multiple ligands separately. + +To use multiple alchemical groups, specify the group name as a suffix after a colon in the lambda schedule: + +```python +from fes_ml.fes import FES +import numpy as np + +# Define lambda schedule for multiple alchemical groups +lambda_schedule = { + # Group 1: Turn off LJ and charges for ligand 1 + "LJSoftCore:ligand1": [1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.0, 0.0], + "ChargeScaling:ligand1": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.0], + + # Group 2: Turn off LJ and charges for ligand 2 + "LJSoftCore:ligand2": [1.0, 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.0], + "ChargeScaling:ligand2": [1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.33, 0.0], + + # Group 3: Interpolate between MM and ML for the entire system + "MLInterpolation:system": [0.0, 0.0, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0] +} + +# Define atom indices for each alchemical group +ligand1_atoms = [1, 2, 3, 4, 5] # Atoms belonging to first ligand +ligand2_atoms = [20, 21, 22, 23, 24] # Atoms belonging to second ligand +system_atoms = list(range(1, 50)) # All atoms for ML/MM interpolation + +# Define per-group alchemical atoms and other settings +modifications_kwargs = { + "LJSoftCore:ligand1": { + "alchemical_atoms": ligand1_atoms + }, + "ChargeScaling:ligand1": { + "alchemical_atoms": ligand1_atoms + }, + "LJSoftCore:ligand2": { + "alchemical_atoms": ligand2_atoms + }, + "ChargeScaling:ligand2": { + "alchemical_atoms": ligand2_atoms + }, + "MLInterpolation:system": { + "alchemical_atoms": system_atoms + } +} +``` + +#### Multiple Instances of the Same Modification Type + +You can also use multiple instances of the same modification type for the same group groups. For example, to apply interpolate between two sets of `CustomLJ` parameters: + +```python +# Lambda schedule with multiple CustomLJ modifications +lambda_schedule = { + "LJSoftCore:openff1": [1.0, 0.8, 0.6, 0.4, 0.2, 0.0], + "LJSoftCore:openff2": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "CustomLJ:openff1": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "CustomLJ:openff2": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], +} + +# Define different atoms and LJ parameters for each region +modifications_kwargs = { + "CustomLJ:region1": { + "lj_offxml": "openff_unconstrained-1.0.0.offxml", + }, + "CustomLJ:region2": { + "original_offxml": ["openff-2.1.0.offxml"], + "lj_offxml": "openff_unconstrained-2.0.0.offxml", + } +} +``` + ## Dynamics and EMLE settings In fes-ml, the default strategy for creating OpenMM systems is through Sire. Additionally, fes-ml offers the OpenFF strategy. You can select the desired creation strategy, either `'sire'` or `'openff'`, using the `strategy_name` argument when calling the `fes.create_alchemical_states` method to create the alchemical systems. Most other simulation configurations can also be set by passing additional arguments to this method. For details on customization, refer to the definitions of the `SireCreationStrategy` and `OpenFFCreationStrategy` classes. From 268797f23b914d0fc22061a7ba1076a245bce123 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 17:47:41 +0100 Subject: [PATCH 04/20] Fix typos in README --- README.md | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 171d208..28219a6 100644 --- a/README.md +++ b/README.md @@ -121,11 +121,11 @@ lambda_schedule = { } # Define atom indices for each alchemical group -ligand1_atoms = [1, 2, 3, 4, 5] # Atoms belonging to first ligand -ligand2_atoms = [20, 21, 22, 23, 24] # Atoms belonging to second ligand +ligand1_atoms = [1, 2, 3, 4, 5] # Atoms belonging to first ligand +ligand2_atoms = [20, 21, 22, 23, 24] # Atoms belonging to second ligand system_atoms = list(range(1, 50)) # All atoms for ML/MM interpolation -# Define per-group alchemical atoms and other settings +# Define per-group alchemical atoms modifications_kwargs = { "LJSoftCore:ligand1": { "alchemical_atoms": ligand1_atoms @@ -147,10 +147,9 @@ modifications_kwargs = { #### Multiple Instances of the Same Modification Type -You can also use multiple instances of the same modification type for the same group groups. For example, to apply interpolate between two sets of `CustomLJ` parameters: +You can also use multiple instances of the same modification type for the same group of atoms. For example, to interpolate between two sets of `CustomLJ` parameters: ```python -# Lambda schedule with multiple CustomLJ modifications lambda_schedule = { "LJSoftCore:openff1": [1.0, 0.8, 0.6, 0.4, 0.2, 0.0], "LJSoftCore:openff2": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], @@ -158,13 +157,12 @@ lambda_schedule = { "CustomLJ:openff2": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], } -# Define different atoms and LJ parameters for each region +# Define different LJ parameters for each region modifications_kwargs = { - "CustomLJ:region1": { + "CustomLJ:openff1": { "lj_offxml": "openff_unconstrained-1.0.0.offxml", }, - "CustomLJ:region2": { - "original_offxml": ["openff-2.1.0.offxml"], + "CustomLJ:openff2": { "lj_offxml": "openff_unconstrained-2.0.0.offxml", } } From abec88dc4557ad133a1119981951c8f3efc90433 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 17:59:27 +0100 Subject: [PATCH 05/20] Fix replacement of name in modification_kwargs --- src/fes_ml/alchemical/alchemist.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/fes_ml/alchemical/alchemist.py b/src/fes_ml/alchemical/alchemist.py index 991877e..b1e5aad 100644 --- a/src/fes_ml/alchemical/alchemist.py +++ b/src/fes_ml/alchemical/alchemist.py @@ -197,12 +197,10 @@ def create_alchemical_graph( ) if modifications_kwargs is not None and name in modifications_kwargs: - if modification_name in modifications_kwargs: - raise ValueError( - f"Cannot rename '{name}' to '{modification_name}': " - "key already exists in modifications_kwargs." + if modification_name not in modifications_kwargs: + modifications_kwargs[modification_name] = modifications_kwargs.pop( + name ) - modifications_kwargs[modification_name] = modifications_kwargs.pop(name) if base_name in Alchemist._modification_factories: factory = self._modification_factories[base_name] From 807cffed18f5d3fe6e1441308fcbe226bed2e538 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 18:22:40 +0100 Subject: [PATCH 06/20] Fix setting of MLPotential in strategies --- src/fes_ml/alchemical/strategies/openff_strategy.py | 9 ++++----- src/fes_ml/alchemical/strategies/sire_strategy.py | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/fes_ml/alchemical/strategies/openff_strategy.py b/src/fes_ml/alchemical/strategies/openff_strategy.py index 7068ac9..f6ad8cc 100644 --- a/src/fes_ml/alchemical/strategies/openff_strategy.py +++ b/src/fes_ml/alchemical/strategies/openff_strategy.py @@ -742,11 +742,10 @@ def create_alchemical_state( ) if ml_instances: - for modification_name in ml_instances: - modifications_kwargs[modification_name] = modifications_kwargs.get( - modification_name, {} - ) - modifications_kwargs[modification_name]["topology"] = topology + modifications_kwargs["MLPotential"] = modifications_kwargs.get( + "MLPotential", {} + ) + modifications_kwargs["MLPotential"]["topology"] = topology # Handle CustomLJ modifications customlj_instances = self._get_modification_instances( diff --git a/src/fes_ml/alchemical/strategies/sire_strategy.py b/src/fes_ml/alchemical/strategies/sire_strategy.py index 4767f74..8ad57a0 100644 --- a/src/fes_ml/alchemical/strategies/sire_strategy.py +++ b/src/fes_ml/alchemical/strategies/sire_strategy.py @@ -161,11 +161,10 @@ def create_alchemical_state( ) if ml_instances: - for modification_name in ml_instances: - modifications_kwargs[modification_name] = modifications_kwargs.get( - modification_name, {} - ) - modifications_kwargs[modification_name]["topology"] = topology + modifications_kwargs["MLPotential"] = modifications_kwargs.get( + "MLPotential", {} + ) + modifications_kwargs["MLPotential"]["topology"] = topology # Remove constraints involving alchemical atoms if remove_constraints: From a8f56070ddaffdee93765ba64d931b59781f4488 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 18:22:50 +0100 Subject: [PATCH 07/20] Update tests --- tests/test_alchemical_states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_alchemical_states.py b/tests/test_alchemical_states.py index 24d58cc..9b8519b 100644 --- a/tests/test_alchemical_states.py +++ b/tests/test_alchemical_states.py @@ -244,7 +244,7 @@ def _test_energy_decomposition( non_bonded_forces = [ "NonbondedForce", "CustomBondForce", - "CustomNonbondedForce", + "LJSoftCore:default", ] alc_nonbonded_force = sum( alc_energy_decom[force]._value From 1053fd851392cf5d6eb09f21f7cbe4e5782688e5 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 18:38:50 +0100 Subject: [PATCH 08/20] Update hooks and formatting --- .pre-commit-config.yaml | 48 ++++----------- analysis/analyse.py | 59 ------------------- examples/additional_scripts/analysis.py | 10 ++-- .../mts_benchmark/ml/agg_output.py | 2 +- .../mts_benchmark/ml/analysis.py | 2 +- .../mts_benchmark/ml_mts/agg_output.py | 2 +- .../mts_benchmark/ml_mts/analysis.py | 2 +- .../mts_benchmark/mm/agg_output.py | 2 +- .../mts_benchmark/mm/analysis.py | 2 +- .../performance_benchmark_emle_aev.py | 4 +- .../openff_strategy/run_scripts/agg_ouput.py | 4 +- .../openff_strategy/run_scripts/analysis.py | 2 +- .../run_scripts/create_run_system.py | 4 +- .../sire_strategy/benzene_ml_mm_sol_emle.py | 3 - .../sire_strategy/benzene_ml_mm_sol_mts.py | 1 - examples/sire_strategy/benzene_ml_sol_mts.py | 1 - pyproject.toml | 21 ++++++- src/fes_ml/__init__.py | 2 +- src/fes_ml/alchemical/alchemist.py | 1 - .../alchemical/modifications/__init__.py | 20 +++---- .../modifications/charge_transfer.py | 5 +- 21 files changed, 58 insertions(+), 139 deletions(-) delete mode 100644 analysis/analyse.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 816cac6..9772597 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,41 +1,13 @@ -# Pre-commit hooks for Python code -# Last revision by: Joao Morado -# Last revision date: 8.01.2023 -# See https://pre-commit.com for more information -# See https://pre-commit.com/hooks.html for more hooks - repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - #- id: check-added-large-files -- repo: https://github.com/psf/black - rev: 23.1.0 - hooks: - - id: black - - id: black-jupyter -- repo: https://github.com/keewis/blackdoc - rev: v0.3.8 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.8 hooks: - - id: blackdoc - #- id: blackdoc-autoupdate-black -- repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", "--filter-files"] -- repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-docstrings] - args: [--max-line-length=127, --exit-zero] - verbose: True -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.1.1 + - id: ruff + args: ["--fix"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.17.1 hooks: - - id: mypy - args: [--no-strict-optional, --ignore-missing-imports, --namespace-packages, --explicit-package-bases] - additional_dependencies: ["types-PyYAML"] + - id: mypy + verbose: true + entry: bash -c 'mypy "$@" || true' -- \ No newline at end of file diff --git a/analysis/analyse.py b/analysis/analyse.py deleted file mode 100644 index 04d9884..0000000 --- a/analysis/analyse.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Script to calculate free energy differences using MBAR.""" - -import sys - -import matplotlib.pyplot as plt -import numpy as np -from openmm import unit -from pymbar import MBAR - -if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: - print("Usage: python analyse.py ") - print( - "Please provide exactly two input arguments, which should be a file name and the temperature in Kelvin." - ) - sys.exit(1) - -input_file = sys.argv[1] -temperature = float(sys.argv[2]) -U_kn = np.load(sys.argv[1]) - -# Reformat array such that data is organised in the following away -# U_kn = [[U^{l}_{k,n} for n in nsamples for k in nstates] for l in nstates] -# where l (superscript) is the alchemical state at which the potential energy is evaluated -# and k (subscript) the alchemical state at which it is sampled -nstates, nstates, nsamples = U_kn.shape -U_kn = U_kn.transpose(1, 0, 2)[::-1, :, :] -U_kn = U_kn.reshape(nstates, nstates * nsamples) - -# Keep it in here to contemplate case where number of samples per alchemical state differs -# N_k = [ U_kn.shape[1]//nstates for _ in range(nstates)] -N_k = [nsamples for _ in range(nstates)] - -# Compute the overal -mbar = MBAR(U_kn, N_k) -overlap = mbar.compute_overlap() -plt.figure() -plt.title("Overlap") -plt.imshow(overlap["matrix"], vmin=0, vmax=1) -plt.colorbar() -plt.savefig("overlap.png") - -# If this fails try setting compute_uncertainty to false -# See this issue: https://github.com/choderalab/pymbar/issues/419 -results = mbar.compute_free_energy_differences(compute_uncertainty=True) - -# Calculate the free energy -kT = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA * unit.kelvin * temperature -print( - "Free energy = {}".format( - (results["Delta_f"][nstates - 1, 0] * kT).in_units_of(unit.kilocalorie_per_mole) - ) -) -print( - "Statistical uncertainty = {}".format( - (results["dDelta_f"][nstates - 1, 0] * kT).in_units_of( - unit.kilocalorie_per_mole - ) - ) -) diff --git a/examples/additional_scripts/analysis.py b/examples/additional_scripts/analysis.py index 807a9d8..b825fff 100644 --- a/examples/additional_scripts/analysis.py +++ b/examples/additional_scripts/analysis.py @@ -125,7 +125,7 @@ def water_atom( # Select atoms for which you want to calculate RDF water = universe.select_atoms(f"type O and not {ligand_selection}") - ligand = universe.select_atoms(ligand_selection) + # ligand = universe.select_atoms(ligand_selection) atoms = [ universe.select_atoms(f"index {atom.index} and {ligand_selection}") for atom in universe.select_atoms(ligand_selection) @@ -327,10 +327,10 @@ def __init__(self) -> None: def _plot_stdout_with_time_base(txt_file, save_location, data_name, prop): df = pd.read_csv(txt_file, sep=",") data = df[f"{data_name}"] - steps = df[f'#"Step"'] + steps = df['#"Step"'] plt.plot(steps, data, "r-", linewidth=2, color="maroon") - plt.xlabel(f"Steps") + plt.xlabel("Steps") plt.ylabel(f"{data_name}") plt.title(f"Plot of {prop} for {txt_file}") # plt.savefig(f"{save_location}") @@ -378,12 +378,12 @@ def plot_energy_all_windows( txt_file = f"{folder}/{base_name}{w}.txt" df = pd.read_csv(txt_file, sep=",") data = df[f"{data_name}"] - steps = df[f'#"Step"'] + steps = df['#"Step"'] x_axis = [r for r in range(1, len(steps) + 1, 1)] plt.plot(x_axis, data, linewidth=2, alpha=0.3, label=label, color=colors[w]) legend_labels.append(label) - plt.xlabel(f"Steps") + plt.xlabel("Steps") plt.ylabel(f"{data_name}") plt.title(f"Plot of {prop} for {txt_file}") plt.legend(bbox_to_anchor=(1, 0.5), loc="center left") diff --git a/examples/openff_strategy/mts_benchmark/ml/agg_output.py b/examples/openff_strategy/mts_benchmark/ml/agg_output.py index 3664e87..b298743 100644 --- a/examples/openff_strategy/mts_benchmark/ml/agg_output.py +++ b/examples/openff_strategy/mts_benchmark/ml/agg_output.py @@ -9,7 +9,7 @@ out = f"OUTPUT_{i}" try: f = glob.glob(out + "/*npy")[0] - except: + except Exception: break U_kln.append(np.load(f)[:, frames_disc::step]) diff --git a/examples/openff_strategy/mts_benchmark/ml/analysis.py b/examples/openff_strategy/mts_benchmark/ml/analysis.py index 759d665..05ca87c 100644 --- a/examples/openff_strategy/mts_benchmark/ml/analysis.py +++ b/examples/openff_strategy/mts_benchmark/ml/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: print("Usage: python analyse.py ") diff --git a/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py b/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py index 3664e87..b298743 100644 --- a/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py +++ b/examples/openff_strategy/mts_benchmark/ml_mts/agg_output.py @@ -9,7 +9,7 @@ out = f"OUTPUT_{i}" try: f = glob.glob(out + "/*npy")[0] - except: + except Exception: break U_kln.append(np.load(f)[:, frames_disc::step]) diff --git a/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py b/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py index 759d665..05ca87c 100644 --- a/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py +++ b/examples/openff_strategy/mts_benchmark/ml_mts/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: print("Usage: python analyse.py ") diff --git a/examples/openff_strategy/mts_benchmark/mm/agg_output.py b/examples/openff_strategy/mts_benchmark/mm/agg_output.py index 3664e87..b298743 100644 --- a/examples/openff_strategy/mts_benchmark/mm/agg_output.py +++ b/examples/openff_strategy/mts_benchmark/mm/agg_output.py @@ -9,7 +9,7 @@ out = f"OUTPUT_{i}" try: f = glob.glob(out + "/*npy")[0] - except: + except Exception: break U_kln.append(np.load(f)[:, frames_disc::step]) diff --git a/examples/openff_strategy/mts_benchmark/mm/analysis.py b/examples/openff_strategy/mts_benchmark/mm/analysis.py index 759d665..05ca87c 100644 --- a/examples/openff_strategy/mts_benchmark/mm/analysis.py +++ b/examples/openff_strategy/mts_benchmark/mm/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR if len(sys.argv) != 3 or sys.argv[1] in ["-h", "--help"]: print("Usage: python analyse.py ") diff --git a/examples/openff_strategy/performance_benchmark_emle_aev.py b/examples/openff_strategy/performance_benchmark_emle_aev.py index 48b603f..e2d3632 100644 --- a/examples/openff_strategy/performance_benchmark_emle_aev.py +++ b/examples/openff_strategy/performance_benchmark_emle_aev.py @@ -10,13 +10,11 @@ import sys import time - import numpy as np import openff.units as offunit - import openmm as mm import openmm.app as app import openmm.unit as unit - from fes_ml import FES, MTS + from fes_ml import FES if len(sys.argv) == 1: raise ValueError("must pass window as positional arguments") diff --git a/examples/openff_strategy/run_scripts/agg_ouput.py b/examples/openff_strategy/run_scripts/agg_ouput.py index f3446ea..e20e657 100644 --- a/examples/openff_strategy/run_scripts/agg_ouput.py +++ b/examples/openff_strategy/run_scripts/agg_ouput.py @@ -15,10 +15,10 @@ def main(args): out = f"{folder}/{i}" try: f = glob.glob(f"{out}.npy")[0] - except: + except Exception: break print(np.load(f).shape) - U_kln.append(np.load(f)) # np.load(f)[:, frames_disc::step] + U_kln.append(np.load(f)[:, frames_disc::step]) U_kln = np.asarray(U_kln) print(U_kln.shape) diff --git a/examples/openff_strategy/run_scripts/analysis.py b/examples/openff_strategy/run_scripts/analysis.py index d9cc651..19c9e7e 100644 --- a/examples/openff_strategy/run_scripts/analysis.py +++ b/examples/openff_strategy/run_scripts/analysis.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from openmm import unit -from pymbar import MBAR, timeseries +from pymbar import MBAR def main(args): diff --git a/examples/openff_strategy/run_scripts/create_run_system.py b/examples/openff_strategy/run_scripts/create_run_system.py index ec496a3..6d173fa 100644 --- a/examples/openff_strategy/run_scripts/create_run_system.py +++ b/examples/openff_strategy/run_scripts/create_run_system.py @@ -1,9 +1,7 @@ import logging import math import os -import sys from argparse import ArgumentParser -from typing import Optional, Union import numpy as np import openff.units as offunit @@ -33,7 +31,7 @@ def main(args): try: os.makedirs(folder) - except: + except Exception: pass # --------------------------------------------------------------- # diff --git a/examples/sire_strategy/benzene_ml_mm_sol_emle.py b/examples/sire_strategy/benzene_ml_mm_sol_emle.py index e21844c..ede045f 100644 --- a/examples/sire_strategy/benzene_ml_mm_sol_emle.py +++ b/examples/sire_strategy/benzene_ml_mm_sol_emle.py @@ -14,9 +14,6 @@ if __name__ == "__main__": import numpy as np - from fes_ml.alchemical.modifications.ml_interpolation import ( - MLInterpolationModification, - ) from fes_ml.fes import FES # Set up the alchemical modifications diff --git a/examples/sire_strategy/benzene_ml_mm_sol_mts.py b/examples/sire_strategy/benzene_ml_mm_sol_mts.py index a084ee3..a08dc26 100644 --- a/examples/sire_strategy/benzene_ml_mm_sol_mts.py +++ b/examples/sire_strategy/benzene_ml_mm_sol_mts.py @@ -10,7 +10,6 @@ if __name__ == "__main__": import numpy as np - import openmm as mm import openmm.unit as unit from fes_ml import FES, MTS diff --git a/examples/sire_strategy/benzene_ml_sol_mts.py b/examples/sire_strategy/benzene_ml_sol_mts.py index f824554..40fdd66 100644 --- a/examples/sire_strategy/benzene_ml_sol_mts.py +++ b/examples/sire_strategy/benzene_ml_sol_mts.py @@ -17,7 +17,6 @@ if __name__ == "__main__": import numpy as np - import openmm as mm import openmm.unit as unit from fes_ml import FES, MTS diff --git a/pyproject.toml b/pyproject.toml index 130652f..39a6e17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,14 +6,13 @@ build-backend = "setuptools.build_meta" name = "fes_ml" description = "A package to run hybrid ML/MM free energy simulations." requires-python = ">=3.9" -keywords = ["free energy simulation", "free energy", "force field", "machine learning", "openmm"] +keywords = ["free energy simulations", "free energy", "force field", "machine learning", "openmm"] authors = [{email = "jmorado@ed.ac.uk"},{name = "Joao Morado"}] maintainers = [{name = "Joao Morado", email = "jmorado@ed.ac.uk"}] classifiers = [ "License :: OSI Approved :: GPL License", "Intended Audience :: Science/Research", "Intended Audience :: Developers", - "Topic :: Oceanography Modeling", "Topic :: Scientific/Engineering", "Programming Language :: Python", "Programming Language :: Python :: 3.10" @@ -66,3 +65,21 @@ where = ["src"] include = ["*"] exclude = ["*__pycache__*"] namespaces = true + +[tool.ruff] +fix = true +line-length = 144 + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = ["E203", "I001"] + +[tool.ruff.lint.isort] +known-first-party = ["fes_ml"] +combine-as-imports = true +force-single-line = false + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true \ No newline at end of file diff --git a/src/fes_ml/__init__.py b/src/fes_ml/__init__.py index 0322cba..1d719b7 100644 --- a/src/fes_ml/__init__.py +++ b/src/fes_ml/__init__.py @@ -7,4 +7,4 @@ from .log import config_logger from .mts import MTS -__all__ = ["FES", "MTS"] +__all__ = ["FES", "MTS", "config_logger"] diff --git a/src/fes_ml/alchemical/alchemist.py b/src/fes_ml/alchemical/alchemist.py index b1e5aad..a835108 100644 --- a/src/fes_ml/alchemical/alchemist.py +++ b/src/fes_ml/alchemical/alchemist.py @@ -1,7 +1,6 @@ """Module for the Alchemist class.""" import logging -import sys from copy import deepcopy as _deepcopy from importlib.metadata import entry_points from typing import Any, Dict, List, Optional diff --git a/src/fes_ml/alchemical/modifications/__init__.py b/src/fes_ml/alchemical/modifications/__init__.py index 9e9014c..769f092 100644 --- a/src/fes_ml/alchemical/modifications/__init__.py +++ b/src/fes_ml/alchemical/modifications/__init__.py @@ -1,14 +1,14 @@ """Init file for the modification module.""" from . import ( - charge_scaling, - charge_transfer, - custom_lj, - emle_potential, - intramolecular, - lj_softcore, - ml_correction, - ml_interpolation, - ml_potential, + charge_scaling as charge_scaling, + charge_transfer as charge_transfer, + custom_lj as custom_lj, + emle_potential as emle_potential, + intramolecular as intramolecular, + lj_softcore as lj_softcore, + ml_correction as ml_correction, + ml_interpolation as ml_interpolation, + ml_potential as ml_potential, ) -from .emle_potential import _EMLE_CALCULATORS +from .emle_potential import _EMLE_CALCULATORS as _EMLE_CALCULATORS diff --git a/src/fes_ml/alchemical/modifications/charge_transfer.py b/src/fes_ml/alchemical/modifications/charge_transfer.py index 12eaf16..12fd737 100644 --- a/src/fes_ml/alchemical/modifications/charge_transfer.py +++ b/src/fes_ml/alchemical/modifications/charge_transfer.py @@ -2,7 +2,6 @@ from typing import List, Optional, Union import openmm as _mm -import openmm.app as _app import openmm.unit as _unit from openff.toolkit.topology import Topology as _Topology from openff.toolkit.typing.engines.smirnoff import ForceField as _ForceField @@ -73,11 +72,11 @@ def get_is_donor_acceptor( break # Acceptor: heavy atoms elif atom.atomic_number == 8 and (symmetric or idx not in alchemical_atoms): - num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) + #num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) is_acceptor[idx] = 1 elif atom.atomic_number == 7: # Nitrogen - num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) + #num_H = sum(b.atomic_number == 1 for b in atom.bonded_atoms) is_acceptor[idx] = 1 # if num_H != 1 and num_H != 2: # optionally exclude primary/secondary amines # is_acceptor[idx] = 1 From 5dfb408fb51a90d9956395b31650be0c43d629d1 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 18:52:15 +0100 Subject: [PATCH 09/20] Update environments --- environment.yaml | 6 +++--- environment_rascal.yaml | 34 ---------------------------------- 2 files changed, 3 insertions(+), 37 deletions(-) delete mode 100644 environment_rascal.yaml diff --git a/environment.yaml b/environment.yaml index bce2013..cc2ac52 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,4 +1,4 @@ -name: fes-ml-aev +name: fes-ml-aev-fix channels: - conda-forge @@ -7,10 +7,10 @@ channels: dependencies: - ambertools - compilers - - cudatoolkit<11.9 + - cudatoolkit=11.8 - eigen - loguru - - openmm>=8.3 + - openmm>=8.1 - openmm-torch - openmm-ml - openmmforcefields diff --git a/environment_rascal.yaml b/environment_rascal.yaml deleted file mode 100644 index 37c0662..0000000 --- a/environment_rascal.yaml +++ /dev/null @@ -1,34 +0,0 @@ -name: fes-ml - -channels: - - conda-forge - - openbiosim/label/emle - -dependencies: - - ambertools - - ase - - compilers - - cudatoolkit<11.9 - - deepmd-kit - - eigen - - loguru - - openmm>=8.1 - - openmm-torch - - openmmforcefields - - openff-toolkit - - openff-interchange - - nnpops - - pip - - pybind11 - - pytorch<3.11 - - python - - pyyaml - - sire - - torchani - - xtb-python - - pip: - - git+https://github.com/lab-cosmo/librascal.git - - git+https://github.com/chemle/emle-engine.git@main - - git+https://github.com/openmm/openmm-ml.git - - coloredlogs - - mace-torch From fe7ce76e6195eca99c10e2119034938b9ab5afa6 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 18:53:17 +0100 Subject: [PATCH 10/20] Update action version in workflow --- .github/workflows/main.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 8e47fe4..5edee64 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -28,7 +28,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Pre-commit hooks - uses: pre-commit/action@v3.0.0 + uses: pre-commit/action@v3.0.1 build: needs: pre-commit @@ -48,7 +48,7 @@ jobs: uses: actions/checkout@v4 - name: Set up environment - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v2 with: environment-file: environment.yaml environment-name: gha-test-env From 862bd5bbe06314e451e685326a5e6b548b3f2a58 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 18:56:18 +0100 Subject: [PATCH 11/20] Bump Python version from 3.10 to 3.12 --- .github/workflows/main.yaml | 6 +++--- pyproject.toml | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 5edee64..4240d69 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -15,7 +15,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10"] + python-version: ["3.12"] runs-on: ${{ matrix.os }} steps: @@ -36,7 +36,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] #, macos-latest, windows-latest] - python-version: ["3.10"] + python-version: ["3.12"] defaults: run: @@ -94,7 +94,7 @@ jobs: # strategy: # matrix: # os: [ubuntu-latest] - # python-version: ["3.10"] + # python-version: ["3.12"] # runs-on: ${{ matrix.os }} # steps: diff --git a/pyproject.toml b/pyproject.toml index 39a6e17..97ea86f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Intended Audience :: Developers", "Topic :: Scientific/Engineering", "Programming Language :: Python", - "Programming Language :: Python :: 3.10" + "Programming Language :: Python :: 2" ] dynamic = [ @@ -80,6 +80,6 @@ combine-as-imports = true force-single-line = false [tool.mypy] -python_version = "3.10" +python_version = "3.12" warn_return_any = true warn_unused_configs = true \ No newline at end of file From 57b48bf2470813ca4d7fae077ef17be2e22b3c7e Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 18:56:37 +0100 Subject: [PATCH 12/20] Correct env name --- environment.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yaml b/environment.yaml index cc2ac52..668af44 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,4 +1,4 @@ -name: fes-ml-aev-fix +name: fes-ml-aev channels: - conda-forge From a7c4478d02b92f552461b4863915f6fcb5a7cd40 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 19:10:08 +0100 Subject: [PATCH 13/20] Update pre/post/skp dependencies of MLInterpolationModification --- .../alchemical/modifications/ml_interpolation.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/fes_ml/alchemical/modifications/ml_interpolation.py b/src/fes_ml/alchemical/modifications/ml_interpolation.py index b51e1c0..ab8c09d 100644 --- a/src/fes_ml/alchemical/modifications/ml_interpolation.py +++ b/src/fes_ml/alchemical/modifications/ml_interpolation.py @@ -7,6 +7,11 @@ from .base_modification import BaseModification, BaseModificationFactory from .ml_base_modification import MLBaseModification +from .intramolecular import ( + IntraMolecularBondedRemovalModification, + IntraMolecularNonBondedExceptionsModification, +) +from .ml_potential import MLPotentialModification logger = logging.getLogger(__name__) @@ -36,10 +41,12 @@ class MLInterpolationModification(MLBaseModification, BaseModification): """Class to add a CustomCVForce to interpolate between ML and MM forces.""" NAME = "MLInterpolation" - pre_dependencies = ["MLPotential"] + pre_dependencies = [MLPotentialModification.NAME] post_dependencies = [ - "IntraMolecularNonBondedExceptions", - "IntraMolecularBondedRemoval", + IntraMolecularNonBondedExceptionsModification.NAME, + ] + skip_dependencies: List[str] = [ + IntraMolecularBondedRemovalModification.NAME, ] def apply( From 402e676da3ad79fa9c5e105d9b3af18bf50788e0 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 19:11:24 +0100 Subject: [PATCH 14/20] Update dependencies for ml mods --- src/fes_ml/alchemical/modifications/ml_correction.py | 10 ---------- .../alchemical/modifications/ml_interpolation.py | 10 ---------- 2 files changed, 20 deletions(-) diff --git a/src/fes_ml/alchemical/modifications/ml_correction.py b/src/fes_ml/alchemical/modifications/ml_correction.py index 40776da..17ae259 100644 --- a/src/fes_ml/alchemical/modifications/ml_correction.py +++ b/src/fes_ml/alchemical/modifications/ml_correction.py @@ -6,10 +6,6 @@ import openmm as _mm from .base_modification import BaseModification, BaseModificationFactory -from .intramolecular import ( - IntraMolecularBondedRemovalModification, - IntraMolecularNonBondedExceptionsModification, -) from .ml_base_modification import MLBaseModification from .ml_potential import MLPotentialModification @@ -42,12 +38,6 @@ class MLCorrectionModification(MLBaseModification, BaseModification): NAME = "MLCorrection" pre_dependencies: List[str] = [MLPotentialModification.NAME] - post_dependencies: List[str] = [ - IntraMolecularNonBondedExceptionsModification.NAME, - ] - skip_dependencies: List[str] = [ - IntraMolecularBondedRemovalModification.NAME, - ] def apply( self, diff --git a/src/fes_ml/alchemical/modifications/ml_interpolation.py b/src/fes_ml/alchemical/modifications/ml_interpolation.py index ab8c09d..e0149df 100644 --- a/src/fes_ml/alchemical/modifications/ml_interpolation.py +++ b/src/fes_ml/alchemical/modifications/ml_interpolation.py @@ -7,10 +7,6 @@ from .base_modification import BaseModification, BaseModificationFactory from .ml_base_modification import MLBaseModification -from .intramolecular import ( - IntraMolecularBondedRemovalModification, - IntraMolecularNonBondedExceptionsModification, -) from .ml_potential import MLPotentialModification logger = logging.getLogger(__name__) @@ -42,12 +38,6 @@ class MLInterpolationModification(MLBaseModification, BaseModification): NAME = "MLInterpolation" pre_dependencies = [MLPotentialModification.NAME] - post_dependencies = [ - IntraMolecularNonBondedExceptionsModification.NAME, - ] - skip_dependencies: List[str] = [ - IntraMolecularBondedRemovalModification.NAME, - ] def apply( self, From 3b988a4763a9e8dac3b55754c0799300a75c9948 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 19:39:13 +0100 Subject: [PATCH 15/20] Update order for ml mods --- src/fes_ml/alchemical/modifications/ml_correction.py | 3 +++ src/fes_ml/alchemical/modifications/ml_interpolation.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/src/fes_ml/alchemical/modifications/ml_correction.py b/src/fes_ml/alchemical/modifications/ml_correction.py index 17ae259..31ea327 100644 --- a/src/fes_ml/alchemical/modifications/ml_correction.py +++ b/src/fes_ml/alchemical/modifications/ml_correction.py @@ -8,6 +8,7 @@ from .base_modification import BaseModification, BaseModificationFactory from .ml_base_modification import MLBaseModification from .ml_potential import MLPotentialModification +from .intramolecular import IntraMolecularBondedRemovalModification, IntraMolecularNonBondedExceptionsModification logger = logging.getLogger(__name__) @@ -38,6 +39,8 @@ class MLCorrectionModification(MLBaseModification, BaseModification): NAME = "MLCorrection" pre_dependencies: List[str] = [MLPotentialModification.NAME] + post_dependencies = [IntraMolecularNonBondedExceptionsModification.NAME] + skip_dependencies: List[str] = [IntraMolecularBondedRemovalModification.NAME] def apply( self, diff --git a/src/fes_ml/alchemical/modifications/ml_interpolation.py b/src/fes_ml/alchemical/modifications/ml_interpolation.py index e0149df..43f34b6 100644 --- a/src/fes_ml/alchemical/modifications/ml_interpolation.py +++ b/src/fes_ml/alchemical/modifications/ml_interpolation.py @@ -8,6 +8,7 @@ from .base_modification import BaseModification, BaseModificationFactory from .ml_base_modification import MLBaseModification from .ml_potential import MLPotentialModification +from .intramolecular import IntraMolecularBondedRemovalModification, IntraMolecularNonBondedExceptionsModification logger = logging.getLogger(__name__) @@ -38,6 +39,7 @@ class MLInterpolationModification(MLBaseModification, BaseModification): NAME = "MLInterpolation" pre_dependencies = [MLPotentialModification.NAME] + post_dependencies = [IntraMolecularBondedRemovalModification.NAME, IntraMolecularNonBondedExceptionsModification.NAME] def apply( self, From 835175a1d83e81888b234e0fbae380805675afe5 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 20:07:58 +0100 Subject: [PATCH 16/20] Keep edge/node ordering --- src/fes_ml/alchemical/alchemist.py | 14 ++++++++++---- .../alchemical/modifications/ml_interpolation.py | 5 ++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/fes_ml/alchemical/alchemist.py b/src/fes_ml/alchemical/alchemist.py index a835108..33083c3 100644 --- a/src/fes_ml/alchemical/alchemist.py +++ b/src/fes_ml/alchemical/alchemist.py @@ -124,9 +124,12 @@ def add_modification_to_graph( pre_modification = factory.create_modification( modification_name=dep_modification_name ) + # Add edge first (like before), then recursively add dependency + self._graph.add_edge(dep_modification_name, node_name) self.add_modification_to_graph(pre_modification, None) - - self._graph.add_edge(dep_modification_name, node_name) + else: + # Dependency already exists, just add the edge + self._graph.add_edge(dep_modification_name, node_name) if modification.post_dependencies is not None: for post_dependency in modification.post_dependencies: @@ -147,9 +150,12 @@ def add_modification_to_graph( post_modification = factory.create_modification( modification_name=dep_modification_name ) + # Add edge first (like before), then recursively add dependency + self._graph.add_edge(node_name, dep_modification_name) self.add_modification_to_graph(post_modification, None) - - self._graph.add_edge(node_name, dep_modification_name) + else: + # Dependency already exists, just add the edge + self._graph.add_edge(node_name, dep_modification_name) def remove_modification_from_graph(self, modification: str) -> None: """ diff --git a/src/fes_ml/alchemical/modifications/ml_interpolation.py b/src/fes_ml/alchemical/modifications/ml_interpolation.py index 43f34b6..3e7280d 100644 --- a/src/fes_ml/alchemical/modifications/ml_interpolation.py +++ b/src/fes_ml/alchemical/modifications/ml_interpolation.py @@ -39,7 +39,10 @@ class MLInterpolationModification(MLBaseModification, BaseModification): NAME = "MLInterpolation" pre_dependencies = [MLPotentialModification.NAME] - post_dependencies = [IntraMolecularBondedRemovalModification.NAME, IntraMolecularNonBondedExceptionsModification.NAME] + post_dependencies: List[str] = [ + IntraMolecularBondedRemovalModification.NAME, + IntraMolecularNonBondedExceptionsModification.NAME, + ] def apply( self, From 0b88d6d14e9abc6c27b05d374354993333f9c0ec Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 29 Sep 2025 20:11:20 +0100 Subject: [PATCH 17/20] Bump to version 0.2.2 --- src/fes_ml/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fes_ml/__init__.py b/src/fes_ml/__init__.py index 1d719b7..1123f23 100644 --- a/src/fes_ml/__init__.py +++ b/src/fes_ml/__init__.py @@ -1,6 +1,6 @@ """fes_ml base package.""" -__version__ = "0.2.1" +__version__ = "0.2.2" __author__ = "Joao Morado" from .fes import FES From 841fec37f92ac6dca97621e27d7cb7626a1f8b9a Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Wed, 29 Oct 2025 11:23:52 +0000 Subject: [PATCH 18/20] Add Tang-Toennies to LJSoftCore --- .../alchemical/modifications/lj_softcore.py | 61 ++++++++++++++++--- 1 file changed, 51 insertions(+), 10 deletions(-) diff --git a/src/fes_ml/alchemical/modifications/lj_softcore.py b/src/fes_ml/alchemical/modifications/lj_softcore.py index b796d43..7db6368 100644 --- a/src/fes_ml/alchemical/modifications/lj_softcore.py +++ b/src/fes_ml/alchemical/modifications/lj_softcore.py @@ -51,6 +51,8 @@ def apply( system: _mm.System, alchemical_atoms: List[int], lambda_value: Optional[Union[float, int]] = 1.0, + include_repulsion: bool = True, + include_attraction: bool = True, *args, **kwargs, ) -> _mm.System: @@ -65,6 +67,12 @@ def apply( The indices of the alchemical atoms in the system. lambda_value : float The value of the alchemical state parameter. + include_repulsion : bool, optional + Whether to include the short-range repulsive term (r^-12 component). + Default is True. + include_attraction : bool, optional + Whether to include the long-range attractive term (r^-6 component). + Default is True. args : tuple Additional arguments to be passed to the modification. kwargs : dict @@ -78,21 +86,54 @@ def apply( forces = {force.__class__.__name__: force for force in system.getForces()} nb_force = forces["NonbondedForce"] + if not include_repulsion and not include_attraction: + raise ValueError("At least one of include_repulsion or include_attraction must be True.") + # Define the softcore Lennard-Jones energy function - energy_function = ( - f"{lambda_value}*4*epsilon*x*(x-1.0); x = (sigma/reff_sterics)^6;" - ) - energy_function += ( - f"reff_sterics = sigma*(0.5*(1.0-{lambda_value}) + (r/sigma)^6)^(1/6);" - ) - energy_function += ( - "sigma = 0.5*(sigma1+sigma2); epsilon = sqrt(epsilon1*epsilon2);" + # Standard LJ: 4*epsilon*(x^2 - x) where x = (sigma/r)^6 + # Softcore LJ: 4*epsilon*(x_soft^2 - x_soft) where x_soft = (sigma/reff)^6 + terms = [] + # Main energy expression pieces + if include_repulsion: + terms.append(f"{lambda_value}*4*epsilon*x*x") + if include_attraction: + terms.append(f"-{lambda_value}*4*epsilon*x*fdamp") + + energy_expr = " + ".join(terms) + + # Full OpenMM expression: E; intermediate definitions follow + expr = ( + f"{energy_expr}; " + "x=(sigma/reff)^6; " + f"reff=sigma*(0.5*(1.0-{lambda_value})+(r/sigma)^6)^(1/6); " + "sigma=0.5*(sigma1+sigma2); " + "epsilon=sqrt(epsilon1*epsilon2); " ) - logger.debug(f"LJ softcore function: {energy_function}") + # Add damping if needed + if include_attraction and not include_repulsion: + expr += ( + # Tang-Toennies f6(b*r) expansion, truncated to 6th order + "fdamp=1.0 - exp(-xdamp)*(1.0 + xdamp + 0.5*xdamp^2 + " + "(1.0/6.0)*xdamp^3 + (1.0/24.0)*xdamp^4 + " + "(1.0/120.0)*xdamp^5 + (1.0/720.0)*xdamp^6);" + "xdamp=bdamp*r; " + "bdamp=1/0.021; " + ) + else: + expr += "fdamp=1.0;" + + """ + energy_function = 'lambda*4*epsilon*x*(x-1.0); x = (sigma/reff_sterics)^6;' + energy_function += 'reff_sterics = sigma*(0.5*(1.0-lambda) + (r/sigma)^6)^(1/6);' + energy_function += 'sigma = 0.5*(sigma1+sigma2); epsilon = sqrt(epsilon1*epsilon2);' + custom_force = CustomNonbondedForce(energy_function) + """ + + logger.debug(f"Softcore LJ expression: {expr}") # Create a CustomNonbondedForce to compute the softcore Lennard-Jones - soft_core_force = _mm.CustomNonbondedForce(energy_function) + soft_core_force = _mm.CustomNonbondedForce(expr) if self._NON_BONDED_METHODS[nb_force.getNonbondedMethod()] in [ self._NON_BONDED_METHODS[3], From 6e93c570dbf835b4ba04cfb776bfd91f8a082c92 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Wed, 29 Oct 2025 11:25:09 +0000 Subject: [PATCH 19/20] Updates to OpenFF strategy --- .../alchemical/strategies/openff_strategy.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/fes_ml/alchemical/strategies/openff_strategy.py b/src/fes_ml/alchemical/strategies/openff_strategy.py index f6ad8cc..321181a 100644 --- a/src/fes_ml/alchemical/strategies/openff_strategy.py +++ b/src/fes_ml/alchemical/strategies/openff_strategy.py @@ -333,11 +333,14 @@ def _solvate( packmol_kwargs_local.pop("number_of_copies") - topology_off = _pack_box( - molecules=mols, - number_of_copies=number_of_copies, - **packmol_kwargs_local, - ) + if number_of_copies[1] == 1: + topology_off = _Topology.from_molecules(mols) + else: + topology_off = _pack_box( + molecules=mols, + number_of_copies=number_of_copies, + **packmol_kwargs_local, + ) return topology_off @@ -593,7 +596,7 @@ def create_alchemical_state( if sdf_file_ligand is None: # Only generate conformers if no SDF file is provided # Otherwise, the geometry is taken from the SDF file - logger.debug("Generating conformers for the ligand") + logger.debug("Generating conformers for the ligand.") molecules["ligand"].generate_conformers() else: logger.debug(f"Using provided ligand geometry from {sdf_file_ligand}") @@ -701,7 +704,8 @@ def create_alchemical_state( _shutil.rmtree(self._TMP_DIR) _os.makedirs(self._TMP_DIR, exist_ok=True) files_prefix = _os.path.join(self._TMP_DIR, "interchange") - interchange.to_gromacs(prefix=files_prefix) + interchange.to_gro(files_prefix + ".gro") + interchange.to_top(files_prefix + ".top") # Read back those files using Sire sr_mols = _sr.load( @@ -758,7 +762,6 @@ def create_alchemical_state( ) modifications_kwargs[modification_name]["original_offxml"] = ffs modifications_kwargs[modification_name]["topology_off"] = topology_off - # Get the positions of the alchemical atoms modifications_kwargs[modification_name]["positions"] = positions # Handle ChargeTransfer modifications From 3cd12ce4e52d9b2a6d17ae6de6b41906e0c97ea3 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Wed, 29 Oct 2025 11:25:33 +0000 Subject: [PATCH 20/20] Generalise CT --- .../modifications/charge_transfer.py | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/src/fes_ml/alchemical/modifications/charge_transfer.py b/src/fes_ml/alchemical/modifications/charge_transfer.py index 12fd737..a2b3d8d 100644 --- a/src/fes_ml/alchemical/modifications/charge_transfer.py +++ b/src/fes_ml/alchemical/modifications/charge_transfer.py @@ -90,6 +90,7 @@ def apply( system: _mm.System, alchemical_atoms: List[int], original_offxml: List[str], + ct_offxml: str, topology_off: _Topology, lambda_value: Optional[Union[float, int]] = 1.0, *args, @@ -104,6 +105,8 @@ def apply( The system to be modified. alchemical_atoms : list of int The indices of the alchemical atoms in the system. + ct_offxml: str + Path to the offxml file containing the charge transfer parameters. original_offxml : List[str] List of paths to the original offxml files. topology_off : openff.toolkit.topology.Topology @@ -120,34 +123,12 @@ def apply( openmm.System The modified system. """ - ct_paraams = { - "n21": {"sigma": 3.1739, "eps": 169.0640}, - "n16": {"sigma": 4.3983, "eps": 42.4290}, - "n3": {"sigma": 5.9554, "eps": 0.1890}, - "n-tip3p-O": {"sigma": 5.5019, "eps": 119.9913}, - "n-tip3p-H": {"sigma": 5.6480, "eps": 3.4209}, - "n20": {"sigma": 4.1808, "eps": 71.9843}, - "n2": {"sigma": 7.6786, "eps": 0.1502}, - "n18": {"sigma": 4.7983, "eps": 78.6935}, - "n14": {"sigma": 3.9557, "eps": 50.0886}, - "n17": {"sigma": 4.6716, "eps": 61.3267}, - "n9": {"sigma": 5.2866, "eps": 0.8104}, - "n4": {"sigma": 5.8806, "eps": 0.4651}, - "n13": {"sigma": 5.3179, "eps": 1.1033}, - "n11": {"sigma": 5.2541, "eps": 0.7611}, - "n19": {"sigma": 4.9034, "eps": 75.1253}, - "n12": {"sigma": 4.6507, "eps": 1.0854}, - "n7": {"sigma": 6.8461, "eps": 0.5855}, - "n15": {"sigma": 4.3835, "eps": 59.0651}, - "n10": {"sigma": 5.3841, "eps": 0.4192}, - "n8": {"sigma": 6.7245, "eps": 0.7046}, - } - - energy_function = f"-{lambda_value}*donor_acceptor*epsilon*exp(-sigma*r);" - energy_function += "sigma = 0.5*(sigma1+sigma2);" - energy_function += "epsilon = sqrt(epsilon1*epsilon2);" + # Convert units + energy_function = f"-{lambda_value}*donor_acceptor*epsilon*exp(-r/sigma);" + energy_function += "sigma = sqrt(sigma1*sigma2);" + energy_function += "epsilon = (epsilon1*epsilon2);" energy_function += ( - "donor_acceptor = isDonor1*isAcceptor2 + isDonor2*isAcceptor1;" + "donor_acceptor = 1;"#isDonor1*isAcceptor2 + isDonor2*isAcceptor1;" ) logger.debug(f"Charge transfer function: {energy_function}") @@ -163,28 +144,43 @@ def apply( # Add per-particle parameters to the CustomNonbondedForce charge_transfer_force.addPerParticleParameter("sigma") charge_transfer_force.addPerParticleParameter("epsilon") - charge_transfer_force.addPerParticleParameter("isDonor") - charge_transfer_force.addPerParticleParameter("isAcceptor") + #charge_transfer_force.addPerParticleParameter("isDonor") + #charge_transfer_force.addPerParticleParameter("isAcceptor") # Update the Lennard-Jones parameters in the CustomNonbondedForce - force_field = _ForceField(*original_offxml) - labels = force_field.label_molecules(topology_off) + #force_field = _ForceField(*original_offxml) + #labels = force_field.label_molecules(topology_off) # Get atom types - atom_types = [val.id for mol in labels for _, val in mol["vdW"].items()] + #atom_types = [val.id for mol in labels for _, val in mol["vdW"].items()] # Get donor/acceptor flags is_donor, is_acceptor = ChargeTransferModification.get_is_donor_acceptor( topology_off, alchemical_atoms ) + # CT force field + ct_force_field = _ForceField(ct_offxml) + labels = ct_force_field.label_molecules(topology_off) + ct_params = { + p.id: { + "epsilon": p.epsilon.to_openmm().value_in_unit( + _unit.kilojoules_per_mole + ), + "sigma": p.sigma.to_openmm().value_in_unit(_unit.nanometer), + } + for p in ct_force_field.get_parameter_handler("vdW") + } + atom_types = [val.id for mol in labels for _, val in mol["vdW"].items()] + for index in range(system.getNumParticles()): + at_type = atom_types[index] charge_transfer_force.addParticle( [ - ct_paraams.get(atom_types[index], {}).get("sigma", 0) * 10, - ct_paraams.get(atom_types[index], {}).get("eps", 0) * 1e3, - is_donor[index], - is_acceptor[index], + ct_params[at_type]["sigma"], + ct_params[at_type]["epsilon"],# * 10.0, + #is_donor[index], + #is_acceptor[index], ] )