@@ -39,7 +39,11 @@ def prepare_name_pairs_pd(
3939 entity_id_col = "entity_id" ,
4040 gt_entity_id_col = "gt_entity_id" ,
4141 positive_set_col = "positive_set" ,
42+ correct_col = "correct" ,
4243 uid_col = "uid" ,
44+ gt_uid_col = "gt_uid" ,
45+ preprocessed_col = "preprocessed" ,
46+ gt_preprocessed_col = "gt_preprocessed" ,
4347 random_seed = 42 ,
4448):
4549 """Prepare dataset of name-pair candidates for training of supervised model.
@@ -70,7 +74,12 @@ def prepare_name_pairs_pd(
7074 For matching name-pairs entity_id == gt_entity_id.
7175 positive_set_col: column that specifies which candidates remain positive and which become negative,
7276 default is "positive_set".
77+ correct_col: column that indicates a correct match, default is "correct".
78+ For entity_id == gt_entity_id the column value is "correct".
7379 uid_col: uid column for names to match, default is "uid".
80+ gt_uid_col: uid column of ground-truth names, default is "gt_uid".
81+ preprocessed_col: name of the preprocessed names column, default is "preprocessed".
82+ gt_preprocessed_col: name of the preprocessed ground-truth names column, default is "gt_preprocessed".
7483 random_seed: random seed for selection of negative names, default is 42.
7584 """
7685 """We can have the following dataset.columns, or much more like 'count', 'counterparty_account_count_distinct', 'type1_sum':
@@ -84,7 +93,7 @@ def prepare_name_pairs_pd(
8493 assert entity_id_col in candidates_pd .columns
8594 assert gt_entity_id_col in candidates_pd .columns
8695
87- candidates_pd ["correct" ] = candidates_pd [entity_id_col ] == candidates_pd [gt_entity_id_col ]
96+ candidates_pd [correct_col ] = candidates_pd [entity_id_col ] == candidates_pd [gt_entity_id_col ]
8897
8998 # negative sample creation?
9099 # if so, add positive_set_col column for negative sample creation
@@ -110,14 +119,14 @@ def prepare_name_pairs_pd(
110119 # - happens with one correct/positive case, we just pick the correct one
111120 if drop_duplicate_candidates :
112121 candidates_pd = candidates_pd .sort_values (
113- ["uid" , "gt_preprocessed" , "correct" ], ascending = False
114- ).drop_duplicates (subset = ["uid" , "gt_preprocessed" ], keep = "first" )
122+ [uid_col , gt_preprocessed_col , correct_col ], ascending = False
123+ ).drop_duplicates (subset = [uid_col , gt_preprocessed_col ], keep = "first" )
115124 # Similar, for a training set remove all equal names that are not considered a match.
116125 # This can happen a lot in actual data, e.g. with franchises that are independent but have the same name.
117126 # It's a true effect in data, but this screws up our intuitive notion that identical names should be related.
118127 if drop_samename_nomatch :
119- samename_nomatch = (candidates_pd ["preprocessed" ] == candidates_pd ["gt_preprocessed" ]) & ~ candidates_pd [
120- "correct"
128+ samename_nomatch = (candidates_pd [preprocessed_col ] == candidates_pd [gt_preprocessed_col ]) & ~ candidates_pd [
129+ correct_col
121130 ]
122131 candidates_pd = candidates_pd [~ samename_nomatch ]
123132
@@ -133,7 +142,9 @@ def prepare_name_pairs_pd(
133142 # is referred to in: resources/data/howto_create_unittest_sample_namepairs.txt
134143 # create negative sample and rerank negative candidates
135144 # this drops, in part, the negative correct candidates
136- candidates_pd = create_positive_negative_samples (candidates_pd )
145+ candidates_pd = create_positive_negative_samples (
146+ candidates_pd , uid_col = uid_col , correct_col = correct_col , positive_set_col = positive_set_col
147+ )
137148
138149 # It could be that we dropped all candidates, so we need to re-introduce the no-candidate rows
139150 names_to_match_after = candidates_pd [names_to_match_cols ].drop_duplicates ()
@@ -142,12 +153,12 @@ def prepare_name_pairs_pd(
142153 )
143154 names_to_match_missing = names_to_match_missing [names_to_match_missing ["_merge" ] == "left_only" ]
144155 names_to_match_missing = names_to_match_missing .drop (columns = ["_merge" ])
145- names_to_match_missing ["correct" ] = False
156+ names_to_match_missing [correct_col ] = False
146157 # Since this column is used to calculate benchmark metrics
147158 names_to_match_missing ["score_0_rank" ] = 1
148159
149160 candidates_pd = pd .concat ([candidates_pd , names_to_match_missing ], ignore_index = True )
150- candidates_pd ["gt_preprocessed" ] = candidates_pd ["gt_preprocessed" ].fillna ("" )
151- candidates_pd ["no_candidate" ] = candidates_pd ["gt_uid" ].isnull ()
161+ candidates_pd [gt_preprocessed_col ] = candidates_pd [gt_preprocessed_col ].fillna ("" )
162+ candidates_pd ["no_candidate" ] = candidates_pd [gt_uid_col ].isnull ()
152163
153164 return candidates_pd
0 commit comments