From 97a34582d19d7e2319d6aade73c6223614da902f Mon Sep 17 00:00:00 2001
From: Benoit Favre <benoit.favre@lis-lab.fr>
Date: Thu, 4 Jun 2020 17:22:19 +0200
Subject: [PATCH] add option for loading checkpoint from other model for
 transfer learning

---
 model.py | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/model.py b/model.py
index b33e57b..24f003c 100644
--- a/model.py
+++ b/model.py
@@ -43,6 +43,12 @@ class Model(LightningModule):
     hparams.num_labels = len(self.label_vocab)
 
     self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
+    if self.hparams.transfer:
+      print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
+      checkpoint = torch.load(self.hparams.transfer)
+      state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
+      self.bert.load_state_dict(state_dict)
+
     self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
     self.dropout = nn.Dropout(hparams.dropout)
     if self.hparams.loss == 'bce':
@@ -126,6 +132,7 @@ class Model(LightningModule):
     parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
     parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
     parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
+    parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)')
 
     return parser
 
-- 
GitLab