diff --git a/model.py b/model.py index b33e57b36a0c49770aeab1f054f5d660e2856531..24f003c8aba682cf38902afb90017bb52db803b2 100644 --- a/model.py +++ b/model.py @@ -43,6 +43,12 @@ class Model(LightningModule): hparams.num_labels = len(self.label_vocab) self.bert = AutoModel.from_pretrained(hparams.bert_flavor) + if self.hparams.transfer: + print('loading bert weights from checkpoint "%s"' % self.hparams.transfer) + checkpoint = torch.load(self.hparams.transfer) + state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')} + self.bert.load_state_dict(state_dict) + self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels) self.dropout = nn.Dropout(hparams.dropout) if self.hparams.loss == 'bce': @@ -126,6 +132,7 @@ class Model(LightningModule): parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert') parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)') parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)') + parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)') return parser