Skip to content

Commit 2edc09c

Browse files
Nora KhalilNora Khalil
authored andcommitted
first attempt at fixing regularization
1 parent d65c5f3 commit 2edc09c

File tree

1 file changed

+104
-21
lines changed

1 file changed

+104
-21
lines changed

rmgpy/data/kinetics/family.py

Lines changed: 104 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,10 +2953,17 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29532953
out_exts[-1].append(exts[i]) # this extension splits reactions (optimization dim)
29542954
if typ == 'atomExt':
29552955
reg_dict[(typ, indc)][0].extend(grp2.atoms[indc[0]].atomtype)
2956+
#still pass in the regularization data to the grp2. However, this doesn't take care of the grpc
2957+
#reg_dict[(typ, indc)][1].extend(grp2.atoms[indc[0]].atomtype)
2958+
#now take care of the compliment:
2959+
29562960
elif typ == 'elExt':
29572961
reg_dict[(typ, indc)][0].extend(grp2.atoms[indc[0]].radical_electrons)
2962+
#reg_dict[(typ, indc)][1].extend(grp2.atoms[indc[0]].radical_electrons)
29582963
elif typ == 'bondExt':
29592964
reg_dict[(typ, indc)][0].extend(grp2.get_bond(grp2.atoms[indc[0]], grp2.atoms[indc[1]]).order)
2965+
#reg_dict[(typ, indc)][1].extend(grp2.get_bond(grp2.atoms[indc[0]], grp2.atoms[indc[1]]).order)
2966+
29602967

29612968
elif boo: # this extension matches all reactions (regularization dim)
29622969
if typ == 'intNewBondExt' or typ == 'extNewBondExt':
@@ -2983,10 +2990,12 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29832990
reg_val = reg_dict[(typr, indcr)]
29842991

29852992
if first_time and parent.children == []:
2986-
# parent
2993+
2994+
#parent
29872995
if typr != 'intNewBondExt' and typr != 'extNewBondExt': # these dimensions should be regularized
29882996
if typr == 'atomExt':
2989-
grp.atoms[indcr[0]].reg_dim_atm = list(reg_val)
2997+
pass
2998+
#grp.atoms[indcr[0]].reg_dim_atm = list(reg_val)
29902999
elif typr == 'elExt':
29913000
grp.atoms[indcr[0]].reg_dim_u = list(reg_val)
29923001
elif typr == 'ringExt':
@@ -3080,6 +3089,32 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
30803089
out.extend(x)
30813090

30823091
return out, gave_up_split
3092+
3093+
def get_compliment_reg_dim(self, parent, template_rxn_map, new_ext, comp_ext):
3094+
rxns_from_parent = template_rxn_map[parent.label]
3095+
new_ext_rxns, comp_ext_rxns, _ = self._split_reactions(rxns_from_parent, new_ext)
3096+
atom_labeling_in_comp_rxns = dict()
3097+
unlabeled_atoms_in_comp_rxns = []
3098+
for rxn_c in comp_ext_rxns:
3099+
for reactant in rxn_c.reactants:
3100+
for mol in reactant.molecule:
3101+
for atm in mol.atoms:
3102+
if atm.label == '':
3103+
#this atom was unlabeled
3104+
unlabeled_atmtype = ATOMTYPES[atm.symbol]
3105+
if unlabeled_atmtype not in unlabeled_atoms_in_comp_rxns:
3106+
unlabeled_atoms_in_comp_rxns.append(unlabeled_atmtype)
3107+
else:
3108+
atm_label = int(atm.label.replace('*',''))
3109+
if atm_label not in atom_labeling_in_comp_rxns.keys():
3110+
atom_labeling_in_comp_rxns[atm_label] = [ATOMTYPES[atm.symbol]]
3111+
else:
3112+
existing_atomtypes = atom_labeling_in_comp_rxns[atm_label]
3113+
existing_atomtypes.append(ATOMTYPES[atm.symbol])
3114+
#print(f'count of missing * is {count}')
3115+
atom_labeling_in_comp_rxns_set = {k: set(v) for k, v in atom_labeling_in_comp_rxns.items()}
3116+
3117+
return atom_labeling_in_comp_rxns_set, unlabeled_atoms_in_comp_rxns
30833118

30843119
def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.inf, iter_item_cap=np.inf):
30853120
"""
@@ -3156,9 +3191,46 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
31563191

31573192
extname = ext[2]
31583193

3194+
print(extname, ext[3])
31593195
if ext[3] == 'atomExt':
3160-
ext[0].atoms[ext[4][0]].reg_dim_atm = [ext[0].atoms[ext[4][0]].atomtype, ext[0].atoms[ext[4][0]].atomtype]
3161-
elif ext[3] == 'elExt':
3196+
ext[0].atoms[ext[4][0]].reg_dim_atm = [ext[0].atoms[ext[4][0]].atomtype, ext[0].atoms[ext[4][0]].atomtype] #passing regularization information to the selected extension node
3197+
3198+
#handling regularization in complement below
3199+
atom_labeling_in_comp_rxns_set, unlabeled_atoms_in_comp_rxns = self.get_compliment_reg_dim(parent, template_rxn_map, ext[0], ext[1])
3200+
#print(ext[0].atoms[ext[4][0]], ext[0].atoms[ext[4][0]].label, ext[1].atoms[ext[4][0]], ext[1].atoms[ext[4][0]].label)
3201+
3202+
#regularize the atom in which the extension was performed on
3203+
if ext[1] is not None:
3204+
if ext[1].atoms[ext[4][0]].label=='':
3205+
#extension was performed on an unlabeled atom
3206+
limited_atomtypes_comp = set(ext[1].atoms[ext[4][0]].atomtype).intersection(set(unlabeled_atoms_in_comp_rxns))
3207+
#print(ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp))
3208+
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp)]
3209+
else:
3210+
adjusted_index = int(ext[1].atoms[ext[4][0]].label.replace('*','')) #i.e. ext[4]= (3,), ext[4][0] = 3, ext[0].atoms[3]=<GroupAtom [*5 'N', 'C']>, ext[0].atoms[3].label = '*5'
3211+
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(atom_labeling_in_comp_rxns_set[adjusted_index])]
3212+
3213+
#make sure the rest of the atoms in the extension take on the same regularization dimensions as the parent.
3214+
for i, parent_atm in enumerate(parent.item.atoms):
3215+
if i == ext[4][0]:
3216+
print('extension atom')
3217+
continue #this is the atom that the extension is focused on, handled above
3218+
elif parent_atm.reg_dim_atm[1]==[]:
3219+
print('parent atm reg_dim is empty')
3220+
continue #only take on regularization dimensions of parent if there is some
3221+
else:
3222+
ext[0].atoms[i].reg_dim_atm[1] = parent_atm.reg_dim_atm[1] #passing regularization info from parent to the extension
3223+
if ext[1] is not None:
3224+
ext[1].atoms[i].reg_dim_atm[1] = parent_atm.reg_dim_atm[1] #passing regularization info from parent to the complimentary extension
3225+
3226+
3227+
#print(ext[1].atoms[i].atomtype,' ', ext[1].atoms[i].reg_dim_atm[1])
3228+
3229+
3230+
3231+
# print(ext[1].atoms[ext[4][0]].atomtype, )
3232+
# ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, ext[1].atoms[ext[4][0]].atomtype] #must also pass regularization information to the compliment
3233+
if ext[3] == 'elExt':
31623234
ext[0].atoms[ext[4][0]].reg_dim_u = [ext[0].atoms[ext[4][0]].radical_electrons,
31633235
ext[0].atoms[ext[4][0]].radical_electrons]
31643236

@@ -3244,8 +3316,11 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
32443316
if complement:
32453317
template_rxn_map[parent.label] = []
32463318
template_rxn_map[cextname] = comp_entries
3319+
if cextname=="Root_N-4R!H->O":
3320+
print(f'end of extend_node: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
32473321
else:
32483322
template_rxn_map[parent.label] = comp_entries
3323+
32493324
return True
32503325

32513326
def generate_tree(self, rxns=None, obj=None, thermo_database=None, T=1000.0, nprocs=1, min_splitable_entry_num=2,
@@ -3314,6 +3389,8 @@ def rxnkey(rxn):
33143389
logging.error("built tree with {} nodes".format(len(list(self.groups.entries))))
33153390

33163391
self.auto_generated = True
3392+
print(f'end of generate_tree: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
3393+
33173394

33183395
def get_rxn_batches(self, rxns, T=1000.0, max_batch_size=800, outlier_fraction=0.02, stratum_num=8):
33193396
"""
@@ -3481,6 +3558,8 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
34813558
continue
34823559
boo2 = self.extend_node(entry, template_rxn_map, obj, T, iter_max=extension_iter_max, iter_item_cap=extension_iter_item_cap)
34833560
if boo2: # extended node so restart while loop
3561+
# if "Root_N-4R!H->O" in template_rxn_map.keys():
3562+
# print(f'at boo2: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
34843563
break
34853564
else: # no extensions could be generated since all reactions were identical
34863565
mult_completed_nodes.append(entry)
@@ -3512,6 +3591,8 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
35123591
entry.parent = self.groups.entries[pname]
35133592
entry.parent.children.append(entry)
35143593

3594+
print(f'end of make_tree_nodes: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
3595+
35153596
return
35163597

35173598
def _absorb_process(self, p, conn, name):
@@ -3777,9 +3858,10 @@ def simple_regularization(self, node, template_rxn_map, test=True):
37773858
self.simple_regularization(child, template_rxn_map)
37783859

37793860
grp = node.item
3861+
parent = node.parent.item
37803862
rxns = template_rxn_map[node.label]
37813863

3782-
R = ['H', 'C', 'N', 'O', 'Si', 'S', 'Cl', 'F', 'Br'] # set of possible R elements/atoms
3864+
R = ['H', 'C', 'N', 'O', 'Si', 'S', 'Cl', 'F', 'Br', 'Li'] # set of possible R elements/atoms
37833865
R = [ATOMTYPES[x] for x in R]
37843866

37853867
RnH = R[:]
@@ -3794,14 +3876,15 @@ def simple_regularization(self, node, template_rxn_map, test=True):
37943876
for i, atm1 in enumerate(grp.atoms):
37953877

37963878
skip = False
3797-
if node.children == []: # if the atoms or bonds are graphically indistinguishable don't regularize
3798-
bdpairs = {(atm, tuple(bd.order)) for atm, bd in atm1.bonds.items()}
3799-
for atm2 in grp.atoms:
3800-
if atm1 is not atm2 and atm1.atomtype == atm2.atomtype and len(atm1.bonds) == len(atm2.bonds):
3801-
bdpairs2 = {(atm, tuple(bd.order)) for atm, bd in atm2.bonds.items()}
3802-
if bdpairs == bdpairs2:
3803-
skip = True
3804-
indistinguishable.append(i)
3879+
if i <= len(parent.atoms)-1: #if we aren't at an atom definition that the parent node doesn't have (due to this child being an extNewBondExt type)
3880+
if node.children == [] and parent.atoms[i].reg_dim_atm[1]==[]: # if the atoms or bonds are graphically indistinguishable don't regularize
3881+
bdpairs = {(atm, tuple(bd.order)) for atm, bd in atm1.bonds.items()}
3882+
for atm2 in grp.atoms:
3883+
if atm1 is not atm2 and atm1.atomtype == atm2.atomtype and len(atm1.bonds) == len(atm2.bonds):
3884+
bdpairs2 = {(atm, tuple(bd.order)) for atm, bd in atm2.bonds.items()}
3885+
if bdpairs == bdpairs2:
3886+
skip = True
3887+
indistinguishable.append(i)
38053888

38063889
if not skip and atm1.reg_dim_atm[1] != [] and set(atm1.reg_dim_atm[1]) != set(atm1.atomtype):
38073890
atyp = atm1.atomtype
@@ -3813,14 +3896,14 @@ def simple_regularization(self, node, template_rxn_map, test=True):
38133896

38143897
vals = list(set(atyp) & set(atm1.reg_dim_atm[1]))
38153898
assert vals != [], 'cannot regularize to empty'
3816-
if all([set(child.item.atoms[i].atomtype) <= set(vals) for child in node.children]):
3817-
if not test:
3818-
atm1.atomtype = vals
3819-
else:
3820-
oldvals = atm1.atomtype
3821-
atm1.atomtype = vals
3822-
if not self.rxns_match_node(node, rxns):
3823-
atm1.atomtype = oldvals
3899+
#if all([set(child.item.atoms[i].atomtype) <= set(vals) for child in node.children]):
3900+
if not test:
3901+
atm1.atomtype = vals
3902+
else:
3903+
oldvals = atm1.atomtype
3904+
atm1.atomtype = vals
3905+
if not self.rxns_match_node(node, rxns):
3906+
atm1.atomtype = oldvals
38243907

38253908
if not skip and atm1.reg_dim_u[1] != [] and set(atm1.reg_dim_u[1]) != set(atm1.radical_electrons):
38263909
if len(atm1.radical_electrons) == 1:

0 commit comments

Comments
 (0)