Skip to content
Snippets Groups Projects
Commit 97a34582 authored by Benoit Favre's avatar Benoit Favre
Browse files

add option for loading checkpoint from other model for transfer learning

parent 6e9186d7
Branches
Tags
No related merge requests found
......@@ -43,6 +43,12 @@ class Model(LightningModule):
hparams.num_labels = len(self.label_vocab)
self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
if self.hparams.transfer:
print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
checkpoint = torch.load(self.hparams.transfer)
state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
self.bert.load_state_dict(state_dict)
self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
self.dropout = nn.Dropout(hparams.dropout)
if self.hparams.loss == 'bce':
......@@ -126,6 +132,7 @@ class Model(LightningModule):
parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)')
return parser
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment