1515from hanlp .transform .table import TableTransform
1616from hanlp .utils .log_util import logger
1717from hanlp .utils .util import merge_locals_kwargs
18+ import numpy as np
1819
1920
2021class TransformerTextTransform (TableTransform ):
2122
2223 def __init__ (self , config : SerializableDict = None , map_x = False , map_y = True , x_columns = None ,
23- y_column = - 1 , skip_header = True , delimiter = 'auto' , ** kwargs ) -> None :
24- super ().__init__ (config , map_x , map_y , x_columns , y_column , skip_header , delimiter , ** kwargs )
24+ y_column = - 1 , skip_header = True , delimiter = 'auto' , multi_label = False , ** kwargs ) -> None :
25+ super ().__init__ (config , map_x , map_y , x_columns , y_column , multi_label , skip_header , delimiter , ** kwargs )
2526 self .tokenizer : FullTokenizer = None
2627
2728 def inputs_to_samples (self , inputs , gold = False ):
@@ -61,26 +62,40 @@ def inputs_to_samples(self, inputs, gold=False):
6162 segment_ids += [0 ] * diff
6263
6364 assert len (token_ids ) == max_length , "Error with input length {} vs {}" .format (len (token_ids ), max_length )
64- assert len (attention_mask ) == max_length , "Error with input length {} vs {}" .format (len (attention_mask ),
65- max_length )
66- assert len ( segment_ids ) == max_length , "Error with input length {} vs {}" . format ( len ( segment_ids ),
67- max_length )
65+ assert len (attention_mask ) == max_length , "Error with input length {} vs {}" .format (len (attention_mask ), max_length )
66+ assert len ( segment_ids ) == max_length , "Error with input length {} vs {}" . format ( len ( segment_ids ), max_length )
67+
68+
6869 label = Y
6970 yield (token_ids , attention_mask , segment_ids ), label
7071
7172 def create_types_shapes_values (self ) -> Tuple [Tuple , Tuple , Tuple ]:
7273 max_length = self .config .max_length
7374 types = (tf .int32 , tf .int32 , tf .int32 ), tf .string
74- shapes = ([max_length ], [max_length ], [max_length ]), []
75+ shapes = ([max_length ], [max_length ], [max_length ]), [None ,] if self . config . multi_label else [ ]
7576 values = (0 , 0 , 0 ), self .label_vocab .safe_pad_token
7677 return types , shapes , values
7778
7879 def x_to_idx (self , x ) -> Union [tf .Tensor , Tuple ]:
7980 logger .fatal ('map_x should always be set to True' )
8081 exit (1 )
8182
83+ def y_to_idx (self , y ) -> tf .Tensor :
84+ if self .config .multi_label :
85+ #need to change index to binary vector
86+ mapped = tf .map_fn (fn = lambda x : tf .cast (self .label_vocab .lookup (x ), tf .int32 ), elems = y , fn_output_signature = tf .TensorSpec (dtype = tf .dtypes .int32 , shape = [None ,]))
87+ one_hots = tf .one_hot (mapped , len (self .label_vocab ))
88+ idx = tf .reduce_sum (one_hots , - 2 )
89+ else :
90+ idx = self .label_vocab .lookup (y )
91+ return idx
92+
8293 def Y_to_outputs (self , Y : Union [tf .Tensor , Tuple [tf .Tensor ]], gold = False , inputs = None , X = None , batch = None ) -> Iterable :
83- preds = tf .argmax (Y , axis = - 1 )
94+ # Prediction to be Y > 0:
95+ if self .config .multi_label :
96+ preds = Y
97+ else :
98+ preds = tf .argmax (Y , axis = - 1 )
8499 for y in preds :
85100 yield self .label_vocab .idx_to_token [y ]
86101
@@ -126,7 +141,14 @@ def _y_id_to_str(self, Y_pred) -> str:
126141 return self .transform .label_vocab .idx_to_token [Y_pred .numpy ()]
127142
128143 def build_loss (self , loss , ** kwargs ):
129- loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True )
144+ if loss :
145+ assert isinstance (loss , tf .keras .losses .loss ), 'Must specify loss as an instance in tf.keras.losses'
146+ return loss
147+ elif self .config .multi_label :
148+ #Loss to be BinaryCrossentropy for multi-label:
149+ loss = tf .keras .losses .BinaryCrossentropy (from_logits = True )
150+ else :
151+ loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True )
130152 return loss
131153
132154 # noinspection PyMethodOverriding
@@ -158,3 +180,10 @@ def build_vocab(self, trn_data, logger):
158180 warmup_steps_per_epoch = math .ceil (train_examples * self .config .warmup_steps_ratio / self .config .batch_size )
159181 self .config .warmup_steps = warmup_steps_per_epoch * self .config .epochs
160182 return train_examples
183+
184+ def build_metrics (self , metrics , logger , ** kwargs ):
185+ if self .config .multi_label :
186+ metric = tf .keras .metrics .BinaryAccuracy ('binary_accuracy' )
187+ else :
188+ metric = tf .keras .metrics .SparseCategoricalAccuracy ('accuracy' )
189+ return [metric ]
0 commit comments