From d275fd7d0d811bfacbfa6febbf62ab653c1ec98e Mon Sep 17 00:00:00 2001 From: Benoit Favre <benoit.favre@lis-lab.fr> Date: Fri, 29 May 2020 21:56:29 +0200 Subject: [PATCH] fix prediction and add f-score loss --- data.py | 2 +- model.py | 7 ++++--- predict.py | 3 +-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data.py b/data.py index 8958f60..6891e11 100644 --- a/data.py +++ b/data.py @@ -27,7 +27,7 @@ def to_int(tokenizer, label_vocab, hparams, dataset): for article in dataset: text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features]) int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len)) - int_labels.append([1 if label in 'topics' in article and article['topics'] else 0 for label in sorted_labels]) + int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels]) return int_texts, int_labels diff --git a/model.py b/model.py index 7cc9e3a..a60163b 100644 --- a/model.py +++ b/model.py @@ -12,7 +12,7 @@ from transformers import AutoModel import data # based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric -def f1_score_binary(y_pred, y_true, epsilon=1e-7): +def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7): y_pred = torch.sigmoid(y_pred) y_true = y_true.float() @@ -25,7 +25,8 @@ def f1_score_binary(y_pred, y_true, epsilon=1e-7): recall = tp / (tp + fn + epsilon) f1 = 2 * (precision * recall) / (precision + recall + epsilon) - f1 = f1.clamp(min=epsilon, max=1 - epsilon) + #f1 = f1.clamp(min=epsilon, max=1 - epsilon) + f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1) return 1 - f1.mean() @@ -46,7 +47,7 @@ class Model(LightningModule): if self.hparams.loss == 'bce': self.loss_function = F.binary_cross_entropy_with_logits elif self.hparams.loss == 'f1': - self.loss_function = f1_score_binary + self.loss_function = binary_f1_score_with_logits else: raise ValueError('invalid loss "%s"' % self.hparams.loss) diff --git a/predict.py b/predict.py index f71641a..67fc6f4 100644 --- a/predict.py +++ b/predict.py @@ -34,14 +34,13 @@ def main(hparams): #y = y.to(device) with torch.no_grad(): y_scores = model(x) - y_pred = y_scores > 0 predictions.extend(y_scores.cpu().tolist()) return predictions predictions = generate_predictions(model, test_loader) for i, article in enumerate(dataset): article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])} - article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score > 0] + article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score >= 0] print(json.dumps(dataset, indent=2)) -- GitLab