@@ -2953,10 +2953,17 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29532953 out_exts [- 1 ].append (exts [i ]) # this extension splits reactions (optimization dim)
29542954 if typ == 'atomExt' :
29552955 reg_dict [(typ , indc )][0 ].extend (grp2 .atoms [indc [0 ]].atomtype )
2956+ #still pass in the regularization data to the grp2. However, this doesn't take care of the grpc
2957+ #reg_dict[(typ, indc)][1].extend(grp2.atoms[indc[0]].atomtype)
2958+ #now take care of the compliment:
2959+
29562960 elif typ == 'elExt' :
29572961 reg_dict [(typ , indc )][0 ].extend (grp2 .atoms [indc [0 ]].radical_electrons )
2962+ #reg_dict[(typ, indc)][1].extend(grp2.atoms[indc[0]].radical_electrons)
29582963 elif typ == 'bondExt' :
29592964 reg_dict [(typ , indc )][0 ].extend (grp2 .get_bond (grp2 .atoms [indc [0 ]], grp2 .atoms [indc [1 ]]).order )
2965+ #reg_dict[(typ, indc)][1].extend(grp2.get_bond(grp2.atoms[indc[0]], grp2.atoms[indc[1]]).order)
2966+
29602967
29612968 elif boo : # this extension matches all reactions (regularization dim)
29622969 if typ == 'intNewBondExt' or typ == 'extNewBondExt' :
@@ -2983,10 +2990,12 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
29832990 reg_val = reg_dict [(typr , indcr )]
29842991
29852992 if first_time and parent .children == []:
2986- # parent
2993+
2994+ #parent
29872995 if typr != 'intNewBondExt' and typr != 'extNewBondExt' : # these dimensions should be regularized
29882996 if typr == 'atomExt' :
2989- grp .atoms [indcr [0 ]].reg_dim_atm = list (reg_val )
2997+ pass
2998+ #grp.atoms[indcr[0]].reg_dim_atm = list(reg_val)
29902999 elif typr == 'elExt' :
29913000 grp .atoms [indcr [0 ]].reg_dim_u = list (reg_val )
29923001 elif typr == 'ringExt' :
@@ -3080,6 +3089,32 @@ def get_extension_edge(self, parent, template_rxn_map, obj, T, iter_max=np.inf,
30803089 out .extend (x )
30813090
30823091 return out , gave_up_split
3092+
3093+ def get_compliment_reg_dim (self , parent , template_rxn_map , new_ext , comp_ext ):
3094+ rxns_from_parent = template_rxn_map [parent .label ]
3095+ new_ext_rxns , comp_ext_rxns , _ = self ._split_reactions (rxns_from_parent , new_ext )
3096+ atom_labeling_in_comp_rxns = dict ()
3097+ unlabeled_atoms_in_comp_rxns = []
3098+ for rxn_c in comp_ext_rxns :
3099+ for reactant in rxn_c .reactants :
3100+ for mol in reactant .molecule :
3101+ for atm in mol .atoms :
3102+ if atm .label == '' :
3103+ #this atom was unlabeled
3104+ unlabeled_atmtype = ATOMTYPES [atm .symbol ]
3105+ if unlabeled_atmtype not in unlabeled_atoms_in_comp_rxns :
3106+ unlabeled_atoms_in_comp_rxns .append (unlabeled_atmtype )
3107+ else :
3108+ atm_label = int (atm .label .replace ('*' ,'' ))
3109+ if atm_label not in atom_labeling_in_comp_rxns .keys ():
3110+ atom_labeling_in_comp_rxns [atm_label ] = [ATOMTYPES [atm .symbol ]]
3111+ else :
3112+ existing_atomtypes = atom_labeling_in_comp_rxns [atm_label ]
3113+ existing_atomtypes .append (ATOMTYPES [atm .symbol ])
3114+ #print(f'count of missing * is {count}')
3115+ atom_labeling_in_comp_rxns_set = {k : set (v ) for k , v in atom_labeling_in_comp_rxns .items ()}
3116+
3117+ return atom_labeling_in_comp_rxns_set , unlabeled_atoms_in_comp_rxns
30833118
30843119 def extend_node (self , parent , template_rxn_map , obj = None , T = 1000.0 , iter_max = np .inf , iter_item_cap = np .inf ):
30853120 """
@@ -3156,9 +3191,46 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
31563191
31573192 extname = ext [2 ]
31583193
3194+ print (extname , ext [3 ])
31593195 if ext [3 ] == 'atomExt' :
3160- ext [0 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [0 ].atoms [ext [4 ][0 ]].atomtype , ext [0 ].atoms [ext [4 ][0 ]].atomtype ]
3161- elif ext [3 ] == 'elExt' :
3196+ ext [0 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [0 ].atoms [ext [4 ][0 ]].atomtype , ext [0 ].atoms [ext [4 ][0 ]].atomtype ] #passing regularization information to the selected extension node
3197+
3198+ #handling regularization in complement below
3199+ atom_labeling_in_comp_rxns_set , unlabeled_atoms_in_comp_rxns = self .get_compliment_reg_dim (parent , template_rxn_map , ext [0 ], ext [1 ])
3200+ #print(ext[0].atoms[ext[4][0]], ext[0].atoms[ext[4][0]].label, ext[1].atoms[ext[4][0]], ext[1].atoms[ext[4][0]].label)
3201+
3202+ #regularize the atom in which the extension was performed on
3203+ if ext [1 ] is not None :
3204+ if ext [1 ].atoms [ext [4 ][0 ]].label == '' :
3205+ #extension was performed on an unlabeled atom
3206+ limited_atomtypes_comp = set (ext [1 ].atoms [ext [4 ][0 ]].atomtype ).intersection (set (unlabeled_atoms_in_comp_rxns ))
3207+ #print(ext[1].atoms[ext[4][0]].atomtype, list(limited_atomtypes_comp))
3208+ ext [1 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [1 ].atoms [ext [4 ][0 ]].atomtype , list (limited_atomtypes_comp )]
3209+ else :
3210+ adjusted_index = int (ext [1 ].atoms [ext [4 ][0 ]].label .replace ('*' ,'' )) #i.e. ext[4]= (3,), ext[4][0] = 3, ext[0].atoms[3]=<GroupAtom [*5 'N', 'C']>, ext[0].atoms[3].label = '*5'
3211+ ext [1 ].atoms [ext [4 ][0 ]].reg_dim_atm = [ext [1 ].atoms [ext [4 ][0 ]].atomtype , list (atom_labeling_in_comp_rxns_set [adjusted_index ])]
3212+
3213+ #make sure the rest of the atoms in the extension take on the same regularization dimensions as the parent.
3214+ for i , parent_atm in enumerate (parent .item .atoms ):
3215+ if i == ext [4 ][0 ]:
3216+ print ('extension atom' )
3217+ continue #this is the atom that the extension is focused on, handled above
3218+ elif parent_atm .reg_dim_atm [1 ]== []:
3219+ print ('parent atm reg_dim is empty' )
3220+ continue #only take on regularization dimensions of parent if there is some
3221+ else :
3222+ ext [0 ].atoms [i ].reg_dim_atm [1 ] = parent_atm .reg_dim_atm [1 ] #passing regularization info from parent to the extension
3223+ if ext [1 ] is not None :
3224+ ext [1 ].atoms [i ].reg_dim_atm [1 ] = parent_atm .reg_dim_atm [1 ] #passing regularization info from parent to the complimentary extension
3225+
3226+
3227+ #print(ext[1].atoms[i].atomtype,' ', ext[1].atoms[i].reg_dim_atm[1])
3228+
3229+
3230+
3231+ # print(ext[1].atoms[ext[4][0]].atomtype, )
3232+ # ext[1].atoms[ext[4][0]].reg_dim_atm = [ext[1].atoms[ext[4][0]].atomtype, ext[1].atoms[ext[4][0]].atomtype] #must also pass regularization information to the compliment
3233+ if ext [3 ] == 'elExt' :
31623234 ext [0 ].atoms [ext [4 ][0 ]].reg_dim_u = [ext [0 ].atoms [ext [4 ][0 ]].radical_electrons ,
31633235 ext [0 ].atoms [ext [4 ][0 ]].radical_electrons ]
31643236
@@ -3244,8 +3316,11 @@ def extend_node(self, parent, template_rxn_map, obj=None, T=1000.0, iter_max=np.
32443316 if complement :
32453317 template_rxn_map [parent .label ] = []
32463318 template_rxn_map [cextname ] = comp_entries
3319+ if cextname == "Root_N-4R!H->O" :
3320+ print (f'end of extend_node: { self .groups .entries ["Root_N-4R!H->O" ].item .atoms [3 ].reg_dim_atm } ' )
32473321 else :
32483322 template_rxn_map [parent .label ] = comp_entries
3323+
32493324 return True
32503325
32513326 def generate_tree (self , rxns = None , obj = None , thermo_database = None , T = 1000.0 , nprocs = 1 , min_splitable_entry_num = 2 ,
@@ -3314,6 +3389,8 @@ def rxnkey(rxn):
33143389 logging .error ("built tree with {} nodes" .format (len (list (self .groups .entries ))))
33153390
33163391 self .auto_generated = True
3392+ print (f'end of generate_tree: { self .groups .entries ["Root_N-4R!H->O" ].item .atoms [3 ].reg_dim_atm } ' )
3393+
33173394
33183395 def get_rxn_batches (self , rxns , T = 1000.0 , max_batch_size = 800 , outlier_fraction = 0.02 , stratum_num = 8 ):
33193396 """
@@ -3481,6 +3558,8 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
34813558 continue
34823559 boo2 = self .extend_node (entry , template_rxn_map , obj , T , iter_max = extension_iter_max , iter_item_cap = extension_iter_item_cap )
34833560 if boo2 : # extended node so restart while loop
3561+ # if "Root_N-4R!H->O" in template_rxn_map.keys():
3562+ # print(f'at boo2: {self.groups.entries["Root_N-4R!H->O"].item.atoms[3].reg_dim_atm}')
34843563 break
34853564 else : # no extensions could be generated since all reactions were identical
34863565 mult_completed_nodes .append (entry )
@@ -3512,6 +3591,8 @@ def make_tree_nodes(self, template_rxn_map=None, obj=None, T=1000.0, nprocs=0, d
35123591 entry .parent = self .groups .entries [pname ]
35133592 entry .parent .children .append (entry )
35143593
3594+ print (f'end of make_tree_nodes: { self .groups .entries ["Root_N-4R!H->O" ].item .atoms [3 ].reg_dim_atm } ' )
3595+
35153596 return
35163597
35173598 def _absorb_process (self , p , conn , name ):
@@ -3777,9 +3858,10 @@ def simple_regularization(self, node, template_rxn_map, test=True):
37773858 self .simple_regularization (child , template_rxn_map )
37783859
37793860 grp = node .item
3861+ parent = node .parent .item
37803862 rxns = template_rxn_map [node .label ]
37813863
3782- R = ['H' , 'C' , 'N' , 'O' , 'Si' , 'S' , 'Cl' , 'F' , 'Br' ] # set of possible R elements/atoms
3864+ R = ['H' , 'C' , 'N' , 'O' , 'Si' , 'S' , 'Cl' , 'F' , 'Br' , 'Li' ] # set of possible R elements/atoms
37833865 R = [ATOMTYPES [x ] for x in R ]
37843866
37853867 RnH = R [:]
@@ -3794,14 +3876,15 @@ def simple_regularization(self, node, template_rxn_map, test=True):
37943876 for i , atm1 in enumerate (grp .atoms ):
37953877
37963878 skip = False
3797- if node .children == []: # if the atoms or bonds are graphically indistinguishable don't regularize
3798- bdpairs = {(atm , tuple (bd .order )) for atm , bd in atm1 .bonds .items ()}
3799- for atm2 in grp .atoms :
3800- if atm1 is not atm2 and atm1 .atomtype == atm2 .atomtype and len (atm1 .bonds ) == len (atm2 .bonds ):
3801- bdpairs2 = {(atm , tuple (bd .order )) for atm , bd in atm2 .bonds .items ()}
3802- if bdpairs == bdpairs2 :
3803- skip = True
3804- indistinguishable .append (i )
3879+ if i <= len (parent .atoms )- 1 : #if we aren't at an atom definition that the parent node doesn't have (due to this child being an extNewBondExt type)
3880+ if node .children == [] and parent .atoms [i ].reg_dim_atm [1 ]== []: # if the atoms or bonds are graphically indistinguishable don't regularize
3881+ bdpairs = {(atm , tuple (bd .order )) for atm , bd in atm1 .bonds .items ()}
3882+ for atm2 in grp .atoms :
3883+ if atm1 is not atm2 and atm1 .atomtype == atm2 .atomtype and len (atm1 .bonds ) == len (atm2 .bonds ):
3884+ bdpairs2 = {(atm , tuple (bd .order )) for atm , bd in atm2 .bonds .items ()}
3885+ if bdpairs == bdpairs2 :
3886+ skip = True
3887+ indistinguishable .append (i )
38053888
38063889 if not skip and atm1 .reg_dim_atm [1 ] != [] and set (atm1 .reg_dim_atm [1 ]) != set (atm1 .atomtype ):
38073890 atyp = atm1 .atomtype
@@ -3813,14 +3896,14 @@ def simple_regularization(self, node, template_rxn_map, test=True):
38133896
38143897 vals = list (set (atyp ) & set (atm1 .reg_dim_atm [1 ]))
38153898 assert vals != [], 'cannot regularize to empty'
3816- if all ([set (child .item .atoms [i ].atomtype ) <= set (vals ) for child in node .children ]):
3817- if not test :
3818- atm1 .atomtype = vals
3819- else :
3820- oldvals = atm1 .atomtype
3821- atm1 .atomtype = vals
3822- if not self .rxns_match_node (node , rxns ):
3823- atm1 .atomtype = oldvals
3899+ # if all([set(child.item.atoms[i].atomtype) <= set(vals) for child in node.children]):
3900+ if not test :
3901+ atm1 .atomtype = vals
3902+ else :
3903+ oldvals = atm1 .atomtype
3904+ atm1 .atomtype = vals
3905+ if not self .rxns_match_node (node , rxns ):
3906+ atm1 .atomtype = oldvals
38243907
38253908 if not skip and atm1 .reg_dim_u [1 ] != [] and set (atm1 .reg_dim_u [1 ]) != set (atm1 .radical_electrons ):
38263909 if len (atm1 .radical_electrons ) == 1 :
0 commit comments