Skip to content

Commit 462e53e

Browse files
Vineet JoshiThe spade_anomaly_detection Authors
authored andcommitted
Evenly distributed the negative-labeled records to all ensemble models.
The different models in the ensemble should not not receive different proportion of negative labeled records to ensure uniformity in their training and performance. Through this commit, we explicitly assign all the negative-labeled features with the batch of unlabeled records to each model in the ensemble. PiperOrigin-RevId: 705983061
1 parent 2ec1a5a commit 462e53e

File tree

4 files changed

+201
-12
lines changed

4 files changed

+201
-12
lines changed

spade_anomaly_detection/csv_data_loader.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,15 +528,18 @@ def combine_features_dict_into_tensor(
528528
dataset = dataset.batch(batch_size, deterministic=True)
529529
dataset = dataset.prefetch(tf.data.AUTOTUNE)
530530

531-
# This Dataset was just created. Calculate the label distribution.
532-
# Any string labels were already re-mapped to integers. So keys are always
533-
# strings and values are always integers.
534-
self._label_counts = self.counts_by_label(dataset)
531+
# This Dataset was just created. Calculate the label distribution. Any
532+
# string labels were already re-mapped to integers. So keys are always
533+
# integers and values are EagerTensors. We need to extract the value within
534+
# this Tensor for subsequent use.
535+
self._label_counts = {
536+
k: v.numpy() for k, v in self.counts_by_label(dataset).items()
537+
}
535538
logging.info('Label counts: %s', self._label_counts)
536539

537540
return dataset
538541

539-
def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, int]:
542+
def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]:
540543
"""Counts the number of samples in each label class in the dataset.
541544
542545
When this function is called, the labels in the Dataset have already been

spade_anomaly_detection/occ_ensemble.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
_RANDOM_SEED: Final[int] = 42
4444

45+
_LABEL_TYPE: Final[str] = 'INT64'
46+
4547

4648
# TODO(b/247116870): Create abstract class for templating out future OCC models.
4749
class GmmEnsemble:
@@ -74,6 +76,12 @@ class GmmEnsemble:
7476
precision when raising this value, and an increase in recall when lowering
7577
it. Equavalent to saying the given data point needs to be X percentile or
7678
greater in order to be considered anomalous.
79+
unlabeled_record_count: The number of unlabeled records in the dataset.
80+
negative_record_count: The number of negative records in the dataset.
81+
unlabeled_data_value: The value used in the label column to denote unlabeled
82+
data.
83+
negative_data_value: The value used in the label column to denote negative
84+
data.
7785
verbose: Boolean denoting whether to send model performance and
7886
pseudo-labeling metrics to the GCP console.
7987
ensemble: A trained ensemble of one class classifiers.
@@ -90,6 +98,10 @@ def __init__(
9098
positive_threshold: float = 1.0,
9199
negative_threshold: float = 95.0,
92100
random_seed: int = _RANDOM_SEED,
101+
unlabeled_record_count: int | None = None,
102+
negative_record_count: int | None = None,
103+
unlabeled_data_value: int | None = None,
104+
negative_data_value: int | None = None,
93105
verbose: bool = False,
94106
) -> None:
95107
self.n_components = n_components
@@ -100,6 +112,10 @@ def __init__(
100112
self.positive_threshold = positive_threshold
101113
self.negative_threshold = negative_threshold
102114
self._random_seed = random_seed
115+
self.unlabeled_record_count = unlabeled_record_count
116+
self.negative_record_count = negative_record_count
117+
self.unlabeled_data_value = unlabeled_data_value
118+
self.negative_data_value = negative_data_value
103119
self.verbose = verbose
104120

105121
self.ensemble = []
@@ -121,6 +137,38 @@ def _get_model(self) -> mixture.GaussianMixture:
121137
random_state=self._random_seed,
122138
)
123139

140+
def _get_filter_by_label_value_func(self, label_column_filter_value: int):
141+
"""Returns a function that filters a record based on the label column value.
142+
143+
Args:
144+
label_column_filter_value: The value of the label column to use as a
145+
filter. If None, all records are included.
146+
147+
Returns:
148+
A function that returns True if the label column value is equal to the
149+
label_column_filter_value parameter.
150+
"""
151+
152+
def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disable=unused-argument
153+
if label_column_filter_value is None:
154+
return True
155+
label_cast = tf.cast(label, tf.dtypes.as_dtype(_LABEL_TYPE.lower()))
156+
label_column_filter_value_cast = tf.cast(
157+
label_column_filter_value, label_cast.dtype
158+
)
159+
broadcast_equal = tf.equal(label_column_filter_value_cast, label_cast)
160+
return tf.reduce_all(broadcast_equal)
161+
162+
return filter_func
163+
164+
def is_batched(self, dataset: tf.data.Dataset) -> bool:
165+
"""Returns True if the dataset is batched."""
166+
# This suffices for the current use case of the OCC ensemble.
167+
return len(dataset.element_spec[0].shape) == 2 and (
168+
dataset.element_spec[0].shape[0] is None
169+
or isinstance(dataset.element_spec[0].shape[0], int)
170+
)
171+
124172
def fit(
125173
self, train_x: tf.data.Dataset, batches_per_occ: int
126174
) -> Sequence[mixture.GaussianMixture]:
@@ -142,15 +190,73 @@ def fit(
142190
if batches_per_occ > 1:
143191
self._warm_start = True
144192

145-
dataset_iterator = train_x.as_numpy_iterator()
193+
has_batches = self.is_batched(train_x)
194+
logging.info('has_batches is %s', has_batches)
195+
negative_features = None
196+
197+
if (
198+
not self.unlabeled_record_count
199+
or not self.negative_record_count
200+
or not has_batches
201+
or self.unlabeled_data_value is None
202+
or self.negative_data_value is None
203+
):
204+
# Either the dataset is not batched, or we don't have all the details to
205+
# extract the negative-labeled data. Hence we will use all the data for
206+
# training.
207+
dataset_iterator = train_x.as_numpy_iterator()
208+
else:
209+
# We unbatch the dataset so that we can separate-out the unlabeled and
210+
# negative data points
211+
ds_unbatched = train_x.unbatch()
212+
213+
ds_unlabeled = ds_unbatched.filter(
214+
self._get_filter_by_label_value_func(self.unlabeled_data_value)
215+
)
216+
217+
ds_negative = ds_unbatched.filter(
218+
self._get_filter_by_label_value_func(self.negative_data_value)
219+
)
220+
221+
negative_features_and_labels_zip = list(
222+
zip(*ds_negative.as_numpy_iterator())
223+
)
224+
225+
negative_features = (
226+
negative_features_and_labels_zip[0]
227+
if len(negative_features_and_labels_zip) == 2
228+
else None
229+
)
230+
231+
if negative_features is None:
232+
# The negative features were not extracted. This can happen when the
233+
# dataset elements are not tuples of features and labels. So we will use
234+
# all the data for training.
235+
ds_batched = train_x
236+
else:
237+
# The negative features were extracted. How we can proceed with creating
238+
# batches of unlabeled data, to which the negative data will be added
239+
# before training.
240+
batch_size = (
241+
self.unlabeled_record_count // self.ensemble_count
242+
) // batches_per_occ
243+
ds_batched = ds_unlabeled.batch(
244+
batch_size,
245+
drop_remainder=False,
246+
)
247+
dataset_iterator = ds_batched.as_numpy_iterator()
146248

147249
for _ in range(self.ensemble_count):
148250
model = self._get_model()
149251

150252
for _ in range(batches_per_occ):
151-
features, labels = dataset_iterator.next()
152-
del labels # Not needed for this task.
153-
model.fit(features)
253+
features, _ = dataset_iterator.next()
254+
all_features = (
255+
np.concatenate([features, negative_features], axis=0)
256+
if negative_features is not None
257+
else features
258+
)
259+
model.fit(all_features)
154260

155261
self.ensemble.append(model)
156262

spade_anomaly_detection/occ_ensemble_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,82 @@ def test_ensemble_training_no_error(
8585
msg='Model count in ensemble not equal to specified ensemble size.',
8686
)
8787

88+
@parameterized.named_parameters(
89+
('components_1_ensemble_10_full', 1, 10, 'full'),
90+
('components_1_ensemble_5_tied', 1, 5, 'tied'),
91+
)
92+
def test_ensemble_training_unlabeled_negative_no_error(
93+
self, n_components, ensemble_count, covariance_type
94+
):
95+
batches_per_occ = 10
96+
negative_data_value = 0
97+
unlabeled_data_value = -1
98+
99+
ensemble_obj = occ_ensemble.GmmEnsemble(
100+
n_components=n_components,
101+
ensemble_count=ensemble_count,
102+
covariance_type=covariance_type,
103+
negative_data_value=negative_data_value,
104+
unlabeled_data_value=unlabeled_data_value,
105+
)
106+
107+
tf_dataset = data_loader.load_tf_dataset_from_csv(
108+
dataset_name='covertype_pnu_100000', batch_size=None
109+
)
110+
# These are the actual counts of unlabeled and negative records in the
111+
# dataset.
112+
unlabeled_record_count = 94950
113+
negative_record_count = 4333
114+
ensemble_obj.unlabeled_record_count = unlabeled_record_count
115+
ensemble_obj.negative_record_count = negative_record_count
116+
117+
features_len = tf_dataset.cardinality().numpy()
118+
records_per_occ = features_len // ensemble_obj.ensemble_count
119+
batch_size = records_per_occ // batches_per_occ
120+
121+
tf_dataset = tf_dataset.shuffle(batch_size).batch(
122+
batch_size, drop_remainder=True
123+
)
124+
125+
ensemble_models = ensemble_obj.fit(tf_dataset, batches_per_occ)
126+
127+
self.assertLen(
128+
ensemble_models,
129+
ensemble_obj.ensemble_count,
130+
msg='Model count in ensemble not equal to specified ensemble size.',
131+
)
132+
133+
def test_dataset_filtering(self):
134+
positive_data_value = 1
135+
negative_data_value = 0
136+
unlabeled_data_value = -1
137+
gmm_ensemble = occ_ensemble.GmmEnsemble(n_components=1, ensemble_count=10)
138+
139+
tf_dataset = data_loader.load_tf_dataset_from_csv(
140+
dataset_name='covertype_pnu_100000', batch_size=None
141+
)
142+
tf_unlabeled_dataset = tf_dataset.filter(
143+
gmm_ensemble._get_filter_by_label_value_func(unlabeled_data_value)
144+
)
145+
tf_negative_dataset = tf_dataset.filter(
146+
gmm_ensemble._get_filter_by_label_value_func(negative_data_value)
147+
)
148+
tf_positive_dataset = tf_dataset.filter(
149+
gmm_ensemble._get_filter_by_label_value_func(positive_data_value)
150+
)
151+
self.assertEqual(
152+
tf_unlabeled_dataset.reduce(0, lambda x, _: x + 1).numpy(),
153+
94950,
154+
)
155+
self.assertEqual(
156+
tf_negative_dataset.reduce(0, lambda x, _: x + 1).numpy(),
157+
4333,
158+
)
159+
self.assertEqual(
160+
tf_positive_dataset.reduce(0, lambda x, _: x + 1).numpy(),
161+
715,
162+
)
163+
88164
@parameterized.named_parameters(
89165
('labels_are_integers', False),
90166
('labels_are_strings', True),

spade_anomaly_detection/runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ def instantiate_and_fit_ensemble(
313313
negative_threshold=self.runner_parameters.negative_threshold,
314314
random_seed=self.runner_parameters.random_seed,
315315
verbose=self.runner_parameters.verbose,
316+
unlabeled_data_value=self.int_unlabeled_data_value,
317+
negative_data_value=self.int_negative_data_value,
316318
)
317319

318320
training_record_count = unlabeled_record_count + negative_record_count
@@ -327,7 +329,7 @@ def instantiate_and_fit_ensemble(
327329
self.input_data_loader = cast(
328330
data_loader.DataLoader, self.input_data_loader
329331
)
330-
unlabeled_data = self.input_data_loader.load_tf_dataset_from_bigquery(
332+
training_data = self.input_data_loader.load_tf_dataset_from_bigquery(
331333
input_path=self.runner_parameters.input_bigquery_table_path,
332334
label_col_name=self.runner_parameters.label_col_name,
333335
where_statements=self.runner_parameters.where_statements,
@@ -346,7 +348,7 @@ def instantiate_and_fit_ensemble(
346348
self.input_data_loader = cast(
347349
csv_data_loader.CsvDataLoader, self.input_data_loader
348350
)
349-
unlabeled_data = self.input_data_loader.load_tf_dataset_from_csv(
351+
training_data = self.input_data_loader.load_tf_dataset_from_csv(
350352
input_path=self.runner_parameters.data_input_gcs_uri,
351353
label_col_name=self.runner_parameters.label_col_name,
352354
batch_size=batch_size,
@@ -358,10 +360,12 @@ def instantiate_and_fit_ensemble(
358360
self.int_negative_data_value,
359361
],
360362
)
363+
ensemble_object.unlabeled_record_count = unlabeled_record_count
364+
ensemble_object.negative_record_count = negative_record_count
361365

362366
logging.info('Fitting ensemble.')
363367
ensemble_object.fit(
364-
train_x=unlabeled_data,
368+
train_x=training_data,
365369
batches_per_occ=self.runner_parameters.batches_per_model,
366370
)
367371
logging.info('Ensemble fit complete.')

0 commit comments

Comments
 (0)