@@ -2994,8 +2994,7 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29942994 #parent
29952995 if typr != 'intNewBondExt' and typr != 'extNewBondExt' : # these dimensions should be regularized
29962996 if typr == 'atomExt' :
2997- pass
2998- #grp.atoms[indcr[0]].reg_dim_atm = list(reg_val)
2997+ pass #no longer passing regularization info to the parent here. Doing this instead in `extend_node`
29992998 elif typr == 'elExt' :
30002999 grp .atoms [indcr [0 ]].reg_dim_u = list (reg_val )
30013000 elif typr == 'ringExt' :
@@ -3091,10 +3090,27 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
30913090 return out , gave_up_split
30923091
30933092 def get_compliment_reg_dim (self , parent , template_rxn_map , new_ext , comp_ext ):
3093+ """
3094+ Function takes in a parent node (`parent`), an extension node (`new_ext`) and its compliment (`comp_ext`).
3095+ Reactions of the parent node are split to extension and compliment.
3096+ Iterating over all the reactions that fit the complimentary node, the atomtypes of each labeled atom in each reaction are saved to a dictionary `atom_labeling_in_comp_rxns`,
3097+ where the key is the integer of the atom label (i.e. 5 in '*5') and the value is a set of all the atomtypes in all the complimentary reactions with that atom label.
3098+
3099+ Additionally, when iterating over all the reactions that fit the complimentary node, the atomtypes of each unlabeled atom in each reaction are saved to a list `unlabeled_atoms_in_comp_rxns`.
3100+ """
3101+
3102+
3103+ assert comp_ext is not None , "This extension does not include a complimentary node. Cannot get regularization dimensions of complimentary node."
3104+
3105+ #divide parent reactions into the extension node and its compliment
30943106 rxns_from_parent = template_rxn_map [parent .label ]
30953107 new_ext_rxns , comp_ext_rxns , _ = self ._split_reactions (rxns_from_parent , new_ext )
3108+
3109+ #for saving data
30963110 atom_labeling_in_comp_rxns = dict ()
30973111 unlabeled_atoms_in_comp_rxns = []
3112+
3113+ #iterate through each complimentary rxn
30983114 for rxn_c in comp_ext_rxns :
30993115 for reactant in rxn_c .reactants :
31003116 for mol in reactant .molecule :
@@ -3105,13 +3121,13 @@ def get_compliment_reg_dim(self, parent, template_rxn_map, new_ext, comp_ext):
31053121 if unlabeled_atmtype not in unlabeled_atoms_in_comp_rxns :
31063122 unlabeled_atoms_in_comp_rxns .append (unlabeled_atmtype )
31073123 else :
3124+ #this is a labeled atom
31083125 atm_label = int (atm .label .replace ('*' ,'' ))
31093126 if atm_label not in atom_labeling_in_comp_rxns .keys ():
31103127 atom_labeling_in_comp_rxns [atm_label ] = [ATOMTYPES [atm .symbol ]]
31113128 else :
31123129 existing_atomtypes = atom_labeling_in_comp_rxns [atm_label ]
31133130 existing_atomtypes .append (ATOMTYPES [atm .symbol ])
3114- #print(f'count of missing * is {count}')
31153131 atom_labeling_in_comp_rxns_set = {k : set (v ) for k , v in atom_labeling_in_comp_rxns .items ()}
31163132
31173133 return atom_labeling_in_comp_rxns_set , unlabeled_atoms_in_comp_rxns
@@ -3191,45 +3207,35 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
31913207
31923208 extname = ext [2 ]
31933209
3194- print ( extname , ext [ 3 ])
3210+
31953211 if ext [3 ] == 'atomExt' :
31963212 ext [0 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [0 ].atoms [ext [4 ][0 ]].atomtype , ext [0 ].atoms [ext [4 ][0 ]].atomtype ] #passing regularization information to the selected extension node
31973213
3198- #handling regularization in complement below
3214+ #handling regularization in complement below:
31993215 atom_labeling_in_comp_rxns_set , unlabeled_atoms_in_comp_rxns = self .get_compliment_reg_dim (parent , template_rxn_map , ext [0 ], ext [1 ])
3200- #print(ext[0].atoms[ext[4][0]], ext[0].atoms[ext[4][0]].label, ext[1].atoms[ext[4][0]], ext[1].atoms[ext[4][0]].label)
32013216
32023217 #regularize the atom in which the extension was performed on
3203- if ext [1 ] is not None :
3204- if ext [1 ].atoms [ext [4 ][0 ]].label == '' :
3205- #extension was performed on an unlabeled atom
3206- limited_atomtypes_comp = set (ext [1 ].atoms [ext [4 ][0 ]].atomtype ).intersection (set (unlabeled_atoms_in_comp_rxns ))
3207- #print(ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp))
3208- ext [1 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [1 ].atoms [ext [4 ][0 ]].atomtype , list (limited_atomtypes_comp )]
3209- else :
3210- adjusted_index = int (ext [1 ].atoms [ext [4 ][0 ]].label .replace ('*' ,'' )) #i.e. ext[4]= (3,), ext[4][0] = 3, ext[0].atoms[3]=<GroupAtom [*5 'N', 'C']>, ext[0].atoms[3].label = '*5'
3211- ext [1 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [1 ].atoms [ext [4 ][0 ]].atomtype , list (atom_labeling_in_comp_rxns_set [adjusted_index ])]
3212-
3213- #make sure the rest of the atoms in the extension take on the same regularization dimensions as the parent.
3218+ if ext [1 ].atoms [ext [4 ][0 ]].label == '' :
3219+ #extension was performed on an unlabeled atom, so pass in regularization dimensions that are at least limited to the atomtypes of all the unlabeled atoms
3220+ limited_atomtypes_comp = set (ext [1 ].atoms [ext [4 ][0 ]].atomtype ).intersection (set (unlabeled_atoms_in_comp_rxns ))
3221+ ext [1 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [1 ].atoms [ext [4 ][0 ]].atomtype , list (limited_atomtypes_comp )]
3222+ else :
3223+ #extension was performed on a labeled atom. For each labeled atom, we know all the atomtypes in the training reactions. Let's limit regularization dimensions to these known atomtypes
3224+ adjusted_index = int (ext [1 ].atoms [ext [4 ][0 ]].label .replace ('*' ,'' )) #i.e. ext[4]= (3,), ext[4][0] = 3, ext[0].atoms[3]=<GroupAtom [*5 'N', 'C']>, ext[0].atoms[3].label = '*5'
3225+ ext [1 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [1 ].atoms [ext [4 ][0 ]].atomtype , list (atom_labeling_in_comp_rxns_set [adjusted_index ])]
3226+
3227+ #make sure the rest of the atoms in the extension take on the same regularization dimensions as the parent. Ensures subgraph isomorphism.
32143228 for i , parent_atm in enumerate (parent .item .atoms ):
32153229 if i == ext [4 ][0 ]:
3216- print ('extension atom' )
3217- continue #this is the atom that the extension is focused on, handled above
3230+ continue #this is the atom that the extension is focused on, handled above if the extension was an 'atomExt' extension type
32183231 elif parent_atm .reg_dim_atm [1 ]== []:
3219- print ('parent atm reg_dim is empty' )
32203232 continue #only take on regularization dimensions of parent if there is some
32213233 else :
32223234 ext [0 ].atoms [i ].reg_dim_atm [1 ] = parent_atm .reg_dim_atm [1 ] #passing regularization info from parent to the extension
3223- if ext [1 ] is not None :
3235+ if ext [1 ] is not None : #check if there's a complimentary node
32243236 ext [1 ].atoms [i ].reg_dim_atm [1 ] = parent_atm .reg_dim_atm [1 ] #passing regularization info from parent to the complimentary extension
32253237
3226-
3227- #print(ext[1].atoms[i].atomtype,' ', ext[1].atoms[i].reg_dim_atm[1])
3228-
32293238
3230-
3231- # print(ext[1].atoms[ext[4][0]].atomtype, )
3232- # ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, ext[1].atoms[ext[4][0]].atomtype] #must also pass regularization information to the compliment
32333239 if ext [3 ] == 'elExt' :
32343240 ext [0 ].atoms [ext [4 ][0 ]].reg_dim_u = [ext [0 ].atoms [ext [4 ][0 ]].radical_electrons ,
32353241 ext [0 ].atoms [ext [4 ][0 ]].radical_electrons ]
@@ -3316,8 +3322,6 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
33163322 if complement :
33173323 template_rxn_map [parent .label ] = []
33183324 template_rxn_map [cextname ] = comp_entries
3319- if cextname == "Root_N-4R!H->O" :
3320- print (f'end of extend_node: { self .groups .entries ["Root_N-4R!H->O" ].item .atoms [3 ].reg_dim_atm } ' )
33213325 else :
33223326 template_rxn_map [parent .label ] = comp_entries
33233327
@@ -3389,8 +3393,6 @@ def rxnkey(rxn):
33893393 logging .error ("built tree with {} nodes" .format (len (list (self .groups .entries ))))
33903394
33913395 self .auto_generated = True
3392- print (f'end of generate_tree: { self .groups .entries ["Root_N-4R!H->O" ].item .atoms [3 ].reg_dim_atm } ' )
3393-
33943396
33953397 def get_rxn_batches (self , rxns , T = 1000.0 , max_batch_size = 800 , outlier_fraction = 0.02 , stratum_num = 8 ):
33963398 """
@@ -3558,8 +3560,6 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
35583560 continue
35593561 boo2 = self .extend_node (entry , template_rxn_map , obj , T , iter_max = extension_iter_max , iter_item_cap = extension_iter_item_cap )
35603562 if boo2 : # extended node so restart while loop
3561- # if "Root_N-4R!H->O" in template_rxn_map.keys():
3562- # print(f'at boo2: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
35633563 break
35643564 else : # no extensions could be generated since all reactions were identical
35653565 mult_completed_nodes .append (entry )
@@ -3591,8 +3591,6 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
35913591 entry .parent = self .groups .entries [pname ]
35923592 entry .parent .children .append (entry )
35933593
3594- print (f'end of make_tree_nodes: { self .groups .entries ["Root_N-4R!H->O" ].item .atoms [3 ].reg_dim_atm } ' )
3595-
35963594 return
35973595
35983596 def _absorb_process (self , p , conn , name ):
0 commit comments