@@ -536,6 +536,117 @@ def test_bigquery_table_upload_throw_error_metadata(self):
536536 features = features , labels = labels
537537 )
538538
539+ @mock .patch .object (bigquery , 'LoadJobConfig' , autospec = True )
540+ def test_upload_dataframe_with_wts_flags_as_bigquery_table_no_error (
541+ self , mock_bqclient_loadjobconfig
542+ ):
543+ self .runner_parameters .output_bigquery_table_path = (
544+ 'project.dataset.pseudo_labeled_data'
545+ )
546+ data_loader_object = data_loader .DataLoader (self .runner_parameters )
547+ feature_column_names = [
548+ 'x1' ,
549+ 'x2' ,
550+ data_loader .WEIGHT_COLUMN_NAME ,
551+ data_loader .PSEUDOLABEL_FLAG_COLUMN_NAME ,
552+ self .runner_parameters .label_col_name ,
553+ ]
554+
555+ features = np .random .rand (10 , 2 ).astype (np .float32 )
556+ labels = np .repeat (0 , 10 ).reshape (10 , 1 ).astype (np .int8 )
557+ # Two possible values for weight (alpha), repeated 10/2 = 5 times each.
558+ weights = np .repeat ([0.1 , 1.0 ], 5 ).reshape (10 , 1 ).astype (np .float32 )
559+ # The corresponding peseudolabel flags are False, True, repeated 5 times.
560+ flags = np .repeat ([1 , 0 ], 5 ).reshape (10 , 1 ).astype (np .int8 )
561+
562+ tf_dataset_instance_mock = mock .create_autospec (
563+ tf .data .Dataset , instance = True
564+ )
565+
566+ feature1_metadata = feature_metadata .FeatureMetadata ('x1' , 0 , 'FLOAT64' )
567+ feature2_metadata = feature_metadata .FeatureMetadata ('x2' , 0 , 'FLOAT64' )
568+ label_metadata = feature_metadata .FeatureMetadata (
569+ self .runner_parameters .label_col_name , 1 , 'INT64'
570+ )
571+ metadata_container = feature_metadata .FeatureMetadataContainer (
572+ [feature1_metadata , feature2_metadata , label_metadata ]
573+ )
574+
575+ self .mock_bq_dataset .return_value = (
576+ tf_dataset_instance_mock ,
577+ metadata_container ,
578+ )
579+
580+ # Perform this call so that FeatureMetadata is set.
581+ data_loader_object .load_tf_dataset_from_bigquery (
582+ input_path = self .runner_parameters .input_bigquery_table_path ,
583+ label_col_name = self .runner_parameters .label_col_name ,
584+ batch_size = self .batch_size ,
585+ )
586+
587+ data_loader_object .upload_dataframe_as_bigquery_table (
588+ features = features ,
589+ labels = labels ,
590+ weights = weights ,
591+ pseudolabel_flags = flags ,
592+ )
593+ job_config_object = mock_bqclient_loadjobconfig .return_value
594+
595+ load_table_mock_kwargs = (
596+ self .mock_bq_client .return_value .__enter__ .return_value .load_table_from_dataframe .call_args .kwargs
597+ )
598+
599+ with self .subTest (name = 'LabelColumnCorrect' ):
600+ self .assertListEqual (
601+ list (
602+ load_table_mock_kwargs ['dataframe' ][
603+ self .runner_parameters .label_col_name
604+ ]
605+ ),
606+ list (labels ),
607+ )
608+
609+ with self .subTest (name = 'LabelColumnDataTypeBool' ):
610+ self .assertEqual (
611+ load_table_mock_kwargs ['dataframe' ][
612+ self .runner_parameters .label_col_name
613+ ].dtype ,
614+ bool ,
615+ )
616+
617+ with self .subTest (name = 'WeightsColumnCorrect' ):
618+ self .assertListEqual (
619+ list (
620+ load_table_mock_kwargs ['dataframe' ][
621+ data_loader .WEIGHT_COLUMN_NAME
622+ ]
623+ ),
624+ list (weights ),
625+ )
626+
627+ with self .subTest (name = 'PseudolabelFlagsColumnCorrect' ):
628+ self .assertListEqual (
629+ list (
630+ load_table_mock_kwargs ['dataframe' ][
631+ data_loader .PSEUDOLABEL_FLAG_COLUMN_NAME
632+ ]
633+ ),
634+ list (flags ),
635+ )
636+
637+ with self .subTest (name = 'EqualColumnNames' ):
638+ self .assertListEqual (
639+ feature_column_names ,
640+ list (load_table_mock_kwargs ['dataframe' ].columns ),
641+ )
642+ with self .subTest (name = 'EqualDestinationPath' ):
643+ self .assertEqual (
644+ self .runner_parameters .output_bigquery_table_path ,
645+ load_table_mock_kwargs ['destination' ],
646+ )
647+ with self .subTest (name = 'EqualJobConfig' ):
648+ self .assertEqual (job_config_object , load_table_mock_kwargs ['job_config' ])
649+
539650 def test_get_label_thresholds_no_error (self ):
540651 mock_query_return_dictionary = {
541652 self .runner_parameters .label_col_name : [
0 commit comments