4444import tensorflow as tf
4545
4646
47- # Types are from //cloud/ml/research/data_utils/feature_metadata.py
4847_FEATURES_TYPE : Final [str ] = 'FLOAT64'
4948_SOURCE_LABEL_TYPE : Final [str ] = 'STRING'
5049_SOURCE_LABEL_DEFAULT_VALUE : Final [str ] = '-1'
5150_LABEL_TYPE : Final [str ] = 'INT64'
51+ _STRING_TO_INTEGER_LABEL_MAP : dict [str | int , int ] = {
52+ 1 : 1 ,
53+ 0 : 0 ,
54+ - 1 : - 1 ,
55+ '' : - 1 ,
56+ '-1' : - 1 ,
57+ '0' : 0 ,
58+ '1' : 1 ,
59+ 'positive' : 1 ,
60+ 'negative' : 0 ,
61+ 'unlabeled' : - 1 ,
62+ }
5263
5364# Setting the shuffle buffer size to 1M seems to be necessary to get the CSV
5465# reader to provide a diversity of data to the model.
@@ -167,12 +178,12 @@ def from_inputs_file(
167178 raise ValueError (
168179 f'Label column { label_column_name } not found in the header: { header } '
169180 )
170- num_features = len (all_columns ) - 1
171181 features_types = [_FEATURES_TYPE ] * len (all_columns )
172182 column_names_dict = collections .OrderedDict (
173183 zip (all_columns , features_types )
174184 )
175185 column_names_dict [label_column_name ] = _SOURCE_LABEL_DEFAULT_VALUE
186+ num_features = len (all_columns ) - 1
176187 return ColumnNamesInfo (
177188 column_names_dict = column_names_dict ,
178189 header = header ,
@@ -216,6 +227,13 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
216227 self .runner_parameters .negative_data_value ,
217228 self .runner_parameters .unlabeled_data_value ,
218229 ]
230+ # Add any labels that are not already in the map.
231+ _STRING_TO_INTEGER_LABEL_MAP [self .runner_parameters .positive_data_value ] = 1
232+ _STRING_TO_INTEGER_LABEL_MAP [self .runner_parameters .negative_data_value ] = 0
233+ _STRING_TO_INTEGER_LABEL_MAP [
234+ self .runner_parameters .unlabeled_data_value
235+ ] = - 1
236+
219237 # Construct a label remap from string labels to integers. The table is not
220238 # necessary for the case when the labels are all integers. But instead of
221239 # checking if the labels are all integers, we construct the table and use
@@ -286,7 +304,8 @@ def get_inputs_metadata(
286304 )
287305 # Get information about the columns.
288306 column_names_info = ColumnNamesInfo .from_inputs_file (
289- csv_filenames [0 ], label_column_name
307+ csv_filenames [0 ],
308+ label_column_name ,
290309 )
291310 logging .info (
292311 'Obtained metadata for data with CSV prefix %s (number of features=%d)' ,
@@ -360,22 +379,19 @@ def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disab
360379 @classmethod
361380 def convert_str_to_int (cls , value : str ) -> int :
362381 """Converts a string integer label to an integer label."""
363- if isinstance (value , str ) and value .lstrip ('-' ).isdigit ():
364- return int (value )
365- elif isinstance (value , int ):
366- return value
382+ if value in _STRING_TO_INTEGER_LABEL_MAP :
383+ return _STRING_TO_INTEGER_LABEL_MAP [value ]
367384 else :
368385 raise ValueError (
369- f'Label { value } of type { type (value )} is not a string integer.'
386+ f'Label { value } of type { type (value )} is not a string integer or '
387+ 'mappable to an integer.'
370388 )
371389
372390 @classmethod
373391 def _get_label_remap_table (
374392 cls , labels_mapping : dict [str , int ]
375393 ) -> tf .lookup .StaticHashTable :
376394 """Returns a label remap table that converts string labels to integers."""
377- # The possible keys are '', '-1, '0', '1'. None is not included because the
378- # Data Loader will default to '' if the label is None.
379395 keys_tensor = tf .constant (
380396 list (labels_mapping .keys ()),
381397 dtype = tf .dtypes .as_dtype (_SOURCE_LABEL_TYPE .lower ()),
@@ -390,6 +406,14 @@ def _get_label_remap_table(
390406 )
391407 return label_remap_table
392408
409+ def remap_label (self , label : str | tf .Tensor ) -> int | tf .Tensor :
410+ """Remaps the label to an integer."""
411+ if isinstance (label , str ) or (
412+ isinstance (label , tf .Tensor ) and label .dtype == tf .dtypes .string
413+ ):
414+ return self ._label_remap_table .lookup (label )
415+ return label
416+
393417 def load_tf_dataset_from_csv (
394418 self ,
395419 input_path : str ,
@@ -441,6 +465,7 @@ def load_tf_dataset_from_csv(
441465 self ._last_read_metadata .column_names_info .column_names_dict .values ()
442466 )
443467 ]
468+ logging .info ('column_defaults: %s' , column_defaults )
444469
445470 # Construct a single dataset out of multiple CSV files.
446471 # TODO(sinharaj): Remove the determinism after testing.
@@ -456,7 +481,7 @@ def load_tf_dataset_from_csv(
456481 na_value = '' ,
457482 header = True ,
458483 num_epochs = 1 ,
459- shuffle = True ,
484+ shuffle = False ,
460485 shuffle_buffer_size = _SHUFFLE_BUFFER_SIZE ,
461486 shuffle_seed = self .runner_parameters .random_seed ,
462487 prefetch_buffer_size = tf .data .AUTOTUNE ,
@@ -473,17 +498,9 @@ def load_tf_dataset_from_csv(
473498 'created.'
474499 )
475500
476- def remap_label (label : str | tf .Tensor ) -> int | tf .Tensor :
477- """Remaps the label to an integer."""
478- if isinstance (label , str ) or (
479- isinstance (label , tf .Tensor ) and label .dtype == tf .dtypes .string
480- ):
481- return self ._label_remap_table .lookup (label )
482- return label
483-
484501 # The Dataset can have labels of type int or str. Cast them to int.
485502 dataset = dataset .map (
486- lambda features , label : (features , remap_label (label )),
503+ lambda features , label : (features , self . remap_label (label )),
487504 num_parallel_calls = tf .data .AUTOTUNE ,
488505 deterministic = True ,
489506 )
@@ -535,7 +552,6 @@ def combine_features_dict_into_tensor(
535552 self ._label_counts = {
536553 k : v .numpy () for k , v in self .counts_by_label (dataset ).items ()
537554 }
538- logging .info ('Label counts: %s' , self ._label_counts )
539555
540556 return dataset
541557
@@ -554,11 +570,11 @@ def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]:
554570
555571 @tf .function
556572 def count_class (
557- counts : Dict [int , int ], # Keys are always strings.
573+ counts : Dict [int , int ],
558574 batch : Tuple [tf .Tensor , tf .Tensor ],
559575 ) -> Dict [int , int ]:
560576 _ , labels = batch
561- # Keys are always strings.
577+ labels = self . remap_label ( labels )
562578 new_counts : Dict [int , int ] = counts .copy ()
563579 for i in self .all_labels :
564580 # This function is called after the Dataset is constructed and the
@@ -582,6 +598,59 @@ def count_class(
582598 )
583599 return counts
584600
601+ def counts_by_original_label (
602+ self , dataset : tf .data .Dataset
603+ ) -> tuple [dict [str , tf .Tensor ], dict [int , tf .Tensor ]]:
604+ """Counts the number of samples in each label class in the dataset."""
605+
606+ all_int_labels = [l for l in self .all_labels if isinstance (l , int )]
607+ logging .info ('all_int_labels: %s' , all_int_labels )
608+ all_str_labels = [l for l in self .all_labels if isinstance (l , str )]
609+ logging .info ('all_str_labels: %s' , all_str_labels )
610+
611+ @tf .function
612+ def count_original_class (
613+ counts : Dict [int | str , int ],
614+ batch : Tuple [tf .Tensor , tf .Tensor ],
615+ ) -> Dict [int | str , int ]:
616+ keys_are_int = all (isinstance (k , int ) for k in counts .keys ())
617+ if keys_are_int :
618+ all_labels = all_int_labels
619+ else :
620+ all_labels = all_str_labels
621+ _ , labels = batch
622+ new_counts : Dict [int | str , int ] = counts .copy ()
623+ for label in all_labels :
624+ cc : tf .Tensor = tf .cast (labels == label , tf .int32 )
625+ if label in list (new_counts .keys ()):
626+ new_counts [label ] += tf .reduce_sum (cc )
627+ else :
628+ new_counts [label ] = tf .reduce_sum (cc )
629+ return new_counts
630+
631+ int_keys_map = {
632+ k : v
633+ for k , v in _STRING_TO_INTEGER_LABEL_MAP .items ()
634+ if isinstance (k , int )
635+ }
636+ initial_int_state = dict ((int (label ), 0 ) for label in int_keys_map .keys ())
637+ if initial_int_state :
638+ int_counts = dataset .reduce (
639+ initial_state = initial_int_state , reduce_func = count_original_class
640+ )
641+ else :
642+ int_counts = {}
643+ str_keys_map = {
644+ k : v
645+ for k , v in _STRING_TO_INTEGER_LABEL_MAP .items ()
646+ if isinstance (k , str )
647+ }
648+ initial_str_state = dict ((str (label ), 0 ) for label in str_keys_map .keys ())
649+ str_counts = dataset .reduce (
650+ initial_state = initial_str_state , reduce_func = count_original_class
651+ )
652+ return int_counts , str_counts
653+
585654 def get_label_thresholds (self ) -> Mapping [str , float ]:
586655 """Computes positive and negative thresholds based on label ratios.
587656
0 commit comments