From d2bf6dfb020f5d22683b37ba84d38cbfb8f3b15e Mon Sep 17 00:00:00 2001
From: Benoit Favre <benoit.favre@lis-lab.fr>
Date: Tue, 2 Jun 2020 10:35:09 +0200
Subject: [PATCH] add compatibility with bibliovid data

---
 data.py  | 16 +++++++++++++++-
 model.py | 37 ++++++++++++++++++++++++++++---------
 2 files changed, 43 insertions(+), 10 deletions(-)

diff --git a/data.py b/data.py
index 6891e11..ef21491 100644
--- a/data.py
+++ b/data.py
@@ -24,10 +24,17 @@ def to_int(tokenizer, label_vocab, hparams, dataset):
   int_labels = []
   sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
 
-  for article in dataset:
+  for i, article in enumerate(dataset):
+    for feature in hparams.selected_features:
+      if feature not in article or article[feature] is None:
+        article[feature] = ''
     text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
     int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
     int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels])
+    if hparams.augment_data and i > hparams.valid_size: # don't forget to skip valid set
+      text = ' | '.join([''.join(article[feature] if feature != 'abstract' else '') for feature in hparams.selected_features])
+      int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
+      int_labels.append([1 if 'topics' in article and label in article['topics'] else 0 for label in sorted_labels])
 
   return int_texts, int_labels
 
@@ -46,6 +53,11 @@ def load(tokenizer, hparams):
   label_vocab = dict(label_vocab)
 
   dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
+  assert len(dataset) > 0
+
+  hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0)
+  assert hparams.valid_size > 0
+
   missing_abstracts = 0
   for article in dataset:
     if 'abstract' not in article or article['abstract'] == []:
@@ -59,5 +71,7 @@ def load(tokenizer, hparams):
 
   train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
   valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
+  print('training set', len(train_set))
+  print('valid set', len(valid_set))
 
   return train_set, valid_set, label_vocab
diff --git a/model.py b/model.py
index a60163b..458af4a 100644
--- a/model.py
+++ b/model.py
@@ -1,4 +1,5 @@
 from argparse import ArgumentParser
+import sys
 
 import torch
 import torch.nn as nn
@@ -25,8 +26,8 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
   recall = tp / (tp + fn + epsilon)
 
   f1 = 2 * (precision * recall) / (precision + recall + epsilon)
-  #f1 = f1.clamp(min=epsilon, max=1 - epsilon)
-  f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
+  f1 = f1.clamp(min=epsilon, max=1 - epsilon)
+  #f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
 
   return 1 - f1.mean()
 
@@ -65,17 +66,34 @@ class Model(LightningModule):
     x, y = batch
     y_hat = self(x)
     loss = self.loss_function(y_hat, y)
-    num_correct = torch.sum((y_hat >= 0) == y)
-    return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]}
+    bce = F.binary_cross_entropy_with_logits(y_hat, y)
+    num_correct = torch.sum((y_hat >= 0) * (y == 1))
+    num_hyp = torch.sum(y_hat >= 0)
+    num_ref = torch.sum(y == 1)
+    num = torch.tensor([y.shape[0] * y.shape[1]])
+    return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num}
 
   def training_epoch_end(self, outputs):
     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
     return {'loss': avg_loss, 'log': {'loss': avg_loss}}
 
   def validation_epoch_end(self, outputs):
-    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
-    accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs])
-    return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}}
+    metrics = outputs[0].keys()
+    values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics}
+
+    avg_loss = values['val_loss'].mean()
+
+    bce = values['bce'].mean()
+    num_correct = values['num_correct'].sum()
+    num = values['num'].sum()
+    accuracy = num_correct / float(num.item())
+    num_ref = values['num_ref'].sum()
+    num_hyp = values['num_hyp'].sum()
+    recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
+    precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0])
+    fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])
+
+    return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}}
 
   def configure_optimizers(self):
     return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@@ -100,13 +118,14 @@ class Model(LightningModule):
     parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
     parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
     parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
-    parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)')
-    parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)')
+    parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
+    parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in % (default=10)')
     parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
     parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
     parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
     parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
     parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
+    parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
 
     return parser
 
-- 
GitLab