diff --git a/.gitignore b/.gitignore index b155533..34f9fb1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,34 +1,22 @@ *.pyc -catpred/uncertainty/__pycache__/uncertainty_predictor.cpython-312.pyc -catpred/uncertainty/__pycache__/uncertainty_evaluator.cpython-312.pyc -catpred/uncertainty/__pycache__/uncertainty_estimator.cpython-312.pyc -catpred/uncertainty/__pycache__/uncertainty_calibrator.cpython-312.pyc -catpred/uncertainty/__pycache__/__init__.cpython-312.pyc -catpred/train/__pycache__/train.cpython-312.pyc -catpred/train/__pycache__/run_training.cpython-312.pyc -catpred/train/__pycache__/predict.cpython-312.pyc -catpred/train/__pycache__/molecule_fingerprint.cpython-312.pyc -catpred/train/__pycache__/metrics.cpython-312.pyc -catpred/train/__pycache__/make_predictions.cpython-312.pyc -catpred/train/__pycache__/loss_functions.cpython-312.pyc -catpred/train/__pycache__/evaluate.cpython-312.pyc -catpred/train/__pycache__/cross_validate.cpython-312.pyc -catpred/train/__pycache__/__init__.cpython-312.pyc -catpred/models/__pycache__/transformer_models.cpython-312.pyc -catpred/models/__pycache__/mpn.cpython-312.pyc -catpred/models/__pycache__/model.cpython-312.pyc -catpred/models/__pycache__/ffn.cpython-312.pyc -catpred/models/__pycache__/__init__.cpython-312.pyc -catpred/models/.ipynb_checkpoints/model-checkpoint.py -catpred/features/__pycache__/utils.cpython-312.pyc -catpred/features/__pycache__/featurization.cpython-312.pyc -catpred/features/__pycache__/features_generators.cpython-312.pyc -catpred/features/__pycache__/__init__.cpython-312.pyc -catpred/data/__pycache__/utils.cpython-312.pyc -catpred/data/__pycache__/scaler.cpython-312.pyc -catpred/data/__pycache__/scaffold.cpython-312.pyc -catpred/data/__pycache__/esm_utils.cpython-312.pyc -catpred/data/__pycache__/data.cpython-312.pyc -catpred/data/__pycache__/cache_utils.cpython-312.pyc -catpred/data/__pycache__/__init__.cpython-312.pyc +__pycache__/ +*.pyo +*.pyd +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +.ipynb_checkpoints/ diff --git a/README.md b/README.md index 1933ecd..221de52 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ - [System Requirements](#requirements) - [Installation](#installing) - [Prediction](#predict) + - [pH and Temperature Features](#ph-temp) - [Reproducibility](#reproduce) - [Acknowledgements](#acknw) - [License](#license) @@ -78,6 +79,32 @@ pip install -e . The Jupyter Notebook `batch_demo.ipynb` and the Python script `demo_run.py` show the usage of pre-trained models for prediction. +### 🌡️ Using pH and Temperature Features + +CatPred supports including pH and temperature as additional input features for training and prediction. To use these features: + +1. **Data Preparation**: Include columns for pH and/or temperature in your CSV data file. + +2. **Training with pH/Temp**: Specify the column names using the command line arguments: +```bash +python train.py --data_path --ph_column --temp_column ... +``` + +3. **Available Options**: + - `--ph_column`: Name of the column containing pH values + - `--temp_column`: Name of the column containing temperature values + - `--no_ph_temp_features_scaling`: Disable normalization of pH and temperature features (enabled by default) + +4. **Prediction**: When using a model trained with pH/Temp features, provide the same columns in your test data: +```bash +python predict.py --test_path --ph_column --temp_column ... +``` + +**Notes**: +- Missing pH/Temp values (empty, 'nan', 'None', 'null') are handled gracefully and replaced with 0 +- pH and temperature features are appended to the existing molecular features vector +- Feature scaling (normalization) is applied by default for better model performance + ### 🔄 Reproducing Publication Results We provide three separate ways for reproducing the results of the publication. diff --git a/catpred/args.py b/catpred/args.py index d52b198..6ca5b73 100644 --- a/catpred/args.py +++ b/catpred/args.py @@ -264,6 +264,12 @@ class TrainArgs(CommonArgs): """ ignore_columns: List[str] = None """Name of the columns to ignore when :code:`target_columns` is not provided.""" + ph_column: str = None + """Name of the column containing pH values. If provided, pH will be used as an input feature.""" + temp_column: str = None + """Name of the column containing temperature values. If provided, temperature will be used as an input feature.""" + no_ph_temp_features_scaling: bool = False + """Turn off scaling of pH and temperature features.""" dataset_type: Literal['regression', 'classification', 'multiclass', 'spectra'] """Type of dataset. This determines the default loss function used during training.""" loss_function: Literal['mse', 'bounded_mse', 'binary_cross_entropy', 'cross_entropy', 'mcc', 'sid', 'wasserstein', 'mve', 'evidential', 'dirichlet'] = None @@ -638,6 +644,19 @@ def bond_descriptor_scaling(self) -> bool: """ return not self.no_bond_descriptor_scaling + @property + def ph_temp_features_scaling(self) -> bool: + """ + Whether to apply normalization with a :class:`~catpred.data.scaler.StandardScaler` + to the pH and temperature features. + """ + return not self.no_ph_temp_features_scaling + + @property + def use_ph_temp_features(self) -> bool: + """Whether the model is using pH and/or temperature as input features.""" + return self.ph_column is not None or self.temp_column is not None + @property def shared_atom_bond_ffn(self) -> bool: """ @@ -935,6 +954,10 @@ class PredictArgs(CommonArgs): """Path to CSV or PICKLE file where predictions will be saved.""" vocabulary_path: str = None """ Path to EC and Taxonomy vocabulary file """ + ph_column: str = None + """Name of the column containing pH values. If provided, pH will be used as an input feature.""" + temp_column: str = None + """Name of the column containing temperature values. If provided, temperature will be used as an input feature.""" drop_extra_columns: bool = False """Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns.""" ensemble_variance: bool = False diff --git a/catpred/data/__pycache__/__init__.cpython-312.pyc b/catpred/data/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 5dae7d1..0000000 Binary files a/catpred/data/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/catpred/data/__pycache__/cache_utils.cpython-312.pyc b/catpred/data/__pycache__/cache_utils.cpython-312.pyc deleted file mode 100644 index 680d8a2..0000000 Binary files a/catpred/data/__pycache__/cache_utils.cpython-312.pyc and /dev/null differ diff --git a/catpred/data/__pycache__/data.cpython-312.pyc b/catpred/data/__pycache__/data.cpython-312.pyc deleted file mode 100644 index 022118f..0000000 Binary files a/catpred/data/__pycache__/data.cpython-312.pyc and /dev/null differ diff --git a/catpred/data/__pycache__/esm_utils.cpython-312.pyc b/catpred/data/__pycache__/esm_utils.cpython-312.pyc deleted file mode 100644 index 97db928..0000000 Binary files a/catpred/data/__pycache__/esm_utils.cpython-312.pyc and /dev/null differ diff --git a/catpred/data/__pycache__/scaffold.cpython-312.pyc b/catpred/data/__pycache__/scaffold.cpython-312.pyc deleted file mode 100644 index 95f8eff..0000000 Binary files a/catpred/data/__pycache__/scaffold.cpython-312.pyc and /dev/null differ diff --git a/catpred/data/__pycache__/scaler.cpython-312.pyc b/catpred/data/__pycache__/scaler.cpython-312.pyc deleted file mode 100644 index d8767fe..0000000 Binary files a/catpred/data/__pycache__/scaler.cpython-312.pyc and /dev/null differ diff --git a/catpred/data/__pycache__/utils.cpython-312.pyc b/catpred/data/__pycache__/utils.cpython-312.pyc deleted file mode 100644 index 5bade1f..0000000 Binary files a/catpred/data/__pycache__/utils.cpython-312.pyc and /dev/null differ diff --git a/catpred/data/data.py b/catpred/data/data.py index cf04dd3..81a81d3 100644 --- a/catpred/data/data.py +++ b/catpred/data/data.py @@ -75,7 +75,9 @@ def __init__(self, raw_constraints: np.ndarray = None, constraints: np.ndarray = None, overwrite_default_atom_features: bool = False, - overwrite_default_bond_features: bool = False): + overwrite_default_bond_features: bool = False, + ph: Optional[float] = None, + temp: Optional[float] = None): """ :param smiles: A list of the SMILES strings for the molecules. :param vocabulary: A dict of vocabulary of EC and Taxonomy categories @@ -96,6 +98,8 @@ def __init__(self, :param constraints: A numpy array containing atom/bond-level constraints that are used in training. Param constraints is a subset of param raw_constraints. :param overwrite_default_atom_features: Boolean to overwrite default atom features by atom_features. :param overwrite_default_bond_features: Boolean to overwrite default bond features by bond_features. + :param ph: Optional pH value for the reaction conditions. + :param temp: Optional temperature value for the reaction conditions. """ self.smiles = smiles @@ -122,6 +126,10 @@ def __init__(self, self.is_adding_hs_list = [is_adding_hs(x) for x in self.is_mol_list] self.is_keeping_atom_map_list = [is_keeping_atom_map(x) for x in self.is_mol_list] + # Store pH and temperature values + self.ph = ph + self.temp = temp + if data_weight is not None: self.data_weight = data_weight if gt_targets is not None: @@ -183,6 +191,10 @@ def __init__(self, self.raw_atom_descriptors, self.raw_atom_features, self.raw_bond_descriptors, self.raw_bond_features = \ self.atom_descriptors, self.atom_features, self.bond_descriptors, self.bond_features + # Save raw pH and temperature values for potential scaling + self.raw_ph = self.ph + self.raw_temp = self.temp + self.ec_features = [] self.tax_features = [] @@ -301,6 +313,36 @@ def extend_features(self, features: np.ndarray) -> None: """ self.features = np.append(self.features, features) if self.features is not None else features + def extend_features_with_ph_temp(self) -> None: + """ + Extends the features of the molecule with pH and temperature values. + + This method appends pH and Temp to the existing features vector. + Missing values (None) are replaced with 0. + """ + ph_temp_features = [ + self.ph if self.ph is not None else 0.0, + self.temp if self.temp is not None else 0.0 + ] + ph_temp_array = np.array(ph_temp_features) + self.features = np.append(self.features, ph_temp_array) if self.features is not None else ph_temp_array + + def set_ph(self, ph: Optional[float]) -> None: + """ + Sets the pH value for the molecule. + + :param ph: pH value for the reaction conditions. + """ + self.ph = ph + + def set_temp(self, temp: Optional[float]) -> None: + """ + Sets the temperature value for the molecule. + + :param temp: Temperature value for the reaction conditions. + """ + self.temp = temp + def num_tasks(self) -> int: """ Returns the number of prediction tasks. @@ -318,11 +360,14 @@ def set_targets(self, targets: List[Optional[float]]): self.targets = targets def reset_features_and_targets(self) -> None: - """Resets the features (atom, bond, and molecule) and targets to their raw values.""" + """Resets the features (atom, bond, and molecule), targets, and pH/temperature values to their raw values.""" self.features, self.targets, self.atom_targets, self.bond_targets = \ self.raw_features, self.raw_targets, self.raw_atom_targets, self.raw_bond_targets self.atom_descriptors, self.atom_features, self.bond_descriptors, self.bond_features = \ self.raw_atom_descriptors, self.raw_atom_features, self.raw_bond_descriptors, self.raw_bond_features + # Reset pH and temperature to raw values + self.ph = self.raw_ph + self.temp = self.raw_temp class MoleculeDataset(Dataset): @@ -351,6 +396,32 @@ def tax_features(self) -> List[List[int]]: :return: A list of lists of tax_features """ return [d.tax_features for d in self._data] + + def ph_values(self) -> List[Optional[float]]: + """ + Returns a list containing the pH value associated with each :class:`MoleculeDatapoint`. + + :return: A list of pH values (or None if not set) + """ + return [d.ph for d in self._data] + + def temp_values(self) -> List[Optional[float]]: + """ + Returns a list containing the temperature value associated with each :class:`MoleculeDatapoint`. + + :return: A list of temperature values (or None if not set) + """ + return [d.temp for d in self._data] + + def has_ph_temp_features(self) -> bool: + """ + Returns whether any datapoint has pH or temperature features. + + :return: True if any datapoint has pH or temp, False otherwise. + """ + if len(self._data) == 0: + return False + return any(d.ph is not None or d.temp is not None for d in self._data) def smiles(self, flatten: bool = False) -> Union[List[str], List[List[str]]]: """ @@ -732,6 +803,58 @@ def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: i return scaler + def normalize_ph_temp(self, scaler: StandardScaler = None, replace_nan_token: float = 0) -> StandardScaler: + """ + Normalizes the pH and temperature features of the dataset using a :class:`~catpred.data.StandardScaler`. + + The :class:`~catpred.data.StandardScaler` subtracts the mean and divides by the standard deviation + for pH and temperature independently. + + If a :class:`~catpred.data.StandardScaler` is provided, it is used to perform the normalization. + Otherwise, a :class:`~catpred.data.StandardScaler` is first fit to the pH and temp in this dataset + and is then used to perform the normalization. + + :param scaler: A fitted :class:`~catpred.data.StandardScaler`. If it is provided it is used, + otherwise a new :class:`~catpred.data.StandardScaler` is first fitted to this + data and is then used. + :param replace_nan_token: A token to use to replace NaN entries in the features. + :return: A fitted :class:`~catpred.data.StandardScaler`. If a :class:`~catpred.data.StandardScaler` + is provided as a parameter, this is the same :class:`~catpred.data.StandardScaler`. Otherwise, + this is a new :class:`~catpred.data.StandardScaler` that has been fit on this dataset. + """ + if len(self._data) == 0 or not self.has_ph_temp_features(): + return None + + # Collect pH and temp values, replacing None with replace_nan_token + ph_temp_data = [] + for d in self._data: + ph_val = d.raw_ph if d.raw_ph is not None else replace_nan_token + temp_val = d.raw_temp if d.raw_temp is not None else replace_nan_token + ph_temp_data.append([ph_val, temp_val]) + + ph_temp_data = np.array(ph_temp_data) + + if scaler is None: + scaler = StandardScaler(replace_nan_token=replace_nan_token) + scaler.fit(ph_temp_data) + + scaled_ph_temp = scaler.transform(ph_temp_data) + + for i, d in enumerate(self._data): + d.set_ph(scaled_ph_temp[i, 0]) + d.set_temp(scaled_ph_temp[i, 1]) + + return scaler + + def extend_features_with_ph_temp(self) -> None: + """ + Extends the features of each molecule in the dataset with pH and temperature values. + + This method appends pH and Temp to the existing features vector for each datapoint. + """ + for d in self._data: + d.extend_features_with_ph_temp() + def normalize_targets(self) -> StandardScaler: """ Normalizes the targets of the dataset using a :class:`~catpred.data.StandardScaler`. diff --git a/catpred/data/utils.py b/catpred/data/utils.py index a2baaee..9eceaef 100644 --- a/catpred/data/utils.py +++ b/catpred/data/utils.py @@ -418,6 +418,12 @@ def get_data(path: str, constraints_path = constraints_path if constraints_path is not None else args.constraints_path max_data_size = max_data_size if max_data_size is not None else args.max_data_size loss_function = loss_function if loss_function is not None else args.loss_function + # Get pH and temperature column names + ph_column = getattr(args, 'ph_column', None) + temp_column = getattr(args, 'temp_column', None) + else: + ph_column = None + temp_column = None if isinstance(smiles_columns, str) or smiles_columns is None: smiles_columns = preprocess_smiles_columns(path=path, smiles_columns=smiles_columns) @@ -490,9 +496,15 @@ def get_data(path: str, raise ValueError(f'Data file did not contain all provided smiles columns: {smiles_columns}. Data file field names are: {fieldnames}') if any([c not in fieldnames for c in target_columns]): raise ValueError(f'Data file did not contain all provided target columns: {target_columns}. Data file field names are: {fieldnames}') + # Validate pH and temp columns if specified + if ph_column is not None and ph_column not in fieldnames: + raise ValueError(f'Data file did not contain the specified pH column: {ph_column}. Data file field names are: {fieldnames}') + if temp_column is not None and temp_column not in fieldnames: + raise ValueError(f'Data file did not contain the specified temperature column: {temp_column}. Data file field names are: {fieldnames}') all_smiles, all_sequences, all_targets, all_atom_targets, all_bond_targets, all_rows, all_features, all_phase_features, all_constraints_data, all_raw_constraints_data, all_weights, all_gt, all_lt = [], [], [], [], [], [], [], [], [], [], [], [], [] all_protein_records = [] + all_ph, all_temp = [], [] for i, row in enumerate(tqdm(reader)): smoke_test_counter+=1 if args.smoke_test: @@ -575,6 +587,32 @@ def get_data(path: str, if lt_targets is not None: all_lt.append(lt_targets[i]) + # Extract pH value if column is specified + if ph_column is not None: + ph_value = row[ph_column] + if ph_value in ['', 'nan', 'None', 'null']: + all_ph.append(None) + else: + try: + all_ph.append(float(ph_value)) + except ValueError: + all_ph.append(None) + else: + all_ph.append(None) + + # Extract temperature value if column is specified + if temp_column is not None: + temp_value = row[temp_column] + if temp_value in ['', 'nan', 'None', 'null']: + all_temp.append(None) + else: + try: + all_temp.append(float(temp_value)) + except ValueError: + all_temp.append(None) + else: + all_temp.append(None) + if store_row: all_rows.append(row) @@ -629,7 +667,9 @@ def get_data(path: str, constraints=all_constraints_data[i] if constraints_data is not None else None, raw_constraints=all_raw_constraints_data[i] if raw_constraints_data is not None else None, overwrite_default_atom_features=args.overwrite_default_atom_features if args is not None else False, - overwrite_default_bond_features=args.overwrite_default_bond_features if args is not None else False + overwrite_default_bond_features=args.overwrite_default_bond_features if args is not None else False, + ph=all_ph[i], + temp=all_temp[i] ) for i, (smiles, targets) in tqdm(enumerate(zip(all_smiles, all_targets)), total=len(all_smiles)) ]) diff --git a/catpred/features/__pycache__/__init__.cpython-312.pyc b/catpred/features/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 74b596f..0000000 Binary files a/catpred/features/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/catpred/features/__pycache__/features_generators.cpython-312.pyc b/catpred/features/__pycache__/features_generators.cpython-312.pyc deleted file mode 100644 index c3552ff..0000000 Binary files a/catpred/features/__pycache__/features_generators.cpython-312.pyc and /dev/null differ diff --git a/catpred/features/__pycache__/featurization.cpython-312.pyc b/catpred/features/__pycache__/featurization.cpython-312.pyc deleted file mode 100644 index 24645a8..0000000 Binary files a/catpred/features/__pycache__/featurization.cpython-312.pyc and /dev/null differ diff --git a/catpred/features/__pycache__/utils.cpython-312.pyc b/catpred/features/__pycache__/utils.cpython-312.pyc deleted file mode 100644 index 900f591..0000000 Binary files a/catpred/features/__pycache__/utils.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/__init__.cpython-312.pyc b/catpred/train/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 4eb42ae..0000000 Binary files a/catpred/train/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/cross_validate.cpython-312.pyc b/catpred/train/__pycache__/cross_validate.cpython-312.pyc deleted file mode 100644 index 7e37387..0000000 Binary files a/catpred/train/__pycache__/cross_validate.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/evaluate.cpython-312.pyc b/catpred/train/__pycache__/evaluate.cpython-312.pyc deleted file mode 100644 index 77c412c..0000000 Binary files a/catpred/train/__pycache__/evaluate.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/loss_functions.cpython-312.pyc b/catpred/train/__pycache__/loss_functions.cpython-312.pyc deleted file mode 100644 index 485f483..0000000 Binary files a/catpred/train/__pycache__/loss_functions.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/make_predictions.cpython-312.pyc b/catpred/train/__pycache__/make_predictions.cpython-312.pyc deleted file mode 100644 index b45b07b..0000000 Binary files a/catpred/train/__pycache__/make_predictions.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/metrics.cpython-312.pyc b/catpred/train/__pycache__/metrics.cpython-312.pyc deleted file mode 100644 index 12d98df..0000000 Binary files a/catpred/train/__pycache__/metrics.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/molecule_fingerprint.cpython-312.pyc b/catpred/train/__pycache__/molecule_fingerprint.cpython-312.pyc deleted file mode 100644 index 87b5abd..0000000 Binary files a/catpred/train/__pycache__/molecule_fingerprint.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/predict.cpython-312.pyc b/catpred/train/__pycache__/predict.cpython-312.pyc deleted file mode 100644 index f60f021..0000000 Binary files a/catpred/train/__pycache__/predict.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/run_training.cpython-312.pyc b/catpred/train/__pycache__/run_training.cpython-312.pyc deleted file mode 100644 index 63cc4ef..0000000 Binary files a/catpred/train/__pycache__/run_training.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/__pycache__/train.cpython-312.pyc b/catpred/train/__pycache__/train.cpython-312.pyc deleted file mode 100644 index 2608877..0000000 Binary files a/catpred/train/__pycache__/train.cpython-312.pyc and /dev/null differ diff --git a/catpred/train/run_training.py b/catpred/train/run_training.py index 1e52394..597f1e4 100644 --- a/catpred/train/run_training.py +++ b/catpred/train/run_training.py @@ -149,6 +149,18 @@ def run_training(args: TrainArgs, else: bond_descriptor_scaler = None + # Normalize and extend features with pH and temperature if they are used + ph_temp_scaler = None + if args.use_ph_temp_features: + if args.ph_temp_features_scaling: + ph_temp_scaler = train_data.normalize_ph_temp(replace_nan_token=0) + val_data.normalize_ph_temp(ph_temp_scaler) + test_data.normalize_ph_temp(ph_temp_scaler) + # Extend features with pH and temp values + train_data.extend_features_with_ph_temp() + val_data.extend_features_with_ph_temp() + test_data.extend_features_with_ph_temp() + args.train_data_size = len(train_data) debug(f'Total size = {len(data):,} | '