4242
4343_RANDOM_SEED : Final [int ] = 42
4444
45+ _LABEL_TYPE : Final [str ] = 'INT64'
46+
4547
4648# TODO(b/247116870): Create abstract class for templating out future OCC models.
4749class GmmEnsemble :
@@ -74,6 +76,12 @@ class GmmEnsemble:
7476 precision when raising this value, and an increase in recall when lowering
7577 it. Equavalent to saying the given data point needs to be X percentile or
7678 greater in order to be considered anomalous.
79+ unlabeled_record_count: The number of unlabeled records in the dataset.
80+ negative_record_count: The number of negative records in the dataset.
81+ unlabeled_data_value: The value used in the label column to denote unlabeled
82+ data.
83+ negative_data_value: The value used in the label column to denote negative
84+ data.
7785 verbose: Boolean denoting whether to send model performance and
7886 pseudo-labeling metrics to the GCP console.
7987 ensemble: A trained ensemble of one class classifiers.
@@ -90,6 +98,10 @@ def __init__(
9098 positive_threshold : float = 1.0 ,
9199 negative_threshold : float = 95.0 ,
92100 random_seed : int = _RANDOM_SEED ,
101+ unlabeled_record_count : int | None = None ,
102+ negative_record_count : int | None = None ,
103+ unlabeled_data_value : int | None = None ,
104+ negative_data_value : int | None = None ,
93105 verbose : bool = False ,
94106 ) -> None :
95107 self .n_components = n_components
@@ -100,6 +112,10 @@ def __init__(
100112 self .positive_threshold = positive_threshold
101113 self .negative_threshold = negative_threshold
102114 self ._random_seed = random_seed
115+ self .unlabeled_record_count = unlabeled_record_count
116+ self .negative_record_count = negative_record_count
117+ self .unlabeled_data_value = unlabeled_data_value
118+ self .negative_data_value = negative_data_value
103119 self .verbose = verbose
104120
105121 self .ensemble = []
@@ -121,6 +137,38 @@ def _get_model(self) -> mixture.GaussianMixture:
121137 random_state = self ._random_seed ,
122138 )
123139
140+ def _get_filter_by_label_value_func (self , label_column_filter_value : int ):
141+ """Returns a function that filters a record based on the label column value.
142+
143+ Args:
144+ label_column_filter_value: The value of the label column to use as a
145+ filter. If None, all records are included.
146+
147+ Returns:
148+ A function that returns True if the label column value is equal to the
149+ label_column_filter_value parameter.
150+ """
151+
152+ def filter_func (features : tf .Tensor , label : tf .Tensor ) -> bool : # pylint: disable=unused-argument
153+ if label_column_filter_value is None :
154+ return True
155+ label_cast = tf .cast (label , tf .dtypes .as_dtype (_LABEL_TYPE .lower ()))
156+ label_column_filter_value_cast = tf .cast (
157+ label_column_filter_value , label_cast .dtype
158+ )
159+ broadcast_equal = tf .equal (label_column_filter_value_cast , label_cast )
160+ return tf .reduce_all (broadcast_equal )
161+
162+ return filter_func
163+
164+ def is_batched (self , dataset : tf .data .Dataset ) -> bool :
165+ """Returns True if the dataset is batched."""
166+ # This suffices for the current use case of the OCC ensemble.
167+ return len (dataset .element_spec [0 ].shape ) == 2 and (
168+ dataset .element_spec [0 ].shape [0 ] is None
169+ or isinstance (dataset .element_spec [0 ].shape [0 ], int )
170+ )
171+
124172 def fit (
125173 self , train_x : tf .data .Dataset , batches_per_occ : int
126174 ) -> Sequence [mixture .GaussianMixture ]:
@@ -142,15 +190,73 @@ def fit(
142190 if batches_per_occ > 1 :
143191 self ._warm_start = True
144192
145- dataset_iterator = train_x .as_numpy_iterator ()
193+ has_batches = self .is_batched (train_x )
194+ logging .info ('has_batches is %s' , has_batches )
195+ negative_features = None
196+
197+ if (
198+ not self .unlabeled_record_count
199+ or not self .negative_record_count
200+ or not has_batches
201+ or self .unlabeled_data_value is None
202+ or self .negative_data_value is None
203+ ):
204+ # Either the dataset is not batched, or we don't have all the details to
205+ # extract the negative-labeled data. Hence we will use all the data for
206+ # training.
207+ dataset_iterator = train_x .as_numpy_iterator ()
208+ else :
209+ # We unbatch the dataset so that we can separate-out the unlabeled and
210+ # negative data points
211+ ds_unbatched = train_x .unbatch ()
212+
213+ ds_unlabeled = ds_unbatched .filter (
214+ self ._get_filter_by_label_value_func (self .unlabeled_data_value )
215+ )
216+
217+ ds_negative = ds_unbatched .filter (
218+ self ._get_filter_by_label_value_func (self .negative_data_value )
219+ )
220+
221+ negative_features_and_labels_zip = list (
222+ zip (* ds_negative .as_numpy_iterator ())
223+ )
224+
225+ negative_features = (
226+ negative_features_and_labels_zip [0 ]
227+ if len (negative_features_and_labels_zip ) == 2
228+ else None
229+ )
230+
231+ if negative_features is None :
232+ # The negative features were not extracted. This can happen when the
233+ # dataset elements are not tuples of features and labels. So we will use
234+ # all the data for training.
235+ ds_batched = train_x
236+ else :
237+ # The negative features were extracted. How we can proceed with creating
238+ # batches of unlabeled data, to which the negative data will be added
239+ # before training.
240+ batch_size = (
241+ self .unlabeled_record_count // self .ensemble_count
242+ ) // batches_per_occ
243+ ds_batched = ds_unlabeled .batch (
244+ batch_size ,
245+ drop_remainder = False ,
246+ )
247+ dataset_iterator = ds_batched .as_numpy_iterator ()
146248
147249 for _ in range (self .ensemble_count ):
148250 model = self ._get_model ()
149251
150252 for _ in range (batches_per_occ ):
151- features , labels = dataset_iterator .next ()
152- del labels # Not needed for this task.
153- model .fit (features )
253+ features , _ = dataset_iterator .next ()
254+ all_features = (
255+ np .concatenate ([features , negative_features ], axis = 0 )
256+ if negative_features is not None
257+ else features
258+ )
259+ model .fit (all_features )
154260
155261 self .ensemble .append (model )
156262
0 commit comments