Skip to content

Commit bba6181

Browse files
authored
Merge pull request #25 from google-research/test_652155894
Write out the pseudolabel weights and a flag that indicates whether a sample has a ground truth label (0) or a pseudolabel (1).
2 parents 0fcfeda + 84cbfe7 commit bba6181

File tree

11 files changed

+315
-41
lines changed

11 files changed

+315
-41
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2424

2525
## [Unreleased]
2626

27+
## [0.3.1] - 2024-07-13
28+
29+
* Now writes out the pseudolabel weights and a flag that indicates whether a sample has a ground truth label (0) or a pseudolabel (1).
30+
2731
## [0.3.0] - 2024-07-10
2832

2933
* Add the ability to use CSV files on GCS as data input/output/test sources.
@@ -45,7 +49,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
4549

4650
* Initial release
4751

48-
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...HEAD
52+
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.1...HEAD
53+
[0.3.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...v0.3.1
4954
[0.3.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...v0.3.0
5055
[0.2.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...v0.2.2
5156
[0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies = [
2323
"pyarrow==14.0.1",
2424
"retry==0.9.2",
2525
"scikit-learn==1.4.2",
26-
"tensorflow",
26+
"tensorflow==2.12.1",
2727
"tensorflow-datasets==4.9.6",
2828
"parameterized==0.8.1",
2929
"pytest==7.1.2",

spade_anomaly_detection/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131

3232
# A new PyPI release will be pushed every time `__version__` is increased.
3333
# When changing this, also update the CHANGELOG.md.
34-
__version__ = '0.3.0'
34+
__version__ = '0.3.1'

spade_anomaly_detection/csv_data_loader.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from google.cloud import storage
3939
import numpy as np
4040
import pandas as pd
41+
from spade_anomaly_detection import data_loader
4142
from spade_anomaly_detection import parameters
4243
import tensorflow as tf
4344

@@ -489,13 +490,17 @@ def upload_dataframe_to_gcs(
489490
batch: int,
490491
features: np.ndarray,
491492
labels: np.ndarray,
493+
weights: Optional[np.ndarray] = None,
494+
pseudolabel_flags: Optional[np.ndarray] = None,
492495
) -> None:
493496
"""Uploads the dataframe to BigQuery, create or replace table.
494497
495498
Args:
496499
batch: The batch number of the pseudo-labeled data.
497500
features: Numpy array of features.
498501
labels: Numpy array of labels.
502+
weights: Optional numpy array of weights.
503+
pseudolabel_flags: Optional numpy array of pseudolabel flags.
499504
500505
Returns:
501506
None.
@@ -515,15 +520,37 @@ def upload_dataframe_to_gcs(
515520
'Data output GCS URI is not set in the runner parameters. Please set '
516521
'the `data_output_gcs_uri` field in the runner parameters.'
517522
)
518-
combined_data = np.concatenate(
519-
[features, labels.reshape(len(features), 1)], axis=1
520-
)
523+
combined_data = features
521524

522525
column_names = list(
523526
self._last_read_metadata.column_names_info.column_names_dict.keys()
524527
)
528+
529+
# If the weights are provided, add them to the column names and to the
530+
# combined data.
531+
if weights is not None:
532+
column_names.append(data_loader.WEIGHT_COLUMN_NAME)
533+
combined_data = np.concatenate(
534+
[combined_data, weights.reshape(len(features), 1).astype(np.float64)],
535+
axis=1,
536+
)
537+
538+
# If the pseudolabel flags are provided, add them to the column names and
539+
# to the combined data.
540+
if pseudolabel_flags is not None:
541+
column_names.append(data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME)
542+
combined_data = np.concatenate(
543+
[
544+
combined_data,
545+
pseudolabel_flags.reshape(len(features), 1).astype(np.int64),
546+
],
547+
axis=1,
548+
)
549+
525550
# Make sure the label column is the last column.
526-
# TODO(b/347332980): Add support for the pseudolabel flag.
551+
combined_data = np.concatenate(
552+
[combined_data, labels.reshape(len(features), 1)], axis=1
553+
)
527554
column_names.remove(self.runner_parameters.label_col_name)
528555
column_names.append(self.runner_parameters.label_col_name)
529556

@@ -536,6 +563,14 @@ def upload_dataframe_to_gcs(
536563
complete_dataframe[self.runner_parameters.label_col_name].astype('bool')
537564
)
538565

566+
# Adjust pseudolabel flag column type.
567+
if pseudolabel_flags is not None:
568+
complete_dataframe[data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME] = (
569+
complete_dataframe[data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME].astype(
570+
np.int64
571+
)
572+
)
573+
539574
output_path = os.path.join(
540575
self.runner_parameters.data_output_gcs_uri,
541576
f'pseudo_labeled_batch_{batch}.csv',

spade_anomaly_detection/csv_data_loader_test.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,30 +385,75 @@ def test_upload_dataframe_to_gcs(self):
385385
all_features = self.data_df[["x1", "x2"]].to_numpy()
386386
all_labels = self.data_df["y"].to_numpy()
387387
# Create 2 batches of features and labels.
388-
# TODO(b/347332980): Update test when pseudolabel flag is added.
389388
features1 = all_features[0:2]
390389
labels1 = all_labels[0:2]
390+
# Add weights and flags to the first batch. These are pseudolabeled samples.
391+
weights1 = (
392+
np.repeat([0.1], len(features1))
393+
.reshape(len(features1), 1)
394+
.astype(np.float64)
395+
)
396+
flags1 = (
397+
np.repeat([1], len(features1))
398+
.reshape(len(features1), 1)
399+
.astype(np.int64)
400+
)
401+
# Add weights and flags to the first batch. These are ground truth samples.
391402
features2 = all_features[2:]
392403
labels2 = all_labels[2:]
393-
# Upload batch 1.
404+
weights2 = (
405+
np.repeat([1.0], len(features2))
406+
.reshape(len(features2), 1)
407+
.astype(np.float64)
408+
)
409+
flags2 = (
410+
np.repeat([0], len(features2))
411+
.reshape(len(features2), 1)
412+
.astype(np.int64)
413+
) # Upload batch 1.
394414
data_loader.upload_dataframe_to_gcs(
395415
batch=1,
396416
features=features1,
397417
labels=labels1,
418+
weights=weights1,
419+
pseudolabel_flags=flags1,
398420
)
399421
# Upload batch 2.
400422
data_loader.upload_dataframe_to_gcs(
401423
batch=2,
402424
features=features2,
403425
labels=labels2,
426+
weights=weights2,
427+
pseudolabel_flags=flags2,
404428
)
405429
# Sorting means batch 1 file will be first.
406430
files_list = sorted(tf.io.gfile.listdir(output_dir))
407431
self.assertLen(files_list, 2)
408-
expected_dfs = [
409-
self.data_df.iloc[0:2].reset_index(drop=True),
410-
self.data_df.iloc[2:].reset_index(drop=True),
411-
]
432+
col_names = ["x1", "x2", "alpha", "is_pseudolabel", "y"]
433+
expected_df1 = pd.concat(
434+
[
435+
self.data_df.iloc[0:2, 0:-1].reset_index(drop=True),
436+
pd.DataFrame(weights1, columns=["alpha"]),
437+
pd.DataFrame(flags1, columns=["is_pseudolabel"]),
438+
self.data_df.iloc[0:2, -1].reset_index(drop=True),
439+
],
440+
names=col_names,
441+
ignore_index=True,
442+
axis=1,
443+
)
444+
expected_df1.columns = col_names
445+
expected_df2 = pd.concat(
446+
[
447+
self.data_df.iloc[2:, 0:-1].reset_index(drop=True),
448+
pd.DataFrame(weights2, columns=["alpha"]),
449+
pd.DataFrame(flags2, columns=["is_pseudolabel"]),
450+
self.data_df.iloc[2:, -1].reset_index(drop=True),
451+
],
452+
ignore_index=True,
453+
axis=1,
454+
)
455+
expected_df2.columns = col_names
456+
expected_dfs = [expected_df1, expected_df2]
412457
for i, file_name in enumerate(files_list):
413458
with self.subTest(msg=f"file_{i}"):
414459
file_path = os.path.join(output_dir, file_name)

spade_anomaly_detection/data_loader.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454

5555
_DATA_ROOT: Final[str] = 'spade_anomaly_detection/example_data/'
5656

57+
WEIGHT_COLUMN_NAME: Final[str] = 'alpha'
58+
PSEUDOLABEL_FLAG_COLUMN_NAME: Final[str] = 'is_pseudolabel'
59+
5760

5861
def load_dataframe(
5962
dataset_name: str,
@@ -691,12 +694,19 @@ def upload_dataframe_as_bigquery_table(
691694
self,
692695
features: np.ndarray,
693696
labels: np.ndarray,
697+
weights: Optional[np.ndarray] = None,
698+
pseudolabel_flags: Optional[np.ndarray] = None,
694699
) -> None:
695700
"""Uploads the dataframe to BigQuery, create or replace table.
696701
697702
Args:
698703
features: Numpy array of features.
699704
labels: Numpy array of labels.
705+
weights: Optional numpy array of weights.
706+
pseudolabel_flags: Optional numpy array of pseudolabel flags.
707+
708+
Raises:
709+
ValueError: If the metadata has not been read yet.
700710
"""
701711
if not self.input_feature_metadata:
702712
raise ValueError(
@@ -705,11 +715,31 @@ def upload_dataframe_as_bigquery_table(
705715
'load_tf_dataset_from_bigquery() before this method '
706716
'is called.'
707717
)
708-
combined_data = np.concatenate(
709-
[features, labels.reshape(len(features), 1)], axis=1
710-
)
718+
combined_data = features
711719

720+
# Get the list of feature and label column names.
712721
column_names = list(self.input_feature_metadata.names)
722+
723+
# If the weights are provided, add them to the column names and to the
724+
# combined data.
725+
if weights is not None:
726+
column_names.append(WEIGHT_COLUMN_NAME)
727+
combined_data = np.concatenate(
728+
[combined_data, weights.reshape(len(features), 1)], axis=1
729+
)
730+
731+
# If the pseudolabel flags are provided, add them to the column names and
732+
# to the combined data.
733+
if pseudolabel_flags is not None:
734+
column_names.append(PSEUDOLABEL_FLAG_COLUMN_NAME)
735+
combined_data = np.concatenate(
736+
[combined_data, pseudolabel_flags.reshape(len(features), 1)], axis=1
737+
)
738+
739+
# Make sure the label column is the last column.
740+
combined_data = np.concatenate(
741+
[combined_data, labels.reshape(len(features), 1)], axis=1
742+
)
713743
column_names.remove(self.runner_parameters.label_col_name)
714744
column_names.append(self.runner_parameters.label_col_name)
715745

@@ -722,6 +752,12 @@ def upload_dataframe_as_bigquery_table(
722752
complete_dataframe[self.runner_parameters.label_col_name].astype('bool')
723753
)
724754

755+
# Adjust pseudolabel flag column type.
756+
if pseudolabel_flags is not None:
757+
complete_dataframe[PSEUDOLABEL_FLAG_COLUMN_NAME] = complete_dataframe[
758+
PSEUDOLABEL_FLAG_COLUMN_NAME
759+
].astype(np.int64)
760+
725761
with bigquery.Client(
726762
project=self.table_parts.project_id
727763
) as big_query_client:

spade_anomaly_detection/data_loader_test.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,117 @@ def test_bigquery_table_upload_throw_error_metadata(self):
536536
features=features, labels=labels
537537
)
538538

539+
@mock.patch.object(bigquery, 'LoadJobConfig', autospec=True)
540+
def test_upload_dataframe_with_wts_flags_as_bigquery_table_no_error(
541+
self, mock_bqclient_loadjobconfig
542+
):
543+
self.runner_parameters.output_bigquery_table_path = (
544+
'project.dataset.pseudo_labeled_data'
545+
)
546+
data_loader_object = data_loader.DataLoader(self.runner_parameters)
547+
feature_column_names = [
548+
'x1',
549+
'x2',
550+
data_loader.WEIGHT_COLUMN_NAME,
551+
data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME,
552+
self.runner_parameters.label_col_name,
553+
]
554+
555+
features = np.random.rand(10, 2).astype(np.float32)
556+
labels = np.repeat(0, 10).reshape(10, 1).astype(np.int8)
557+
# Two possible values for weight (alpha), repeated 10/2 = 5 times each.
558+
weights = np.repeat([0.1, 1.0], 5).reshape(10, 1).astype(np.float32)
559+
# The corresponding peseudolabel flags are False, True, repeated 5 times.
560+
flags = np.repeat([1, 0], 5).reshape(10, 1).astype(np.int8)
561+
562+
tf_dataset_instance_mock = mock.create_autospec(
563+
tf.data.Dataset, instance=True
564+
)
565+
566+
feature1_metadata = feature_metadata.FeatureMetadata('x1', 0, 'FLOAT64')
567+
feature2_metadata = feature_metadata.FeatureMetadata('x2', 0, 'FLOAT64')
568+
label_metadata = feature_metadata.FeatureMetadata(
569+
self.runner_parameters.label_col_name, 1, 'INT64'
570+
)
571+
metadata_container = feature_metadata.FeatureMetadataContainer(
572+
[feature1_metadata, feature2_metadata, label_metadata]
573+
)
574+
575+
self.mock_bq_dataset.return_value = (
576+
tf_dataset_instance_mock,
577+
metadata_container,
578+
)
579+
580+
# Perform this call so that FeatureMetadata is set.
581+
data_loader_object.load_tf_dataset_from_bigquery(
582+
input_path=self.runner_parameters.input_bigquery_table_path,
583+
label_col_name=self.runner_parameters.label_col_name,
584+
batch_size=self.batch_size,
585+
)
586+
587+
data_loader_object.upload_dataframe_as_bigquery_table(
588+
features=features,
589+
labels=labels,
590+
weights=weights,
591+
pseudolabel_flags=flags,
592+
)
593+
job_config_object = mock_bqclient_loadjobconfig.return_value
594+
595+
load_table_mock_kwargs = (
596+
self.mock_bq_client.return_value.__enter__.return_value.load_table_from_dataframe.call_args.kwargs
597+
)
598+
599+
with self.subTest(name='LabelColumnCorrect'):
600+
self.assertListEqual(
601+
list(
602+
load_table_mock_kwargs['dataframe'][
603+
self.runner_parameters.label_col_name
604+
]
605+
),
606+
list(labels),
607+
)
608+
609+
with self.subTest(name='LabelColumnDataTypeBool'):
610+
self.assertEqual(
611+
load_table_mock_kwargs['dataframe'][
612+
self.runner_parameters.label_col_name
613+
].dtype,
614+
bool,
615+
)
616+
617+
with self.subTest(name='WeightsColumnCorrect'):
618+
self.assertListEqual(
619+
list(
620+
load_table_mock_kwargs['dataframe'][
621+
data_loader.WEIGHT_COLUMN_NAME
622+
]
623+
),
624+
list(weights),
625+
)
626+
627+
with self.subTest(name='PseudolabelFlagsColumnCorrect'):
628+
self.assertListEqual(
629+
list(
630+
load_table_mock_kwargs['dataframe'][
631+
data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME
632+
]
633+
),
634+
list(flags),
635+
)
636+
637+
with self.subTest(name='EqualColumnNames'):
638+
self.assertListEqual(
639+
feature_column_names,
640+
list(load_table_mock_kwargs['dataframe'].columns),
641+
)
642+
with self.subTest(name='EqualDestinationPath'):
643+
self.assertEqual(
644+
self.runner_parameters.output_bigquery_table_path,
645+
load_table_mock_kwargs['destination'],
646+
)
647+
with self.subTest(name='EqualJobConfig'):
648+
self.assertEqual(job_config_object, load_table_mock_kwargs['job_config'])
649+
539650
def test_get_label_thresholds_no_error(self):
540651
mock_query_return_dictionary = {
541652
self.runner_parameters.label_col_name: [

0 commit comments

Comments
 (0)