Skip to content

Commit 56c44d3

Browse files
callzhanghankcs
authored andcommitted
implemented multi-label support
Revert "minor fix" This reverts commit 91e9847. On branch master Your branch is up to date with 'origin/master'. Changes to be committed: modified: hanlp/common/component.py modified: hanlp/layers/transformers/loader.py (cherry picked from commit 7bae452) minor fix (cherry picked from commit 91e9847) minor fix (cherry picked from commit d4104d7) multi-label support cherry picked to master (cherry picked from commit 62f7b3d) implemented multi-label support (cherry picked from commit a844481)
1 parent 4ea9595 commit 56c44d3

File tree

5 files changed

+65
-18
lines changed

5 files changed

+65
-18
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,4 +284,5 @@ fabric.properties
284284
.idea/caches/build_file_checksums.ser
285285
.idea
286286
*.iml
287-
data
287+
data
288+
.vscode/settings.json

hanlp/common/vocab.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def update(self, tokens: Iterable[str]) -> None:
7979
self.add(token)
8080

8181
def get_idx(self, token: str) -> int:
82-
idx = self.token_to_idx.get(token, None)
82+
if type(token) is list:
83+
idx = [self.get_idx(t) for t in token]
84+
else:
85+
idx = self.token_to_idx.get(token, None)
8386
if idx is None:
8487
if self.mutable:
8588
idx = len(self.token_to_idx)

hanlp/components/classifiers/transformer_classifier.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
from hanlp.transform.table import TableTransform
1616
from hanlp.utils.log_util import logger
1717
from hanlp.utils.util import merge_locals_kwargs
18+
import numpy as np
1819

1920

2021
class TransformerTextTransform(TableTransform):
2122

2223
def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None,
23-
y_column=-1, skip_header=True, delimiter='auto', **kwargs) -> None:
24-
super().__init__(config, map_x, map_y, x_columns, y_column, skip_header, delimiter, **kwargs)
24+
y_column=-1, skip_header=True, delimiter='auto', multi_label=False, **kwargs) -> None:
25+
super().__init__(config, map_x, map_y, x_columns, y_column, multi_label, skip_header, delimiter, **kwargs)
2526
self.tokenizer: FullTokenizer = None
2627

2728
def inputs_to_samples(self, inputs, gold=False):
@@ -61,26 +62,40 @@ def inputs_to_samples(self, inputs, gold=False):
6162
segment_ids += [0] * diff
6263

6364
assert len(token_ids) == max_length, "Error with input length {} vs {}".format(len(token_ids), max_length)
64-
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),
65-
max_length)
66-
assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids),
67-
max_length)
65+
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
66+
assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids), max_length)
67+
68+
6869
label = Y
6970
yield (token_ids, attention_mask, segment_ids), label
7071

7172
def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
7273
max_length = self.config.max_length
7374
types = (tf.int32, tf.int32, tf.int32), tf.string
74-
shapes = ([max_length], [max_length], [max_length]), []
75+
shapes = ([max_length], [max_length], [max_length]), [None,] if self.config.multi_label else []
7576
values = (0, 0, 0), self.label_vocab.safe_pad_token
7677
return types, shapes, values
7778

7879
def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
7980
logger.fatal('map_x should always be set to True')
8081
exit(1)
8182

83+
def y_to_idx(self, y) -> tf.Tensor:
84+
if self.config.multi_label:
85+
#need to change index to binary vector
86+
mapped = tf.map_fn(fn=lambda x: tf.cast(self.label_vocab.lookup(x), tf.int32), elems=y, fn_output_signature=tf.TensorSpec(dtype=tf.dtypes.int32, shape=[None,]))
87+
one_hots = tf.one_hot(mapped, len(self.label_vocab))
88+
idx = tf.reduce_sum(one_hots, -2)
89+
else:
90+
idx = self.label_vocab.lookup(y)
91+
return idx
92+
8293
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
83-
preds = tf.argmax(Y, axis=-1)
94+
# Prediction to be Y > 0:
95+
if self.config.multi_label:
96+
preds = Y
97+
else:
98+
preds = tf.argmax(Y, axis=-1)
8499
for y in preds:
85100
yield self.label_vocab.idx_to_token[y]
86101

@@ -126,7 +141,14 @@ def _y_id_to_str(self, Y_pred) -> str:
126141
return self.transform.label_vocab.idx_to_token[Y_pred.numpy()]
127142

128143
def build_loss(self, loss, **kwargs):
129-
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
144+
if loss:
145+
assert isinstance(loss, tf.keras.losses.loss), 'Must specify loss as an instance in tf.keras.losses'
146+
return loss
147+
elif self.config.multi_label:
148+
#Loss to be BinaryCrossentropy for multi-label:
149+
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
150+
else:
151+
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
130152
return loss
131153

132154
# noinspection PyMethodOverriding
@@ -158,3 +180,10 @@ def build_vocab(self, trn_data, logger):
158180
warmup_steps_per_epoch = math.ceil(train_examples * self.config.warmup_steps_ratio / self.config.batch_size)
159181
self.config.warmup_steps = warmup_steps_per_epoch * self.config.epochs
160182
return train_examples
183+
184+
def build_metrics(self, metrics, logger, **kwargs):
185+
if self.config.multi_label:
186+
metric = tf.keras.metrics.BinaryAccuracy('binary_accuracy')
187+
else:
188+
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
189+
return [metric]

hanlp/transform/table.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Date: 2019-11-10 21:00
44
from abc import ABC
55
from typing import Tuple, Union
6-
6+
import numpy as np
77
import tensorflow as tf
88

99
from hanlp.common.structure import SerializableDict
@@ -16,9 +16,9 @@
1616

1717
class TableTransform(Transform, ABC):
1818
def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None,
19-
y_column=-1,
19+
y_column=-1, multi_label=False,
2020
skip_header=True, delimiter='auto', **kwargs) -> None:
21-
super().__init__(config, map_x, map_y, x_columns=x_columns, y_column=y_column,
21+
super().__init__(config, map_x, map_y, x_columns=x_columns, y_column=y_column, multi_label=multi_label,
2222
skip_header=skip_header,
2323
delimiter=delimiter, **kwargs)
2424
self.label_vocab = create_label_vocab()
@@ -28,6 +28,9 @@ def file_to_inputs(self, filepath: str, gold=True):
2828
y_column = self.config.y_column
2929
num_features = self.config.get('num_features', None)
3030
for cells in read_cells(filepath, skip_header=self.config.skip_header, delimiter=self.config.delimiter):
31+
#multi-label: Dataset in .tsv format: x_columns: at most 2 columns being a sentence pair while in most
32+
# cases just one column being the doc content. y_column being the single label, which shall be modified
33+
# to load a list of labels.
3134
if x_columns:
3235
inputs = tuple(c for i, c in enumerate(cells) if i in x_columns), cells[y_column]
3336
else:
@@ -37,6 +40,15 @@ def file_to_inputs(self, filepath: str, gold=True):
3740
if num_features is None:
3841
num_features = len(inputs[0])
3942
self.config.num_features = num_features
43+
# multi-label support
44+
if self.config.multi_label:
45+
assert type(inputs[1]) is str, 'Y value has to be string'
46+
if inputs[1][0] == '[':
47+
# multi-label is in literal form of a list
48+
labels = eval(inputs[1])
49+
else:
50+
labels = inputs[1].strip().split(',')
51+
inputs = inputs[0], labels
4052
else:
4153
assert num_features == len(inputs[0]), f'Numbers of columns {num_features} ' \
4254
f'inconsistent with current {len(inputs[0])}'
@@ -56,7 +68,11 @@ def y_to_idx(self, y) -> tf.Tensor:
5668
def fit(self, trn_path: str, **kwargs):
5769
samples = 0
5870
for t in self.file_to_samples(trn_path, gold=True):
59-
self.label_vocab.add(t[1]) # the second one regardless of t is pair or triple
71+
if self.config.multi_label:
72+
for l in t[1]:
73+
self.label_vocab.add(l)
74+
else:
75+
self.label_vocab.add(t[1]) # the second one regardless of t is pair or triple
6076
samples += 1
6177
return samples
6278

hanlp/utils/tf_util.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111

1212

1313
def size_of_dataset(dataset: tf.data.Dataset) -> int:
14-
count = 0
15-
for element in dataset.unbatch().batch(1):
16-
count += 1
14+
count = len(list(dataset.unbatch().as_numpy_iterator()))
1715
return count
1816

1917

0 commit comments

Comments
 (0)