Skip to content
Snippets Groups Projects
Commit 97a34582 authored by Benoit Favre's avatar Benoit Favre
Browse files

add option for loading checkpoint from other model for transfer learning

parent 6e9186d7
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,12 @@ class Model(LightningModule): ...@@ -43,6 +43,12 @@ class Model(LightningModule):
hparams.num_labels = len(self.label_vocab) hparams.num_labels = len(self.label_vocab)
self.bert = AutoModel.from_pretrained(hparams.bert_flavor) self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
if self.hparams.transfer:
print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
checkpoint = torch.load(self.hparams.transfer)
state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
self.bert.load_state_dict(state_dict)
self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels) self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
self.dropout = nn.Dropout(hparams.dropout) self.dropout = nn.Dropout(hparams.dropout)
if self.hparams.loss == 'bce': if self.hparams.loss == 'bce':
...@@ -126,6 +132,7 @@ class Model(LightningModule): ...@@ -126,6 +132,7 @@ class Model(LightningModule):
parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert') parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)') parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)') parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)')
return parser return parser
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment