Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 20 additions & 32 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@

*.pyc
catpred/uncertainty/__pycache__/uncertainty_predictor.cpython-312.pyc
catpred/uncertainty/__pycache__/uncertainty_evaluator.cpython-312.pyc
catpred/uncertainty/__pycache__/uncertainty_estimator.cpython-312.pyc
catpred/uncertainty/__pycache__/uncertainty_calibrator.cpython-312.pyc
catpred/uncertainty/__pycache__/__init__.cpython-312.pyc
catpred/train/__pycache__/train.cpython-312.pyc
catpred/train/__pycache__/run_training.cpython-312.pyc
catpred/train/__pycache__/predict.cpython-312.pyc
catpred/train/__pycache__/molecule_fingerprint.cpython-312.pyc
catpred/train/__pycache__/metrics.cpython-312.pyc
catpred/train/__pycache__/make_predictions.cpython-312.pyc
catpred/train/__pycache__/loss_functions.cpython-312.pyc
catpred/train/__pycache__/evaluate.cpython-312.pyc
catpred/train/__pycache__/cross_validate.cpython-312.pyc
catpred/train/__pycache__/__init__.cpython-312.pyc
catpred/models/__pycache__/transformer_models.cpython-312.pyc
catpred/models/__pycache__/mpn.cpython-312.pyc
catpred/models/__pycache__/model.cpython-312.pyc
catpred/models/__pycache__/ffn.cpython-312.pyc
catpred/models/__pycache__/__init__.cpython-312.pyc
catpred/models/.ipynb_checkpoints/model-checkpoint.py
catpred/features/__pycache__/utils.cpython-312.pyc
catpred/features/__pycache__/featurization.cpython-312.pyc
catpred/features/__pycache__/features_generators.cpython-312.pyc
catpred/features/__pycache__/__init__.cpython-312.pyc
catpred/data/__pycache__/utils.cpython-312.pyc
catpred/data/__pycache__/scaler.cpython-312.pyc
catpred/data/__pycache__/scaffold.cpython-312.pyc
catpred/data/__pycache__/esm_utils.cpython-312.pyc
catpred/data/__pycache__/data.cpython-312.pyc
catpred/data/__pycache__/cache_utils.cpython-312.pyc
catpred/data/__pycache__/__init__.cpython-312.pyc
__pycache__/
*.pyo
*.pyd
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
.ipynb_checkpoints/
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- [System Requirements](#requirements)
- [Installation](#installing)
- [Prediction](#predict)
- [pH and Temperature Features](#ph-temp)
- [Reproducibility](#reproduce)
- [Acknowledgements](#acknw)
- [License](#license)
Expand Down Expand Up @@ -78,6 +79,32 @@ pip install -e .

The Jupyter Notebook `batch_demo.ipynb` and the Python script `demo_run.py` show the usage of pre-trained models for prediction.

### 🌡️ Using pH and Temperature Features <a name="ph-temp"></a>

CatPred supports including pH and temperature as additional input features for training and prediction. To use these features:

1. **Data Preparation**: Include columns for pH and/or temperature in your CSV data file.

2. **Training with pH/Temp**: Specify the column names using the command line arguments:
```bash
python train.py --data_path <path> --ph_column <ph_column_name> --temp_column <temp_column_name> ...
```

3. **Available Options**:
- `--ph_column`: Name of the column containing pH values
- `--temp_column`: Name of the column containing temperature values
- `--no_ph_temp_features_scaling`: Disable normalization of pH and temperature features (enabled by default)

4. **Prediction**: When using a model trained with pH/Temp features, provide the same columns in your test data:
```bash
python predict.py --test_path <path> --ph_column <ph_column_name> --temp_column <temp_column_name> ...
```

**Notes**:
- Missing pH/Temp values (empty, 'nan', 'None', 'null') are handled gracefully and replaced with 0
- pH and temperature features are appended to the existing molecular features vector
- Feature scaling (normalization) is applied by default for better model performance

### 🔄 Reproducing Publication Results <a name="reproduce"></a>

We provide three separate ways for reproducing the results of the publication.
Expand Down
23 changes: 23 additions & 0 deletions catpred/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ class TrainArgs(CommonArgs):
"""
ignore_columns: List[str] = None
"""Name of the columns to ignore when :code:`target_columns` is not provided."""
ph_column: str = None
"""Name of the column containing pH values. If provided, pH will be used as an input feature."""
temp_column: str = None
"""Name of the column containing temperature values. If provided, temperature will be used as an input feature."""
no_ph_temp_features_scaling: bool = False
"""Turn off scaling of pH and temperature features."""
dataset_type: Literal['regression', 'classification', 'multiclass', 'spectra']
"""Type of dataset. This determines the default loss function used during training."""
loss_function: Literal['mse', 'bounded_mse', 'binary_cross_entropy', 'cross_entropy', 'mcc', 'sid', 'wasserstein', 'mve', 'evidential', 'dirichlet'] = None
Expand Down Expand Up @@ -638,6 +644,19 @@ def bond_descriptor_scaling(self) -> bool:
"""
return not self.no_bond_descriptor_scaling

@property
def ph_temp_features_scaling(self) -> bool:
"""
Whether to apply normalization with a :class:`~catpred.data.scaler.StandardScaler`
to the pH and temperature features.
"""
return not self.no_ph_temp_features_scaling

@property
def use_ph_temp_features(self) -> bool:
"""Whether the model is using pH and/or temperature as input features."""
return self.ph_column is not None or self.temp_column is not None

@property
def shared_atom_bond_ffn(self) -> bool:
"""
Expand Down Expand Up @@ -935,6 +954,10 @@ class PredictArgs(CommonArgs):
"""Path to CSV or PICKLE file where predictions will be saved."""
vocabulary_path: str = None
""" Path to EC and Taxonomy vocabulary file """
ph_column: str = None
"""Name of the column containing pH values. If provided, pH will be used as an input feature."""
temp_column: str = None
"""Name of the column containing temperature values. If provided, temperature will be used as an input feature."""
drop_extra_columns: bool = False
"""Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns."""
ensemble_variance: bool = False
Expand Down
Binary file removed catpred/data/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file removed catpred/data/__pycache__/data.cpython-312.pyc
Binary file not shown.
Binary file removed catpred/data/__pycache__/esm_utils.cpython-312.pyc
Binary file not shown.
Binary file removed catpred/data/__pycache__/scaffold.cpython-312.pyc
Binary file not shown.
Binary file removed catpred/data/__pycache__/scaler.cpython-312.pyc
Binary file not shown.
Binary file removed catpred/data/__pycache__/utils.cpython-312.pyc
Binary file not shown.
127 changes: 125 additions & 2 deletions catpred/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def __init__(self,
raw_constraints: np.ndarray = None,
constraints: np.ndarray = None,
overwrite_default_atom_features: bool = False,
overwrite_default_bond_features: bool = False):
overwrite_default_bond_features: bool = False,
ph: Optional[float] = None,
temp: Optional[float] = None):
"""
:param smiles: A list of the SMILES strings for the molecules.
:param vocabulary: A dict of vocabulary of EC and Taxonomy categories
Expand All @@ -96,6 +98,8 @@ def __init__(self,
:param constraints: A numpy array containing atom/bond-level constraints that are used in training. Param constraints is a subset of param raw_constraints.
:param overwrite_default_atom_features: Boolean to overwrite default atom features by atom_features.
:param overwrite_default_bond_features: Boolean to overwrite default bond features by bond_features.
:param ph: Optional pH value for the reaction conditions.
:param temp: Optional temperature value for the reaction conditions.

"""
self.smiles = smiles
Expand All @@ -122,6 +126,10 @@ def __init__(self,
self.is_adding_hs_list = [is_adding_hs(x) for x in self.is_mol_list]
self.is_keeping_atom_map_list = [is_keeping_atom_map(x) for x in self.is_mol_list]

# Store pH and temperature values
self.ph = ph
self.temp = temp

if data_weight is not None:
self.data_weight = data_weight
if gt_targets is not None:
Expand Down Expand Up @@ -183,6 +191,10 @@ def __init__(self,
self.raw_atom_descriptors, self.raw_atom_features, self.raw_bond_descriptors, self.raw_bond_features = \
self.atom_descriptors, self.atom_features, self.bond_descriptors, self.bond_features

# Save raw pH and temperature values for potential scaling
self.raw_ph = self.ph
self.raw_temp = self.temp

self.ec_features = []
self.tax_features = []

Expand Down Expand Up @@ -301,6 +313,36 @@ def extend_features(self, features: np.ndarray) -> None:
"""
self.features = np.append(self.features, features) if self.features is not None else features

def extend_features_with_ph_temp(self) -> None:
"""
Extends the features of the molecule with pH and temperature values.

This method appends pH and Temp to the existing features vector.
Missing values (None) are replaced with 0.
"""
ph_temp_features = [
self.ph if self.ph is not None else 0.0,
self.temp if self.temp is not None else 0.0
]
ph_temp_array = np.array(ph_temp_features)
self.features = np.append(self.features, ph_temp_array) if self.features is not None else ph_temp_array

def set_ph(self, ph: Optional[float]) -> None:
"""
Sets the pH value for the molecule.

:param ph: pH value for the reaction conditions.
"""
self.ph = ph

def set_temp(self, temp: Optional[float]) -> None:
"""
Sets the temperature value for the molecule.

:param temp: Temperature value for the reaction conditions.
"""
self.temp = temp

def num_tasks(self) -> int:
"""
Returns the number of prediction tasks.
Expand All @@ -318,11 +360,14 @@ def set_targets(self, targets: List[Optional[float]]):
self.targets = targets

def reset_features_and_targets(self) -> None:
"""Resets the features (atom, bond, and molecule) and targets to their raw values."""
"""Resets the features (atom, bond, and molecule), targets, and pH/temperature values to their raw values."""
self.features, self.targets, self.atom_targets, self.bond_targets = \
self.raw_features, self.raw_targets, self.raw_atom_targets, self.raw_bond_targets
self.atom_descriptors, self.atom_features, self.bond_descriptors, self.bond_features = \
self.raw_atom_descriptors, self.raw_atom_features, self.raw_bond_descriptors, self.raw_bond_features
# Reset pH and temperature to raw values
self.ph = self.raw_ph
self.temp = self.raw_temp


class MoleculeDataset(Dataset):
Expand Down Expand Up @@ -351,6 +396,32 @@ def tax_features(self) -> List[List[int]]:
:return: A list of lists of tax_features
"""
return [d.tax_features for d in self._data]

def ph_values(self) -> List[Optional[float]]:
"""
Returns a list containing the pH value associated with each :class:`MoleculeDatapoint`.

:return: A list of pH values (or None if not set)
"""
return [d.ph for d in self._data]

def temp_values(self) -> List[Optional[float]]:
"""
Returns a list containing the temperature value associated with each :class:`MoleculeDatapoint`.

:return: A list of temperature values (or None if not set)
"""
return [d.temp for d in self._data]

def has_ph_temp_features(self) -> bool:
"""
Returns whether any datapoint has pH or temperature features.

:return: True if any datapoint has pH or temp, False otherwise.
"""
if len(self._data) == 0:
return False
return any(d.ph is not None or d.temp is not None for d in self._data)

def smiles(self, flatten: bool = False) -> Union[List[str], List[List[str]]]:
"""
Expand Down Expand Up @@ -732,6 +803,58 @@ def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: i

return scaler

def normalize_ph_temp(self, scaler: StandardScaler = None, replace_nan_token: float = 0) -> StandardScaler:
"""
Normalizes the pH and temperature features of the dataset using a :class:`~catpred.data.StandardScaler`.

The :class:`~catpred.data.StandardScaler` subtracts the mean and divides by the standard deviation
for pH and temperature independently.

If a :class:`~catpred.data.StandardScaler` is provided, it is used to perform the normalization.
Otherwise, a :class:`~catpred.data.StandardScaler` is first fit to the pH and temp in this dataset
and is then used to perform the normalization.

:param scaler: A fitted :class:`~catpred.data.StandardScaler`. If it is provided it is used,
otherwise a new :class:`~catpred.data.StandardScaler` is first fitted to this
data and is then used.
:param replace_nan_token: A token to use to replace NaN entries in the features.
:return: A fitted :class:`~catpred.data.StandardScaler`. If a :class:`~catpred.data.StandardScaler`
is provided as a parameter, this is the same :class:`~catpred.data.StandardScaler`. Otherwise,
this is a new :class:`~catpred.data.StandardScaler` that has been fit on this dataset.
"""
if len(self._data) == 0 or not self.has_ph_temp_features():
return None

# Collect pH and temp values, replacing None with replace_nan_token
ph_temp_data = []
for d in self._data:
ph_val = d.raw_ph if d.raw_ph is not None else replace_nan_token
temp_val = d.raw_temp if d.raw_temp is not None else replace_nan_token
ph_temp_data.append([ph_val, temp_val])

ph_temp_data = np.array(ph_temp_data)

if scaler is None:
scaler = StandardScaler(replace_nan_token=replace_nan_token)
scaler.fit(ph_temp_data)

scaled_ph_temp = scaler.transform(ph_temp_data)

for i, d in enumerate(self._data):
d.set_ph(scaled_ph_temp[i, 0])
d.set_temp(scaled_ph_temp[i, 1])

return scaler

def extend_features_with_ph_temp(self) -> None:
"""
Extends the features of each molecule in the dataset with pH and temperature values.

This method appends pH and Temp to the existing features vector for each datapoint.
"""
for d in self._data:
d.extend_features_with_ph_temp()

def normalize_targets(self) -> StandardScaler:
"""
Normalizes the targets of the dataset using a :class:`~catpred.data.StandardScaler`.
Expand Down
42 changes: 41 additions & 1 deletion catpred/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ def get_data(path: str,
constraints_path = constraints_path if constraints_path is not None else args.constraints_path
max_data_size = max_data_size if max_data_size is not None else args.max_data_size
loss_function = loss_function if loss_function is not None else args.loss_function
# Get pH and temperature column names
ph_column = getattr(args, 'ph_column', None)
temp_column = getattr(args, 'temp_column', None)
else:
ph_column = None
temp_column = None

if isinstance(smiles_columns, str) or smiles_columns is None:
smiles_columns = preprocess_smiles_columns(path=path, smiles_columns=smiles_columns)
Expand Down Expand Up @@ -490,9 +496,15 @@ def get_data(path: str,
raise ValueError(f'Data file did not contain all provided smiles columns: {smiles_columns}. Data file field names are: {fieldnames}')
if any([c not in fieldnames for c in target_columns]):
raise ValueError(f'Data file did not contain all provided target columns: {target_columns}. Data file field names are: {fieldnames}')
# Validate pH and temp columns if specified
if ph_column is not None and ph_column not in fieldnames:
raise ValueError(f'Data file did not contain the specified pH column: {ph_column}. Data file field names are: {fieldnames}')
if temp_column is not None and temp_column not in fieldnames:
raise ValueError(f'Data file did not contain the specified temperature column: {temp_column}. Data file field names are: {fieldnames}')

all_smiles, all_sequences, all_targets, all_atom_targets, all_bond_targets, all_rows, all_features, all_phase_features, all_constraints_data, all_raw_constraints_data, all_weights, all_gt, all_lt = [], [], [], [], [], [], [], [], [], [], [], [], []
all_protein_records = []
all_ph, all_temp = [], []
for i, row in enumerate(tqdm(reader)):
smoke_test_counter+=1
if args.smoke_test:
Expand Down Expand Up @@ -575,6 +587,32 @@ def get_data(path: str,
if lt_targets is not None:
all_lt.append(lt_targets[i])

# Extract pH value if column is specified
if ph_column is not None:
ph_value = row[ph_column]
if ph_value in ['', 'nan', 'None', 'null']:
all_ph.append(None)
else:
try:
all_ph.append(float(ph_value))
except ValueError:
all_ph.append(None)
else:
all_ph.append(None)

# Extract temperature value if column is specified
if temp_column is not None:
temp_value = row[temp_column]
if temp_value in ['', 'nan', 'None', 'null']:
all_temp.append(None)
else:
try:
all_temp.append(float(temp_value))
except ValueError:
all_temp.append(None)
else:
all_temp.append(None)

if store_row:
all_rows.append(row)

Expand Down Expand Up @@ -629,7 +667,9 @@ def get_data(path: str,
constraints=all_constraints_data[i] if constraints_data is not None else None,
raw_constraints=all_raw_constraints_data[i] if raw_constraints_data is not None else None,
overwrite_default_atom_features=args.overwrite_default_atom_features if args is not None else False,
overwrite_default_bond_features=args.overwrite_default_bond_features if args is not None else False
overwrite_default_bond_features=args.overwrite_default_bond_features if args is not None else False,
ph=all_ph[i],
temp=all_temp[i]
) for i, (smiles, targets) in tqdm(enumerate(zip(all_smiles, all_targets)),
total=len(all_smiles))
])
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed catpred/features/__pycache__/utils.cpython-312.pyc
Binary file not shown.
Binary file removed catpred/train/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file removed catpred/train/__pycache__/evaluate.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed catpred/train/__pycache__/metrics.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file removed catpred/train/__pycache__/predict.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file removed catpred/train/__pycache__/train.cpython-312.pyc
Binary file not shown.
Loading