Skip to content

Commit 7b73d0c

Browse files
Nora KhalilNora Khalil
authored andcommitted
Scrapping previous changes/attempts to fix bug. Starting fix that allows problematic nodes to generate atomExt extensions that aren't node splitting if the optimization dimension of the regularization dictionary is more specific than the atomtype at the atom of interest being extended. For example, if the atomtype of an atom labeled *5 is [Si, F, Li, N, C, P, S] and the regulatization dictionary has an optimization dimension that narrows down these atomtypes (i.e. reg_dim_atm[0] = <N,C>), then we can allow for atomExt extensions that change *5's atomtype to be [N,C] (rather than just [N] or just [C]). This way, we have an extension that narrows down *5 to <N,C> from [Si, F, Li, N, C, P, S] but also matches all of the training reactions at the node, so the regularization information (reg_dim_atm{1]) is passed to the group.
1 parent 349625a commit 7b73d0c

File tree

2 files changed

+50
-97
lines changed

2 files changed

+50
-97
lines changed

rmgpy/data/kinetics/family.py

Lines changed: 21 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -2958,7 +2958,6 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29582958
elif typ == 'bondExt':
29592959
reg_dict[(typ, indc)][0].extend(grp2.get_bond(grp2.atoms[indc[0]], grp2.atoms[indc[1]]).order)
29602960

2961-
29622961
elif boo: # this extension matches all reactions (regularization dim)
29632962
if typ == 'intNewBondExt' or typ == 'extNewBondExt':
29642963
# these are bond formation extensions, we want to expand these until we get splits
@@ -2984,11 +2983,10 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29842983
reg_val = reg_dict[(typr, indcr)]
29852984

29862985
if first_time and parent.children == []:
2987-
2988-
#parent
2986+
# parent
29892987
if typr != 'intNewBondExt' and typr != 'extNewBondExt': # these dimensions should be regularized
29902988
if typr == 'atomExt':
2991-
pass #no longer passing regularization info to the parent here. Doing this instead in `extend_node`
2989+
grp.atoms[indcr[0]].reg_dim_atm = list(reg_val)
29922990
elif typr == 'elExt':
29932991
grp.atoms[indcr[0]].reg_dim_u = list(reg_val)
29942992
elif typr == 'ringExt':
@@ -3082,49 +3080,6 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
30823080
out.extend(x)
30833081

30843082
return out, gave_up_split
3085-
3086-
def get_compliment_reg_dim(self, parent, template_rxn_map, new_ext, comp_ext):
3087-
"""
3088-
Function takes in a parent node (`parent`), an extension node (`new_ext`) and its compliment (`comp_ext`).
3089-
Reactions of the parent node are split to extension and compliment.
3090-
Iterating over all the reactions that fit the complimentary node, the atomtypes of each labeled atom in each reaction are saved to a dictionary `atom_labeling_in_comp_rxns`,
3091-
where the key is the integer of the atom label (i.e. 5 in '*5') and the value is a set of all the atomtypes in all the complimentary reactions with that atom label.
3092-
3093-
Additionally, when iterating over all the reactions that fit the complimentary node, the atomtypes of each unlabeled atom in each reaction are saved to a list `unlabeled_atoms_in_comp_rxns`.
3094-
"""
3095-
3096-
3097-
assert comp_ext is not None, "This extension does not include a complimentary node. Cannot get regularization dimensions of complimentary node."
3098-
3099-
#divide parent reactions into the extension node and its compliment
3100-
rxns_from_parent = template_rxn_map[parent.label]
3101-
new_ext_rxns, comp_ext_rxns, _ = self._split_reactions(rxns_from_parent, new_ext)
3102-
3103-
#for saving data
3104-
atom_labeling_in_comp_rxns = dict()
3105-
unlabeled_atoms_in_comp_rxns = []
3106-
3107-
#iterate through each complimentary rxn
3108-
for rxn_c in comp_ext_rxns:
3109-
for reactant in rxn_c.reactants:
3110-
for mol in reactant.molecule:
3111-
for atm in mol.atoms:
3112-
if atm.label == '':
3113-
#this atom was unlabeled
3114-
unlabeled_atmtype = atm.atomtype
3115-
if unlabeled_atmtype not in unlabeled_atoms_in_comp_rxns:
3116-
unlabeled_atoms_in_comp_rxns.append(unlabeled_atmtype)
3117-
else:
3118-
#this is a labeled atom
3119-
atm_label = int(atm.label.replace('*',''))
3120-
if atm_label not in atom_labeling_in_comp_rxns.keys():
3121-
atom_labeling_in_comp_rxns[atm_label] = [atm.atomtype]
3122-
else:
3123-
existing_atomtypes = atom_labeling_in_comp_rxns[atm_label]
3124-
existing_atomtypes.append(atm.atomtype)
3125-
atom_labeling_in_comp_rxns_set = {k: set(v) for k, v in atom_labeling_in_comp_rxns.items()}
3126-
3127-
return atom_labeling_in_comp_rxns_set, unlabeled_atoms_in_comp_rxns
31283083

31293084
def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.inf, iter_item_cap=np.inf):
31303085
"""
@@ -3201,36 +3156,9 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
32013156

32023157
extname = ext[2]
32033158

3204-
32053159
if ext[3] == 'atomExt':
3206-
ext[0].atoms[ext[4][0]].reg_dim_atm = [ext[0].atoms[ext[4][0]].atomtype, ext[0].atoms[ext[4][0]].atomtype] #passing regularization information to the selected extension node
3207-
3208-
#handling regularization in complement below:
3209-
atom_labeling_in_comp_rxns_set, unlabeled_atoms_in_comp_rxns = self.get_compliment_reg_dim(parent, template_rxn_map, ext[0], ext[1])
3210-
3211-
#regularize the atom in which the extension was performed on
3212-
if ext[1].atoms[ext[4][0]].label=='':
3213-
#extension was performed on an unlabeled atom, so pass in regularization dimensions that are at least limited to the atomtypes of all the unlabeled atoms
3214-
limited_atomtypes_comp = set(ext[1].atoms[ext[4][0]].atomtype).intersection(set(unlabeled_atoms_in_comp_rxns))
3215-
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp)]
3216-
else:
3217-
#extension was performed on a labeled atom. For each labeled atom, we know all the atomtypes in the training reactions. Let's limit regularization dimensions to these known atomtypes
3218-
adjusted_index = int(ext[1].atoms[ext[4][0]].label.replace('*','')) #i.e. ext[4]= (3,), ext[4][0] = 3, ext[0].atoms[3]=<GroupAtom [*5 'N', 'C']>, ext[0].atoms[3].label = '*5'
3219-
ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, list(atom_labeling_in_comp_rxns_set[adjusted_index])]
3220-
3221-
#make sure the rest of the atoms in the extension take on the same regularization dimensions as the parent. Ensures subgraph isomorphism.
3222-
for i, parent_atm in enumerate(parent.item.atoms):
3223-
if i == ext[4][0]:
3224-
continue #this is the atom that the extension is focused on, handled above if the extension was an 'atomExt' extension type
3225-
elif parent_atm.reg_dim_atm[1]==[]:
3226-
continue #only take on regularization dimensions of parent if there is some
3227-
else:
3228-
ext[0].atoms[i].reg_dim_atm[1] = parent_atm.reg_dim_atm[1] #passing regularization info from parent to the extension
3229-
if ext[1] is not None: #check if there's a complimentary node
3230-
ext[1].atoms[i].reg_dim_atm[1] = parent_atm.reg_dim_atm[1] #passing regularization info from parent to the complimentary extension
3231-
3232-
3233-
if ext[3] == 'elExt':
3160+
ext[0].atoms[ext[4][0]].reg_dim_atm = [ext[0].atoms[ext[4][0]].atomtype, ext[0].atoms[ext[4][0]].atomtype]
3161+
elif ext[3] == 'elExt':
32343162
ext[0].atoms[ext[4][0]].reg_dim_u = [ext[0].atoms[ext[4][0]].radical_electrons,
32353163
ext[0].atoms[ext[4][0]].radical_electrons]
32363164

@@ -3318,7 +3246,6 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
33183246
template_rxn_map[cextname] = comp_entries
33193247
else:
33203248
template_rxn_map[parent.label] = comp_entries
3321-
33223249
return True
33233250

33243251
def generate_tree(self, rxns=None, obj=None, thermo_database=None, T=1000.0, nprocs=1, min_splitable_entry_num=2,
@@ -3850,10 +3777,9 @@ def simple_regularization(self, node, template_rxn_map, test=True):
38503777
self.simple_regularization(child, template_rxn_map)
38513778

38523779
grp = node.item
3853-
parent = node.parent.item
38543780
rxns = template_rxn_map[node.label]
38553781

3856-
R = ['H', 'C', 'N', 'O', 'Si', 'S', 'Cl', 'F', 'Br', 'Li'] # set of possible R elements/atoms
3782+
R = ['H', 'C', 'N', 'O', 'Si', 'S', 'Cl', 'F', 'Br'] # set of possible R elements/atoms
38573783
R = [ATOMTYPES[x] for x in R]
38583784

38593785
RnH = R[:]
@@ -3868,15 +3794,14 @@ def simple_regularization(self, node, template_rxn_map, test=True):
38683794
for i, atm1 in enumerate(grp.atoms):
38693795

38703796
skip = False
3871-
if i <= len(parent.atoms)-1: #if we aren't at an atom definition that the parent node doesn't have (due to this child being an extNewBondExt type)
3872-
if node.children == [] and parent.atoms[i].reg_dim_atm[1]==[]: # if the atoms or bonds are graphically indistinguishable don't regularize
3873-
bdpairs = {(atm, tuple(bd.order)) for atm, bd in atm1.bonds.items()}
3874-
for atm2 in grp.atoms:
3875-
if atm1 is not atm2 and atm1.atomtype == atm2.atomtype and len(atm1.bonds) == len(atm2.bonds):
3876-
bdpairs2 = {(atm, tuple(bd.order)) for atm, bd in atm2.bonds.items()}
3877-
if bdpairs == bdpairs2:
3878-
skip = True
3879-
indistinguishable.append(i)
3797+
if node.children == []: # if the atoms or bonds are graphically indistinguishable don't regularize
3798+
bdpairs = {(atm, tuple(bd.order)) for atm, bd in atm1.bonds.items()}
3799+
for atm2 in grp.atoms:
3800+
if atm1 is not atm2 and atm1.atomtype == atm2.atomtype and len(atm1.bonds) == len(atm2.bonds):
3801+
bdpairs2 = {(atm, tuple(bd.order)) for atm, bd in atm2.bonds.items()}
3802+
if bdpairs == bdpairs2:
3803+
skip = True
3804+
indistinguishable.append(i)
38803805

38813806
if not skip and atm1.reg_dim_atm[1] != [] and set(atm1.reg_dim_atm[1]) != set(atm1.atomtype):
38823807
atyp = atm1.atomtype
@@ -3888,14 +3813,14 @@ def simple_regularization(self, node, template_rxn_map, test=True):
38883813

38893814
vals = list(set(atyp) & set(atm1.reg_dim_atm[1]))
38903815
assert vals != [], 'cannot regularize to empty'
3891-
#if all([set(child.item.atoms[i].atomtype) <= set(vals) for child in node.children]):
3892-
if not test:
3893-
atm1.atomtype = vals
3894-
else:
3895-
oldvals = atm1.atomtype
3896-
atm1.atomtype = vals
3897-
if not self.rxns_match_node(node, rxns):
3898-
atm1.atomtype = oldvals
3816+
if all([set(child.item.atoms[i].atomtype) <= set(vals) for child in node.children]):
3817+
if not test:
3818+
atm1.atomtype = vals
3819+
else:
3820+
oldvals = atm1.atomtype
3821+
atm1.atomtype = vals
3822+
if not self.rxns_match_node(node, rxns):
3823+
atm1.atomtype = oldvals
38993824

39003825
if not skip and atm1.reg_dim_u[1] != [] and set(atm1.reg_dim_u[1]) != set(atm1.radical_electrons):
39013826
if len(atm1.radical_electrons) == 1:

rmgpy/molecule/group.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1571,7 +1571,7 @@ def get_extensions(self, r=None, r_bonds=None, r_un=None, basename='', atm_ind=N
15711571
"""
15721572
cython.declare(atoms=list, atm=GroupAtom, atm2=GroupAtom, bd=GroupBond, i=int, j=int,
15731573
extents=list, RnH=list, typ=list)
1574-
1574+
print('im in')
15751575
extents = []
15761576
if r_bonds is None:
15771577
r_bonds = [1, 1.5, 2, 3, 4]
@@ -1684,6 +1684,7 @@ def get_extensions(self, r=None, r_bonds=None, r_un=None, basename='', atm_ind=N
16841684
elif typ[0].label == 'R!H':
16851685
extents.extend(self.specify_atom_extensions(i, basename, list(set(atm.reg_dim_atm[0]) & set(r))))
16861686
else:
1687+
print(set(typ), set(atm.reg_dim_atm[0]), list(set(typ) & set(atm.reg_dim_atm[0])))
16871688
extents.extend(self.specify_atom_extensions(i, basename, list(set(typ) & set(atm.reg_dim_atm[0]))))
16881689
if atm.reg_dim_u == []:
16891690
if len(atm.radical_electrons) != 1:
@@ -1726,6 +1727,8 @@ def specify_atom_extensions(self, i, basename, r):
17261727

17271728
grps = []
17281729
Rset = set(r)
1730+
1731+
#consider node splitting
17291732
for item in r:
17301733
grp = deepcopy(self)
17311734
grpc = deepcopy(self)
@@ -1751,6 +1754,31 @@ def specify_atom_extensions(self, i, basename, r):
17511754
grps.append(
17521755
(grp, grpc, basename + '_' + str(i + 1) + old_atom_type_str + '->' + item.label, 'atomExt', (i,)))
17531756

1757+
#generate an extension without node splitting
1758+
if len(self.atoms[i].atomtype)>len(Rset):
1759+
if all(r in self.atoms[i].atomtype for r in Rset):
1760+
#that means even if we update the atomtype of the atom to the Rset, it will still be a specification
1761+
grp = deepcopy(self)
1762+
grp.atoms[i].atomtype = list(Rset)
1763+
1764+
#rename
1765+
old_atom_type = grp.atoms[i].atomtype
1766+
1767+
if len(old_atom_type) > 1:
1768+
labelList = []
1769+
old_atom_type_str = ''
1770+
for k in old_atom_type:
1771+
labelList.append(k.label)
1772+
for p in sorted(labelList):
1773+
old_atom_type_str += p
1774+
elif len(old_atom_type) == 0:
1775+
old_atom_type_str = ""
1776+
else:
1777+
old_atom_type_str = old_atom_type[0].label
1778+
1779+
grps.append(
1780+
(grp, None, basename + '_' + str(i + 1) + old_atom_type_str + '->' + ''.join(r.label for r in Rset), 'atomExt', (i,)))
1781+
17541782
return grps
17551783

17561784
def specify_ring_extensions(self, i, basename):

0 commit comments

Comments
 (0)