Skip to content
Snippets Groups Projects
Commit d275fd7d authored by Benoit Favre's avatar Benoit Favre
Browse files

fix prediction and add f-score loss

parent 7a7cbb6b
Branches
No related tags found
No related merge requests found
...@@ -27,7 +27,7 @@ def to_int(tokenizer, label_vocab, hparams, dataset): ...@@ -27,7 +27,7 @@ def to_int(tokenizer, label_vocab, hparams, dataset):
for article in dataset: for article in dataset:
text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features]) text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len)) int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
int_labels.append([1 if label in 'topics' in article and article['topics'] else 0 for label in sorted_labels]) int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels])
return int_texts, int_labels return int_texts, int_labels
......
...@@ -12,7 +12,7 @@ from transformers import AutoModel ...@@ -12,7 +12,7 @@ from transformers import AutoModel
import data import data
# based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric # based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
def f1_score_binary(y_pred, y_true, epsilon=1e-7): def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
y_pred = torch.sigmoid(y_pred) y_pred = torch.sigmoid(y_pred)
y_true = y_true.float() y_true = y_true.float()
...@@ -25,7 +25,8 @@ def f1_score_binary(y_pred, y_true, epsilon=1e-7): ...@@ -25,7 +25,8 @@ def f1_score_binary(y_pred, y_true, epsilon=1e-7):
recall = tp / (tp + fn + epsilon) recall = tp / (tp + fn + epsilon)
f1 = 2 * (precision * recall) / (precision + recall + epsilon) f1 = 2 * (precision * recall) / (precision + recall + epsilon)
f1 = f1.clamp(min=epsilon, max=1 - epsilon) #f1 = f1.clamp(min=epsilon, max=1 - epsilon)
f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
return 1 - f1.mean() return 1 - f1.mean()
...@@ -46,7 +47,7 @@ class Model(LightningModule): ...@@ -46,7 +47,7 @@ class Model(LightningModule):
if self.hparams.loss == 'bce': if self.hparams.loss == 'bce':
self.loss_function = F.binary_cross_entropy_with_logits self.loss_function = F.binary_cross_entropy_with_logits
elif self.hparams.loss == 'f1': elif self.hparams.loss == 'f1':
self.loss_function = f1_score_binary self.loss_function = binary_f1_score_with_logits
else: else:
raise ValueError('invalid loss "%s"' % self.hparams.loss) raise ValueError('invalid loss "%s"' % self.hparams.loss)
......
...@@ -34,14 +34,13 @@ def main(hparams): ...@@ -34,14 +34,13 @@ def main(hparams):
#y = y.to(device) #y = y.to(device)
with torch.no_grad(): with torch.no_grad():
y_scores = model(x) y_scores = model(x)
y_pred = y_scores > 0
predictions.extend(y_scores.cpu().tolist()) predictions.extend(y_scores.cpu().tolist())
return predictions return predictions
predictions = generate_predictions(model, test_loader) predictions = generate_predictions(model, test_loader)
for i, article in enumerate(dataset): for i, article in enumerate(dataset):
article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])} article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])}
article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score > 0] article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score >= 0]
print(json.dumps(dataset, indent=2)) print(json.dumps(dataset, indent=2))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment