diff --git a/data.py b/data.py
index ef21491a5bc8005e01674efcbf4109a1b3b0e14f..f28007287dc6a3e7684b1568ad4fb4cb75acfa12 100644
--- a/data.py
+++ b/data.py
@@ -40,38 +40,52 @@ def to_int(tokenizer, label_vocab, hparams, dataset):
 
 def load(tokenizer, hparams):
 
-  with open(hparams.train_filename) as fp:
-    articles = json.loads(fp.read())
+  with open(hparams.stem + '.train') as fp:
+    train_articles = json.loads(fp.read())
+
+  with open(hparams.stem + '.valid') as fp:
+    valid_articles = json.loads(fp.read())
+
+  with open(hparams.stem + '.test') as fp:
+    test_articles = json.loads(fp.read())
 
   label_vocab = collections.defaultdict(lambda: len(label_vocab))
 
-  for article in articles:
+  for article in train_articles:
     if 'topics' in article:
       for topic in article['topics']:
         label_vocab[topic]
 
   label_vocab = dict(label_vocab)
 
-  dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
-  assert len(dataset) > 0
+  train_dataset = [article for article in train_articles if 'topics' in article] # and 'abstract' in article]
+  valid_dataset = [article for article in valid_articles if 'topics' in article] # and 'abstract' in article]
+  test_dataset = [article for article in test_articles if 'topics' in article] # and 'abstract' in article]
+  assert len(train_dataset) > 0 and len(valid_dataset) > 0 and len(test_dataset) > 0
 
-  hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0)
-  assert hparams.valid_size > 0
+  #hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0)
+  #assert hparams.valid_size > 0
 
-  missing_abstracts = 0
-  for article in dataset:
-    if 'abstract' not in article or article['abstract'] == []:
-      article['abstract'] = ['']
-      missing_abstracts += 1
-  print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset)))
+  for name, dataset in [('train', train_dataset), ('valid', valid_dataset), ('test', test_dataset)]:
+    missing_abstracts = 0
+    for article in dataset:
+      if 'abstract' not in article or article['abstract'] == []:
+        article['abstract'] = ['']
+        missing_abstracts += 1
+    print('WARNING: %.2f%% missing abstract in %s' % (100 * missing_abstracts / len(dataset), name))
 
-  random.shuffle(dataset)
+  #random.shuffle(dataset)
 
-  int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset)
+  train_int_texts, train_int_labels = to_int(tokenizer, label_vocab, hparams, train_dataset)
+  valid_int_texts, valid_int_labels = to_int(tokenizer, label_vocab, hparams, valid_dataset)
+  test_int_texts, test_int_labels = to_int(tokenizer, label_vocab, hparams, test_dataset)
 
-  train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
-  valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
+  train_set = CustomDataset(train_int_texts, train_int_labels)
+  valid_set = CustomDataset(valid_int_texts, valid_int_labels)
+  test_set = CustomDataset(test_int_texts, test_int_labels)
   print('training set', len(train_set))
   print('valid set', len(valid_set))
+  print('test set', len(test_set))
+
+  return train_set, valid_set, test_set, label_vocab
 
-  return train_set, valid_set, label_vocab
diff --git a/logger.py b/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc61b64183c0762dd9c82bf8227fe0320bdccc56
--- /dev/null
+++ b/logger.py
@@ -0,0 +1,95 @@
+import json
+import os
+import sys
+import collections
+
+class Logger:
+  def __init__(self, name, checkpoint_metric='val_loss', logdir='logs'):
+    self.directory = os.path.join(logdir, name)
+    os.makedirs(self.directory, exist_ok=True)
+    self.metrics = collections.defaultdict(dict)
+    self.checkpoint_metric = checkpoint_metric
+    self.hparams = {}
+    self.best_loss = None
+    self.best_checkpoint = os.path.join(self.directory, 'best_checkpoint')
+    self.test_metrics = {}
+    self.save_function = None
+
+  def set_save_function(self, save_function):
+    self.save_function = save_function
+
+  def log_metrics(self, epoch, metrics):
+    self.metrics[epoch].update(metrics)
+    self.save_function(os.path.join(self.directory, 'last_checkpoint'))
+    if self.checkpoint_metric in metrics and (self.best_loss is None or metrics[self.checkpoint_metric] > self.best_loss):
+      self.best_loss = metrics[self.checkpoint_metric]
+      self.save_function(os.path.join(self.directory, 'best_checkpoint'))
+    self.save()
+
+  def log_test(self, metrics):
+    self.test_metrics = metrics
+    self.save()
+
+  def log_hparams(self, hparams):
+    self.hparams = vars(hparams)
+    self.save()
+
+  def save(self):
+    with open(os.path.join(self.directory, 'run.json'), 'w') as fp:
+      fp.write(json.dumps({
+        'metrics': self.metrics, 
+        'hparams': self.hparams, 
+        'test': self.test_metrics, 
+        'best_loss': self.best_loss,
+      }, indent=2))
+
+
+if __name__ == '__main__':
+  import bottle
+  import glob
+  logdir = sys.argv[1] if len(sys.argv) > 1 else 'logs'
+
+  @bottle.route('/<metric>')
+  def metric(metric='val_loss'):
+    series = []
+    for path in glob.glob(logdir + '/*/*.json'):
+      with open(path) as fp:
+        logs = json.loads(fp.read())
+        values = [x[metric] for epoch, x in sorted(logs['metrics'].items(), key=lambda k: int(k[0])) if metric in x]
+        if len(values) > 0: 
+          series.append({
+            'values': values,
+            'name': '\n'.join(['%s = %s' % (k, str(v)) for k, v in logs['hparams'].items()])
+          })
+    bottle.response.content_type = 'application/json'
+    return json.dumps(series)
+
+  @bottle.route('/')
+  def index():
+    metrics = set()
+    for path in glob.glob('logs/*/*.json'):
+      with open(path) as fp:
+        logs = json.loads(fp.read())
+        for row in logs['metrics'].values():
+          metrics.update(row.keys())
+
+    buttons = '<div id="buttons">' + ' | '.join(['<a href="#" onclick="update(\'%s\')">%s</a>' % (metric, metric) for metric in sorted(metrics)]) + '</div>'
+    html = buttons + """<canvas id="canvas">
+    <script src="https://pageperso.lis-lab.fr/benoit.favre/files/autoplot.js"></script>
+    <script>
+      var selected_metric;
+      function update(metric) {
+        selected_metric = metric;
+        fetch('/' + metric).then(res => res.json()).then(series => {
+          chart('canvas', series);
+        });
+      }
+      setInterval(function() {
+        update(selected_metric);
+      }, 60 * 1000);
+      update('%s');
+    </script>""" % sorted(metrics)[0]
+    
+    return html
+
+  bottle.run(host='localhost', port=6006, quiet=True)
diff --git a/model.py b/model.py
index 24f003c8aba682cf38902afb90017bb52db803b2..a310f6cf61248f79cb05aeddbe676963d905fb82 100644
--- a/model.py
+++ b/model.py
@@ -31,6 +31,67 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
 
   return 1 - f1.mean()
 
+class RNNLayer(nn.Module):
+  def __init__(self, hidden_size=128, dropout=0.3):
+    super().__init__()
+    rnn_output = hidden_size * rnn_layers * directions
+    self.rnn = nn.GRU(hidden_size, hidden_size, bias=True, num_layers=1, bidirectional=True, batch_first=True)
+    self.dense = nn.Linear(rnn_output, hidden_size)
+    self.dropout = nn.Dropout(dropout)
+    self.norm = nn.LayerNorm(hidden_size)
+
+  def forward(self, x):
+    output, hidden = self.rnn(x)
+    layer = self.dropout(F.gelu(self.dense(output))) + x
+    return self.norm(layer)
+
+
+class RNN(nn.Module):
+  def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout, padding_idx=0):
+    super().__init__()
+    self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx)
+    self.embed_to_rnn = nn.Linear(embed_size, hidden_size)
+    self.layers = nn.ModuleList([RNNLayer(hidden_size=hidden_size, dropout=dropout) for i in range(num_layers)])
+    self.dropout = nn.Dropout(dropout)
+
+  def forward(self, x_text):
+    embed = self.dropout(self.embed(x_text))
+    activations = self.embed_to_rnn(F.gelu(embed))
+    for layer in self.layers:
+      activations = layer(activations)
+    return activations
+
+
+class CNNLayer(nn.Module):
+  def __init__(self, hidden_size, kernel_size, dropout):
+    super().__init__()
+    self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size=kernel_size)
+    self.dropout = nn.Dropout(dropout)
+    self.norm = nn.LayerNorm(hidden_size)
+
+  def forward(self, x):
+    output = self.conv(x.transpose(1, 2)).transpose(2, 1)
+    missing = x.shape[1] - output.shape[1]
+    output = torch.cat([output, torch.zeros(x.shape[0], missing, x.shape[2], device=x.device)], 1)
+    layer = self.dropout(F.gelu(output)) + x
+    return self.norm(layer)
+
+class CNN(nn.Module):
+  def __init__(self, vocab_size, embed_size, hidden_size, num_layers, kernel_size, dropout, padding_idx=0):
+    super().__init__()
+    self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx)
+    self.embed_to_cnn = nn.Linear(embed_size, hidden_size)
+    self.layers = nn.ModuleList([CNNLayer(hidden_size=hidden_size, kernel_size=kernel_size, dropout=dropout) for i in range(num_layers)])
+    self.dropout = nn.Dropout(dropout)
+
+  def forward(self, x):
+    embed = self.dropout(self.embed(x))
+    activations = self.embed_to_cnn(F.gelu(embed))
+    for layer in self.layers:
+      activations = layer(activations)
+    pool = F.max_pool1d(activations.transpose(1, 2), activations.size(1))
+    return pool.view(x.shape[0], -1)
+
 
 class Model(LightningModule):
 
@@ -38,18 +99,32 @@ class Model(LightningModule):
     super().__init__()
 
     self.hparams = hparams
+    self.epoch = 1
     self.tokenizer = AutoTokenizer.from_pretrained(hparams.bert_flavor)
-    self.train_set, self.valid_set, self.label_vocab = data.load(self.tokenizer, hparams)
+    self.train_set, self.valid_set, self.test_set, self.label_vocab = data.load(self.tokenizer, hparams)
     hparams.num_labels = len(self.label_vocab)
 
-    self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
-    if self.hparams.transfer:
-      print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
-      checkpoint = torch.load(self.hparams.transfer)
-      state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
-      self.bert.load_state_dict(state_dict)
+    if hparams.model == 'bert':
+      self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
+      if self.hparams.transfer:
+        print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
+        checkpoint = torch.load(self.hparams.transfer)
+        state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
+        self.bert.load_state_dict(state_dict)
+      decision_input_size = self.bert.config.hidden_size
+
+    elif hparams.model == 'rnn':
+      self.rnn = RNN(self.tokenizer.vocab_size, hparams.rnn_embed_size, hparams.rnn_hidden_size, hparams.rnn_layers, hparams.dropout, self.tokenizer.pad_token_id)
+      decision_input_size = self.hparams.rnn_hidden_size
+    
+    elif hparams.model == 'cnn':
+      self.cnn = CNN(self.tokenizer.vocab_size, hparams.cnn_embed_size, hparams.cnn_hidden_size, hparams.cnn_layers, hparams.cnn_kernel_size, hparams.dropout, self.tokenizer.pad_token_id)
+      decision_input_size = self.hparams.cnn_hidden_size
+
+    else:
+      raise ValueError('invalid model type "%s"' % hparams.model)
 
-    self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
+    self.decision = nn.Linear(decision_input_size, hparams.num_labels)
     self.dropout = nn.Dropout(hparams.dropout)
     if self.hparams.loss == 'bce':
       self.loss_function = F.binary_cross_entropy_with_logits
@@ -59,7 +134,15 @@ class Model(LightningModule):
       raise ValueError('invalid loss "%s"' % self.hparams.loss)
 
   def forward(self, x):
-    _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
+    if self.hparams.model == 'bert':
+      _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
+    elif self.hparams.model == 'rnn':
+      output = self.rnn(x)
+    elif self.hparams.model == 'cnn':
+      output = self.cnn(x)
+    else:
+      raise ValueError('invalid model type "%s"' % self.hparams.model)
+
     return self.decision(F.gelu(self.dropout(output)))
 
   def training_step(self, batch, batch_idx):
@@ -73,14 +156,28 @@ class Model(LightningModule):
     y_hat = self(x)
     loss = self.loss_function(y_hat, y)
     bce = F.binary_cross_entropy_with_logits(y_hat, y)
+    acc_correct = torch.sum((y_hat >= 0) == y)
+    acc_num = torch.tensor([y.shape[0] * y.shape[1]])
+    num_correct = torch.sum((y_hat >= 0) * (y == 1))
+    num_hyp = torch.sum(y_hat >= 0)
+    num_ref = torch.sum(y == 1)
+    return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'acc_correct': acc_correct, 'acc_num': acc_num}
+
+  def test_step(self, batch, batch_idx):
+    x, y = batch
+    y_hat = self(x)
+    loss = self.loss_function(y_hat, y)
+    bce = F.binary_cross_entropy_with_logits(y_hat, y)
+    acc_correct = torch.sum((y_hat >= 0) == y)
+    acc_num = torch.tensor([y.shape[0] * y.shape[1]])
     num_correct = torch.sum((y_hat >= 0) * (y == 1))
     num_hyp = torch.sum(y_hat >= 0)
     num_ref = torch.sum(y == 1)
-    num = torch.tensor([y.shape[0] * y.shape[1]])
-    return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num}
+    return {'test_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'acc_correct': acc_correct, 'acc_num': acc_num}
 
   def training_epoch_end(self, outputs):
     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+    self.epoch += 1
     return {'loss': avg_loss, 'log': {'loss': avg_loss}}
 
   def validation_epoch_end(self, outputs):
@@ -91,18 +188,59 @@ class Model(LightningModule):
 
     bce = values['bce'].mean()
     num_correct = values['num_correct'].sum()
-    num = values['num'].sum()
-    accuracy = num_correct / float(num.item())
+    acc_num = values['acc_num'].sum()
+    accuracy = values['acc_correct'].sum() / float(acc_num.item())
     num_ref = values['num_ref'].sum()
     num_hyp = values['num_hyp'].sum()
     recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
-    precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0])
+    precision = num_correct / float(num_hyp.item()) if num_hyp != 0 else torch.tensor([0])
     fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])
 
-    return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}}
+    log_metrics = {'bce': bce.item(), 'accuracy': accuracy.item(), 'recall': recall.item(), 'precision': precision.item(), 'fscore': fscore.item()}
+    self.custom_logger.log_metrics(self.epoch, log_metrics)
+
+    return {'val_loss': avg_loss}
+
+  def test_epoch_end(self, outputs):
+    metrics = outputs[0].keys()
+    values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics}
+
+    avg_loss = values['test_loss'].mean()
+
+    bce = values['bce'].mean()
+    num_correct = values['num_correct'].sum()
+    acc_num = values['acc_num'].sum()
+    accuracy = values['acc_correct'].sum() / float(acc_num.item())
+    num_ref = values['num_ref'].sum()
+    num_hyp = values['num_hyp'].sum()
+    recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
+    precision = num_correct / float(num_hyp.item()) if num_hyp != 0 else torch.tensor([0])
+    fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])
+
+    log_metrics = {'bce': bce.item(), 'accuracy': accuracy.item(), 'recall': recall.item(), 'precision': precision.item(), 'fscore': fscore.item()}
+    self.custom_logger.log_test(log_metrics)
+
+    return {'test_loss': avg_loss}
 
   def configure_optimizers(self):
-    return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
+    optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
+    scheduler = None
+    if self.hparams.scheduler == 'warmup_linear':
+      num_warmup_steps = self.hparams.scheduler_warmup
+      num_training_steps = self.hparams.epochs
+
+      def lr_lambda(current_step):
+          if current_step < num_warmup_steps:
+              return float(current_step) / float(max(1, num_warmup_steps))
+          return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) )
+      scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1)
+    elif self.hparams.scheduler != None:
+      raise ValueError('invalid scheduler "%s"' % self.hparams.scheduler)
+
+    if scheduler:
+      return [optimizer], [scheduler]
+    else:
+      return optimizer
 
   def collate_fn(self, inputs):
     text_len = max([len(x[0]) for x in inputs])
@@ -118,21 +256,34 @@ class Model(LightningModule):
   def val_dataloader(self):
     return DataLoader(self.valid_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn)
 
+  def test_dataloader(self):
+    return DataLoader(self.test_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn)
+
   @staticmethod
   def add_model_specific_args(parent_parser):
     parser = ArgumentParser(parents=[parent_parser])
-    parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
+    parser.add_argument('--stem', type=str, required=True, help='stem name of json files containing training/validation/test instances (<stem>.{train,valid,test})')
     parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
     parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
     parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
     parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in %% (default=10)')
     parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
-    parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
     parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
     parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
-    parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
+    parser.add_argument('--loss', default='bce', type=str, help='choose loss function [f1, bce] (default=bce)')
     parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
-    parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)')
+    parser.add_argument('--transfer', default=None, type=str, help='transfer weights from checkpoint (default=do not transfer)')
+    parser.add_argument('--model', default='bert', type=str, help='model type [rnn, bert] (default=bert)')
+    parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
+    parser.add_argument('--rnn_embed_size', default=128, type=int, help='rnn embedding size (default=128)')
+    parser.add_argument('--rnn_hidden_size', default=128, type=int, help='rnn hidden size (default=128)')
+    parser.add_argument('--rnn_layers', default=1, type=int, help='rnn number of layers (default=1)')
+    parser.add_argument('--cnn_embed_size', default=128, type=int, help='cnn embedding size (default=128)')
+    parser.add_argument('--cnn_hidden_size', default=128, type=int, help='cnn hidden size (default=128)')
+    parser.add_argument('--cnn_layers', default=1, type=int, help='cnn number of layers (default=1)')
+    parser.add_argument('--cnn_kernel_size', default=3, type=int, help='cnn kernel size (default=3)')
+    parser.add_argument('--scheduler', default=None, type=str, help='learning rate schedule [warmup_linear] (default=fixed learning rate)')
+    parser.add_argument('--scheduler_warmup', default=1, type=int, help='learning rate schedule warmup epochs (default=1)')
 
     return parser
 
diff --git a/requirements-freeze.txt b/requirements-freeze.txt
index 2a6bb22282a13676631188d3e44240d3e9229658..30b3da66b6b331aac1f01ae650e9a7fd6bf9ebbf 100644
--- a/requirements-freeze.txt
+++ b/requirements-freeze.txt
@@ -1,4 +1,5 @@
 absl-py==0.9.0
+bottle==0.12.18
 cachetools==4.1.0
 certifi==2020.4.5.1
 chardet==3.0.4
diff --git a/requirements.txt b/requirements.txt
index 031902e4565366bf7a8eda2c001efd28c4f6e101..55c25dab096bb18b5a633bd61b1f26f9f858e179 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
 pytorch-lightning
 torch==1.5.0+cu101
 transformers
+bottle
diff --git a/trainer.py b/trainer.py
index b8fd02b1717cb7efd8da4741bdedc7e91d4864ee..08e631e5d0fd5760eb30e1e5d72ff4f254973d42 100644
--- a/trainer.py
+++ b/trainer.py
@@ -10,12 +10,16 @@ warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progre
 warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*')
 
 from model import Model
+from logger import Logger
 
 def main(hparams):
+  pytorch_lightning.seed_everything(hparams.seed)
 
-  model = Model(hparams)
+  logger = Logger(hparams.name, checkpoint_metric='fscore' if hparams.loss == 'f1' else 'bce')
 
-  checkpointer = pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint('checkpoints/%s-{epoch}-{val_loss:.4f}' % hparams.name)
+  model = Model(hparams)
+  model.custom_logger = logger
+  logger.log_hparams(hparams)
 
   trainer = pytorch_lightning.Trainer(
     max_nb_epochs=hparams.epochs,
@@ -23,19 +27,26 @@ def main(hparams):
     nb_gpu_nodes=hparams.nodes,
     check_val_every_n_epoch=1,
     progress_bar_refresh_rate=1,
-    checkpoint_callback=checkpointer,
+    logger=None,
+    checkpoint_callback=None,
     num_sanity_val_steps=0,
     fast_dev_run=hparams.fast_dev_run,
+    deterministic=True,
   )
 
+  logger.set_save_function(trainer.save_checkpoint)
   trainer.fit(model)
+  model = Model.load_from_checkpoint(logger.best_checkpoint)
+  model.custom_logger = logger
+  trainer.test(model)
 
 if __name__ == '__main__':
   parser = ArgumentParser(add_help=False)
-  parser.add_argument('--gpus', type=str, default=None)
-  parser.add_argument('--nodes', type=int, default=1)
-  parser.add_argument('--name', type=str, required=True)
-  parser.add_argument('--fast_dev_run', default=False, action='store_true')
+  parser.add_argument('--gpus', type=str, default=None, help='ids of GPUs to use (use -1 for all available GPUs, defaults to CPU)')
+  parser.add_argument('--nodes', type=int, default=1, help='number of computation nodes for distributed training (see lightning docs, defaults to 1)')
+  parser.add_argument('--name', type=str, required=True, help='name of experiment (required)')
+  parser.add_argument('--fast_dev_run', default=False, action='store_true', help='run a single batch through the whole training loop for catching bugs')
+  parser.add_argument('--seed', default=123, type=int, help='set global random seed (defaults to 123)')
 
   parser = Model.add_model_specific_args(parser)
   command_line = 'python ' + ' '.join(sys.argv)