|
36 | 36 | """ |
37 | 37 |
|
38 | 38 | import enum |
39 | | -# TODO(b/247116870): Change to collections when Vertex supports python 3.9 |
40 | 39 | from typing import Mapping, Optional, Tuple, cast |
41 | 40 |
|
42 | 41 | from absl import logging |
|
49 | 48 | from spade_anomaly_detection import supervised_model |
50 | 49 | import tensorflow as tf |
51 | 50 |
|
| 51 | +# TODO(b/247116870): Change to collections when Vertex supports python 3.9 |
| 52 | + |
52 | 53 |
|
53 | 54 | @enum.unique |
54 | 55 | class DataFormat(enum.Enum): |
@@ -135,6 +136,7 @@ def __init__(self, runner_parameters: parameters.RunnerParameters): |
135 | 136 | else: |
136 | 137 | self.supervised_model_object = None |
137 | 138 |
|
| 139 | + # If the thresholds are not set, use the thresholds from the input table. |
138 | 140 | if ( |
139 | 141 | self.runner_parameters.positive_threshold is None |
140 | 142 | or self.runner_parameters.negative_threshold is None |
@@ -760,7 +762,7 @@ def run(self) -> None: |
760 | 762 | batch_size=1, |
761 | 763 | ) |
762 | 764 | train_label_counts = self.input_data_loader.label_counts |
763 | | - # TODO(sinharaj): This is not ideal, we should not need to read the files |
| 765 | + # This is not ideal, we should not need to read the files |
764 | 766 | # again. Find a way to get the label counts without reading the files. |
765 | 767 | # Assumes that data loader has already been used to read the input table. |
766 | 768 | total_record_count = sum(train_label_counts.values()) |
@@ -885,6 +887,7 @@ def run(self) -> None: |
885 | 887 | labels=updated_labels, |
886 | 888 | weights=weights, |
887 | 889 | ) |
| 890 | + # End of pseudolabeling and supervised model training loop. |
888 | 891 |
|
889 | 892 | if not self.runner_parameters.upload_only: |
890 | 893 | self.evaluate_model() |
|
0 commit comments