diff --git a/model.py b/model.py
index b33e57b36a0c49770aeab1f054f5d660e2856531..24f003c8aba682cf38902afb90017bb52db803b2 100644
--- a/model.py
+++ b/model.py
@@ -43,6 +43,12 @@ class Model(LightningModule):
     hparams.num_labels = len(self.label_vocab)
 
     self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
+    if self.hparams.transfer:
+      print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
+      checkpoint = torch.load(self.hparams.transfer)
+      state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
+      self.bert.load_state_dict(state_dict)
+
     self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
     self.dropout = nn.Dropout(hparams.dropout)
     if self.hparams.loss == 'bce':
@@ -126,6 +132,7 @@ class Model(LightningModule):
     parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
     parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
     parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
+    parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)')
 
     return parser