Skip to content

Commit 97e545c

Browse files
Nora KhalilNora Khalil
authored andcommitted
cleaned up and commented
1 parent 2edc09c commit 97e545c

File tree

1 file changed

+33
-35
lines changed

1 file changed

+33
-35
lines changed

rmgpy/data/kinetics/family.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,8 +2994,7 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29942994
#parent
29952995
if typr != 'intNewBondExt' and typr != 'extNewBondExt': # these dimensions should be regularized
29962996
if typr == 'atomExt':
2997-
pass
2998-
#grp.atoms[indcr[0]].reg_dim_atm = list(reg_val)
2997+
pass #no longer passing regularization info to the parent here. Doing this instead in `extend_node`
29992998
elif typr == 'elExt':
30002999
grp.atoms[indcr[0]].reg_dim_u = list(reg_val)
30013000
elif typr == 'ringExt':
@@ -3091,10 +3090,27 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
30913090
return out, gave_up_split
30923091

30933092
def get_compliment_reg_dim(self, parent, template_rxn_map, new_ext, comp_ext):
3093+
"""
3094+
Function takes in a parent node (`parent`), an extension node (`new_ext`) and its compliment (`comp_ext`).
3095+
Reactions of the parent node are split to extension and compliment.
3096+
Iterating over all the reactions that fit the complimentary node, the atomtypes of each labeled atom in each reaction are saved to a dictionary `atom_labeling_in_comp_rxns`,
3097+
where the key is the integer of the atom label (i.e. 5 in '*5') and the value is a set of all the atomtypes in all the complimentary reactions with that atom label.
3098+
3099+
Additionally, when iterating over all the reactions that fit the complimentary node, the atomtypes of each unlabeled atom in each reaction are saved to a list `unlabeled_atoms_in_comp_rxns`.
3100+
"""
3101+
3102+
3103+
assert comp_ext is not None, "This extension does not include a complimentary node. Cannot get regularization dimensions of complimentary node."
3104+
3105+
#divide parent reactions into the extension node and its compliment
30943106
rxns_from_parent = template_rxn_map[parent.label]
30953107
new_ext_rxns, comp_ext_rxns, _ = self._split_reactions(rxns_from_parent, new_ext)
3108+
3109+
#for saving data
30963110
atom_labeling_in_comp_rxns = dict()
30973111
unlabeled_atoms_in_comp_rxns = []
3112+
3113+
#iterate through each complimentary rxn
30983114
for rxn_c in comp_ext_rxns:
30993115
for reactant in rxn_c.reactants:
31003116
for mol in reactant.molecule:
@@ -3105,13 +3121,13 @@ def get_compliment_reg_dim(self, parent, template_rxn_map, new_ext, comp_ext):
31053121
if unlabeled_atmtype not in unlabeled_atoms_in_comp_rxns:
31063122
unlabeled_atoms_in_comp_rxns.append(unlabeled_atmtype)
31073123
else:
3124+
#this is a labeled atom
31083125
atm_label = int(atm.label.replace('*',''))
31093126
if atm_label not in atom_labeling_in_comp_rxns.keys():
31103127
atom_labeling_in_comp_rxns[atm_label] = [ATOMTYPES[atm.symbol]]
31113128
else:
31123129
existing_atomtypes = atom_labeling_in_comp_rxns[atm_label]
31133130
existing_atomtypes.append(ATOMTYPES[atm.symbol])
3114-
#print(f'count of missing * is {count}')
31153131
atom_labeling_in_comp_rxns_set = {k: set(v) for k, v in atom_labeling_in_comp_rxns.items()}
31163132

31173133
return atom_labeling_in_comp_rxns_set, unlabeled_atoms_in_comp_rxns
@@ -3191,45 +3207,35 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
31913207

31923208
extname = ext[2]
31933209

3194-
print(extname, ext[3])
3210+
31953211
if ext[3] == 'atomExt':
31963212
ext[0].atoms[ext[4][0]].reg_dim_atm = [ext[0].atoms[ext[4][0]].atomtype, ext[0].atoms[ext[4][0]].atomtype] #passing regularization information to the selected extension node
31973213

3198-
#handling regularization in complement below
3214+
#handling regularization in complement below:
31993215
atom_labeling_in_comp_rxns_set, unlabeled_atoms_in_comp_rxns = self.get_compliment_reg_dim(parent, template_rxn_map, ext[0], ext[1])
3200-
#print(ext[0].atoms[ext[4][0]], ext[0].atoms[ext[4][0]].label, ext[1].atoms[ext[4][0]], ext[1].atoms[ext[4][0]].label)
32013216

32023217
#regularize the atom in which the extension was performed on
3203-
if ext[1] is not None:
3204-
if ext[1].atoms[ext[4][0]].label=='':
3205-
#extension was performed on an unlabeled atom
3206-
limited_atomtypes_comp = set(ext[1].atoms[ext[4][0]].atomtype).intersection(set(unlabeled_atoms_in_comp_rxns))
3207-
#print(ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp))
3208-
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp)]
3209-
else:
3210-
adjusted_index = int(ext[1].atoms[ext[4][0]].label.replace('*','')) #i.e. ext[4]= (3,), ext[4][0] = 3, ext[0].atoms[3]=<GroupAtom [*5 'N', 'C']>, ext[0].atoms[3].label = '*5'
3211-
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(atom_labeling_in_comp_rxns_set[adjusted_index])]
3212-
3213-
#make sure the rest of the atoms in the extension take on the same regularization dimensions as the parent.
3218+
if ext[1].atoms[ext[4][0]].label=='':
3219+
#extension was performed on an unlabeled atom, so pass in regularization dimensions that are at least limited to the atomtypes of all the unlabeled atoms
3220+
limited_atomtypes_comp = set(ext[1].atoms[ext[4][0]].atomtype).intersection(set(unlabeled_atoms_in_comp_rxns))
3221+
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp)]
3222+
else:
3223+
#extension was performed on a labeled atom. For each labeled atom, we know all the atomtypes in the training reactions. Let's limit regularization dimensions to these known atomtypes
3224+
adjusted_index = int(ext[1].atoms[ext[4][0]].label.replace('*','')) #i.e. ext[4]= (3,), ext[4][0] = 3, ext[0].atoms[3]=<GroupAtom [*5 'N', 'C']>, ext[0].atoms[3].label = '*5'
3225+
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(atom_labeling_in_comp_rxns_set[adjusted_index])]
3226+
3227+
#make sure the rest of the atoms in the extension take on the same regularization dimensions as the parent. Ensures subgraph isomorphism.
32143228
for i, parent_atm in enumerate(parent.item.atoms):
32153229
if i == ext[4][0]:
3216-
print('extension atom')
3217-
continue #this is the atom that the extension is focused on, handled above
3230+
continue #this is the atom that the extension is focused on, handled above if the extension was an 'atomExt' extension type
32183231
elif parent_atm.reg_dim_atm[1]==[]:
3219-
print('parent atm reg_dim is empty')
32203232
continue #only take on regularization dimensions of parent if there is some
32213233
else:
32223234
ext[0].atoms[i].reg_dim_atm[1] = parent_atm.reg_dim_atm[1] #passing regularization info from parent to the extension
3223-
if ext[1] is not None:
3235+
if ext[1] is not None: #check if there's a complimentary node
32243236
ext[1].atoms[i].reg_dim_atm[1] = parent_atm.reg_dim_atm[1] #passing regularization info from parent to the complimentary extension
32253237

3226-
3227-
#print(ext[1].atoms[i].atomtype,' ', ext[1].atoms[i].reg_dim_atm[1])
3228-
32293238

3230-
3231-
# print(ext[1].atoms[ext[4][0]].atomtype, )
3232-
# ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, ext[1].atoms[ext[4][0]].atomtype] #must also pass regularization information to the compliment
32333239
if ext[3] == 'elExt':
32343240
ext[0].atoms[ext[4][0]].reg_dim_u = [ext[0].atoms[ext[4][0]].radical_electrons,
32353241
ext[0].atoms[ext[4][0]].radical_electrons]
@@ -3316,8 +3322,6 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
33163322
if complement:
33173323
template_rxn_map[parent.label] = []
33183324
template_rxn_map[cextname] = comp_entries
3319-
if cextname=="Root_N-4R!H->O":
3320-
print(f'end of extend_node: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
33213325
else:
33223326
template_rxn_map[parent.label] = comp_entries
33233327

@@ -3389,8 +3393,6 @@ def rxnkey(rxn):
33893393
logging.error("built tree with {} nodes".format(len(list(self.groups.entries))))
33903394

33913395
self.auto_generated = True
3392-
print(f'end of generate_tree: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
3393-
33943396

33953397
def get_rxn_batches(self, rxns, T=1000.0, max_batch_size=800, outlier_fraction=0.02, stratum_num=8):
33963398
"""
@@ -3558,8 +3560,6 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
35583560
continue
35593561
boo2 = self.extend_node(entry, template_rxn_map, obj, T, iter_max=extension_iter_max, iter_item_cap=extension_iter_item_cap)
35603562
if boo2: # extended node so restart while loop
3561-
# if "Root_N-4R!H->O" in template_rxn_map.keys():
3562-
# print(f'at boo2: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
35633563
break
35643564
else: # no extensions could be generated since all reactions were identical
35653565
mult_completed_nodes.append(entry)
@@ -3591,8 +3591,6 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
35913591
entry.parent = self.groups.entries[pname]
35923592
entry.parent.children.append(entry)
35933593

3594-
print(f'end of make_tree_nodes: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
3595-
35963594
return
35973595

35983596
def _absorb_process(self, p, conn, name):

0 commit comments

Comments
 (0)