diff --git a/README.md b/README.md
index 1923dbf297d29c4ad380ee1815b5d6245ef15693..9bb4e63761619169fbf68134e071980322c4ab3f 100644
--- a/README.md
+++ b/README.md
@@ -1,18 +1,65 @@
 Topic classifier for biomedical articles
 ========================================
 
+Multilabel topic classifier for medical articles.
+
+This system learns a topic classifier based for articles labelelled with multiple topics.
+The included model uses a variant of BERT pre-trained on medical texts, and finetunes it on task instances.
+
+Data
+----
+
+Input data is expected to be a json-formatted file containing a list of articles. Each article
+should have a title, an abstract and a topics field containing a list of topics.
+
+
 Installing
 ----------
 
 ```
 virtualenv -p python3 env
-source env/bin/activated
+source env/bin/activate
 pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
 ```
 
-Running
--------
+
+Training
+--------
+
+```
+python trainier.py [options]
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --gpus GPUS
+  --nodes NODES
+  --name NAME
+  --fast_dev_run
+  --train_filename TRAIN_FILENAME
+  --learning_rate LEARNING_RATE
+  --batch_size BATCH_SIZE
+  --epochs EPOCHS
+  --valid_size VALID_SIZE
+  --max_len MAX_LEN
+  --bert_flavor BERT_FLAVOR
+  --selected_features SELECTED_FEATURES
+```
+
+Example training command line:
 
 ```
 python trainer.py --gpus=-1 --name test1 --train_filename ../scrappers/data/20200529/litcovid.json
 ```
+
+pytorch-lightning provides a tensorboard logger. You can check it with
+```
+tensorboard --logdir lightning_logs
+```
+Then point your browser to http://localhost:6006/.
+
+Generating predictions
+----------------------
+
+```
+predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json
+```
diff --git a/data.py b/data.py
index 185e092960968d4e6efd15115ce7d541a7d9e3bf..8958f60c959d65b1ec135c31ff4942f703dc7d62 100644
--- a/data.py
+++ b/data.py
@@ -16,8 +16,20 @@ class CustomDataset(Dataset):
   def __len__(self):
     return len(self.labels)
 
-def bert_text_to_ids(tokenizer, sentence):
-  return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True))
+def bert_text_to_ids(tokenizer, sentence, max_len):
+  return torch.tensor(tokenizer.encode(sentence, add_special_tokens=True, max_length=max_len))
+
+def to_int(tokenizer, label_vocab, hparams, dataset):
+  int_texts = []
+  int_labels = []
+  sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
+
+  for article in dataset:
+    text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
+    int_texts.append(bert_text_to_ids(tokenizer, text, hparams.max_len))
+    int_labels.append([1 if label in 'topics' in article and article['topics'] else 0 for label in sorted_labels])
+
+  return int_texts, int_labels
 
 def load(tokenizer, hparams):
 
@@ -30,26 +42,20 @@ def load(tokenizer, hparams):
     if 'topics' in article:
       for topic in article['topics']:
         label_vocab[topic]
+
   label_vocab = dict(label_vocab)
 
   dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
+  missing_abstracts = 0
   for article in dataset:
     if 'abstract' not in article or article['abstract'] == []:
       article['abstract'] = ['']
+      missing_abstracts += 1
+  print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset)))
 
   random.shuffle(dataset)
 
-  sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
-
-  texts = []
-  int_texts = []
-  int_labels = []
-
-  for article in dataset:
-    text = ' | '.join([''.join(article[feature]) for feature in hparams.selected_features])
-    texts.append(text)
-    int_texts.append(bert_text_to_ids(tokenizer, text)[:hparams.max_len])
-    int_labels.append([1 if label in article['topics'] else 0 for label in sorted_labels])
+  int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset)
 
   train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
   valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
diff --git a/model.py b/model.py
index 0fdb401fc34d306552e60435815d8c7bfe88dc2b..7cc9e3aa9a2fb16cdae089dfebb15dbc1a10c88c 100644
--- a/model.py
+++ b/model.py
@@ -11,6 +11,25 @@ from transformers import AutoModel
 
 import data
 
+# based on https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
+def f1_score_binary(y_pred, y_true, epsilon=1e-7):
+  y_pred = torch.sigmoid(y_pred)
+  y_true = y_true.float()
+  
+  tp = (y_true * y_pred).sum(dim=0).float()
+  tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float()
+  fp = ((1 - y_true) * y_pred).sum(dim=0).float()
+  fn = (y_true * (1 - y_pred)).sum(dim=0).float()
+
+  precision = tp / (tp + fp + epsilon)
+  recall = tp / (tp + fn + epsilon)
+
+  f1 = 2 * (precision * recall) / (precision + recall + epsilon)
+  f1 = f1.clamp(min=epsilon, max=1 - epsilon)
+
+  return 1 - f1.mean()
+
+
 class Model(LightningModule):
 
   def __init__(self, hparams):
@@ -23,26 +42,39 @@ class Model(LightningModule):
 
     self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
     self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
+    self.dropout = nn.Dropout(hparams.dropout)
+    if self.hparams.loss == 'bce':
+      self.loss_function = F.binary_cross_entropy_with_logits
+    elif self.hparams.loss == 'f1':
+      self.loss_function = f1_score_binary
+    else:
+      raise ValueError('invalid loss "%s"' % self.hparams.loss)
 
   def forward(self, x):
     _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
-    return self.decision(output)
+    return self.decision(F.gelu(self.dropout(output)))
 
   def training_step(self, batch, batch_idx):
     x, y = batch
     y_hat = self(x)
-    loss = F.binary_cross_entropy_with_logits(y_hat, y)
+    loss = self.loss_function(y_hat, y)
     return {'loss': loss}
 
   def validation_step(self, batch, batch_idx):
     x, y = batch
     y_hat = self(x)
-    loss = F.binary_cross_entropy_with_logits(y_hat, y)
-    return {'val_loss': loss}
+    loss = self.loss_function(y_hat, y)
+    num_correct = torch.sum((y_hat >= 0) == y)
+    return {'val_loss': loss, 'val_correct': num_correct, 'val_num': y.shape[0] * y.shape[1]}
+
+  def training_epoch_end(self, outputs):
+    avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+    return {'loss': avg_loss, 'log': {'loss': avg_loss}}
 
   def validation_epoch_end(self, outputs):
-    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean().item()
-    return {'val_loss': avg_loss}
+    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
+    accuracy = torch.stack([x['val_correct'] for x in outputs]).sum().item() / sum([x['val_num'] for x in outputs])
+    return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'accuracy': torch.tensor([accuracy])}}
 
   def configure_optimizers(self):
     return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@@ -64,14 +96,16 @@ class Model(LightningModule):
   @staticmethod
   def add_model_specific_args(parent_parser):
     parser = ArgumentParser(parents=[parent_parser])
-    parser.add_argument('--train_filename', default='litcovid.json', type=str)
-    parser.add_argument('--learning_rate', default=2e-5, type=float)
-    parser.add_argument('--batch_size', default=16, type=int)
-    parser.add_argument('--epochs', default=10, type=int)
-    parser.add_argument('--valid_size', default=300, type=int)
-    parser.add_argument('--max_len', default=384, type=int)
-    parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str)
-    parser.add_argument('--selected_features', default=['title', 'abstract'], type=list)
+    parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
+    parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
+    parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
+    parser.add_argument('--epochs', default=10, type=int, help='number of epochs (default=10)')
+    parser.add_argument('--valid_size', default=1000, type=int, help='validation set size (default=1000)')
+    parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
+    parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
+    parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
+    parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
+    parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
 
     return parser
 
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71641acacce11fb57fb46c817f805f0729babe4
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,58 @@
+from argparse import ArgumentParser
+import json
+
+from pytorch_lightning import LightningModule
+from torch.utils.data import DataLoader
+import torch
+
+import data
+from model import Model
+
+def main(hparams):
+  model = Model.load_from_checkpoint(hparams.checkpoint)
+  model.freeze()
+
+  label_vocab = model.label_vocab
+  sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
+
+  with open(hparams.test_filename) as fp:
+    dataset = json.loads(fp.read())
+
+  dataset = dataset[:32]
+
+  int_texts, int_labels = data.to_int(model.tokenizer, model.label_vocab, model.hparams, dataset)
+
+  test_set = data.CustomDataset(int_texts, int_labels)
+  test_loader = DataLoader(test_set, batch_size=hparams.batch_size, collate_fn=model.collate_fn, shuffle=False)
+
+  def generate_predictions(model, loader):
+    predictions = []
+    model.eval()
+    total_loss = num = correct = 0
+    for x, y in loader:
+      #x = x.to(device)
+      #y = y.to(device)
+      with torch.no_grad():
+        y_scores = model(x)
+        y_pred = y_scores > 0
+        predictions.extend(y_scores.cpu().tolist())
+    return predictions
+
+  predictions = generate_predictions(model, test_loader)
+  for i, article in enumerate(dataset):
+    article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])}
+    article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score > 0]
+
+  print(json.dumps(dataset, indent=2))
+
+
+if __name__ == '__main__':
+  parser = ArgumentParser(add_help=False)
+  #parser.add_argument('--gpus', type=str, default=None)
+  parser.add_argument('--checkpoint', type=str, required=True)
+  parser.add_argument('--test_filename', type=str, required=True)
+  parser.add_argument('--batch_size', type=int, default=32)
+
+  hparams = parser.parse_args()
+  main(hparams)
+
diff --git a/trainer.py b/trainer.py
index 1177b80a1a8b5616ce2c97569bfbe72fbe61c1ea..b8fd02b1717cb7efd8da4741bdedc7e91d4864ee 100644
--- a/trainer.py
+++ b/trainer.py
@@ -1,9 +1,10 @@
 from argparse import ArgumentParser
-from pytorch_lightning import Trainer
 import os
 import json
 import sys
 
+import pytorch_lightning 
+
 import warnings
 warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progress bar start from.*')
 warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*')
@@ -14,12 +15,15 @@ def main(hparams):
 
   model = Model(hparams)
 
-  trainer = Trainer(
+  checkpointer = pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint('checkpoints/%s-{epoch}-{val_loss:.4f}' % hparams.name)
+
+  trainer = pytorch_lightning.Trainer(
     max_nb_epochs=hparams.epochs,
     gpus=hparams.gpus,
     nb_gpu_nodes=hparams.nodes,
     check_val_every_n_epoch=1,
-    progress_bar_refresh_rate=10,
+    progress_bar_refresh_rate=1,
+    checkpoint_callback=checkpointer,
     num_sanity_val_steps=0,
     fast_dev_run=hparams.fast_dev_run,
   )