diff --git a/data.py b/data.py index ef21491a5bc8005e01674efcbf4109a1b3b0e14f..f28007287dc6a3e7684b1568ad4fb4cb75acfa12 100644 --- a/data.py +++ b/data.py @@ -40,38 +40,52 @@ def to_int(tokenizer, label_vocab, hparams, dataset): def load(tokenizer, hparams): - with open(hparams.train_filename) as fp: - articles = json.loads(fp.read()) + with open(hparams.stem + '.train') as fp: + train_articles = json.loads(fp.read()) + + with open(hparams.stem + '.valid') as fp: + valid_articles = json.loads(fp.read()) + + with open(hparams.stem + '.test') as fp: + test_articles = json.loads(fp.read()) label_vocab = collections.defaultdict(lambda: len(label_vocab)) - for article in articles: + for article in train_articles: if 'topics' in article: for topic in article['topics']: label_vocab[topic] label_vocab = dict(label_vocab) - dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article] - assert len(dataset) > 0 + train_dataset = [article for article in train_articles if 'topics' in article] # and 'abstract' in article] + valid_dataset = [article for article in valid_articles if 'topics' in article] # and 'abstract' in article] + test_dataset = [article for article in test_articles if 'topics' in article] # and 'abstract' in article] + assert len(train_dataset) > 0 and len(valid_dataset) > 0 and len(test_dataset) > 0 - hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0) - assert hparams.valid_size > 0 + #hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0) + #assert hparams.valid_size > 0 - missing_abstracts = 0 - for article in dataset: - if 'abstract' not in article or article['abstract'] == []: - article['abstract'] = [''] - missing_abstracts += 1 - print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset))) + for name, dataset in [('train', train_dataset), ('valid', valid_dataset), ('test', test_dataset)]: + missing_abstracts = 0 + for article in dataset: + if 'abstract' not in article or article['abstract'] == []: + article['abstract'] = [''] + missing_abstracts += 1 + print('WARNING: %.2f%% missing abstract in %s' % (100 * missing_abstracts / len(dataset), name)) - random.shuffle(dataset) + #random.shuffle(dataset) - int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset) + train_int_texts, train_int_labels = to_int(tokenizer, label_vocab, hparams, train_dataset) + valid_int_texts, valid_int_labels = to_int(tokenizer, label_vocab, hparams, valid_dataset) + test_int_texts, test_int_labels = to_int(tokenizer, label_vocab, hparams, test_dataset) - train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:]) - valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size]) + train_set = CustomDataset(train_int_texts, train_int_labels) + valid_set = CustomDataset(valid_int_texts, valid_int_labels) + test_set = CustomDataset(test_int_texts, test_int_labels) print('training set', len(train_set)) print('valid set', len(valid_set)) + print('test set', len(test_set)) + + return train_set, valid_set, test_set, label_vocab - return train_set, valid_set, label_vocab diff --git a/logger.py b/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..cc61b64183c0762dd9c82bf8227fe0320bdccc56 --- /dev/null +++ b/logger.py @@ -0,0 +1,95 @@ +import json +import os +import sys +import collections + +class Logger: + def __init__(self, name, checkpoint_metric='val_loss', logdir='logs'): + self.directory = os.path.join(logdir, name) + os.makedirs(self.directory, exist_ok=True) + self.metrics = collections.defaultdict(dict) + self.checkpoint_metric = checkpoint_metric + self.hparams = {} + self.best_loss = None + self.best_checkpoint = os.path.join(self.directory, 'best_checkpoint') + self.test_metrics = {} + self.save_function = None + + def set_save_function(self, save_function): + self.save_function = save_function + + def log_metrics(self, epoch, metrics): + self.metrics[epoch].update(metrics) + self.save_function(os.path.join(self.directory, 'last_checkpoint')) + if self.checkpoint_metric in metrics and (self.best_loss is None or metrics[self.checkpoint_metric] > self.best_loss): + self.best_loss = metrics[self.checkpoint_metric] + self.save_function(os.path.join(self.directory, 'best_checkpoint')) + self.save() + + def log_test(self, metrics): + self.test_metrics = metrics + self.save() + + def log_hparams(self, hparams): + self.hparams = vars(hparams) + self.save() + + def save(self): + with open(os.path.join(self.directory, 'run.json'), 'w') as fp: + fp.write(json.dumps({ + 'metrics': self.metrics, + 'hparams': self.hparams, + 'test': self.test_metrics, + 'best_loss': self.best_loss, + }, indent=2)) + + +if __name__ == '__main__': + import bottle + import glob + logdir = sys.argv[1] if len(sys.argv) > 1 else 'logs' + + @bottle.route('/<metric>') + def metric(metric='val_loss'): + series = [] + for path in glob.glob(logdir + '/*/*.json'): + with open(path) as fp: + logs = json.loads(fp.read()) + values = [x[metric] for epoch, x in sorted(logs['metrics'].items(), key=lambda k: int(k[0])) if metric in x] + if len(values) > 0: + series.append({ + 'values': values, + 'name': '\n'.join(['%s = %s' % (k, str(v)) for k, v in logs['hparams'].items()]) + }) + bottle.response.content_type = 'application/json' + return json.dumps(series) + + @bottle.route('/') + def index(): + metrics = set() + for path in glob.glob('logs/*/*.json'): + with open(path) as fp: + logs = json.loads(fp.read()) + for row in logs['metrics'].values(): + metrics.update(row.keys()) + + buttons = '<div id="buttons">' + ' | '.join(['<a href="#" onclick="update(\'%s\')">%s</a>' % (metric, metric) for metric in sorted(metrics)]) + '</div>' + html = buttons + """<canvas id="canvas"> + <script src="https://pageperso.lis-lab.fr/benoit.favre/files/autoplot.js"></script> + <script> + var selected_metric; + function update(metric) { + selected_metric = metric; + fetch('/' + metric).then(res => res.json()).then(series => { + chart('canvas', series); + }); + } + setInterval(function() { + update(selected_metric); + }, 60 * 1000); + update('%s'); + </script>""" % sorted(metrics)[0] + + return html + + bottle.run(host='localhost', port=6006, quiet=True) diff --git a/model.py b/model.py index 24f003c8aba682cf38902afb90017bb52db803b2..a310f6cf61248f79cb05aeddbe676963d905fb82 100644 --- a/model.py +++ b/model.py @@ -31,6 +31,67 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7): return 1 - f1.mean() +class RNNLayer(nn.Module): + def __init__(self, hidden_size=128, dropout=0.3): + super().__init__() + rnn_output = hidden_size * rnn_layers * directions + self.rnn = nn.GRU(hidden_size, hidden_size, bias=True, num_layers=1, bidirectional=True, batch_first=True) + self.dense = nn.Linear(rnn_output, hidden_size) + self.dropout = nn.Dropout(dropout) + self.norm = nn.LayerNorm(hidden_size) + + def forward(self, x): + output, hidden = self.rnn(x) + layer = self.dropout(F.gelu(self.dense(output))) + x + return self.norm(layer) + + +class RNN(nn.Module): + def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout, padding_idx=0): + super().__init__() + self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx) + self.embed_to_rnn = nn.Linear(embed_size, hidden_size) + self.layers = nn.ModuleList([RNNLayer(hidden_size=hidden_size, dropout=dropout) for i in range(num_layers)]) + self.dropout = nn.Dropout(dropout) + + def forward(self, x_text): + embed = self.dropout(self.embed(x_text)) + activations = self.embed_to_rnn(F.gelu(embed)) + for layer in self.layers: + activations = layer(activations) + return activations + + +class CNNLayer(nn.Module): + def __init__(self, hidden_size, kernel_size, dropout): + super().__init__() + self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size=kernel_size) + self.dropout = nn.Dropout(dropout) + self.norm = nn.LayerNorm(hidden_size) + + def forward(self, x): + output = self.conv(x.transpose(1, 2)).transpose(2, 1) + missing = x.shape[1] - output.shape[1] + output = torch.cat([output, torch.zeros(x.shape[0], missing, x.shape[2], device=x.device)], 1) + layer = self.dropout(F.gelu(output)) + x + return self.norm(layer) + +class CNN(nn.Module): + def __init__(self, vocab_size, embed_size, hidden_size, num_layers, kernel_size, dropout, padding_idx=0): + super().__init__() + self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx) + self.embed_to_cnn = nn.Linear(embed_size, hidden_size) + self.layers = nn.ModuleList([CNNLayer(hidden_size=hidden_size, kernel_size=kernel_size, dropout=dropout) for i in range(num_layers)]) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + embed = self.dropout(self.embed(x)) + activations = self.embed_to_cnn(F.gelu(embed)) + for layer in self.layers: + activations = layer(activations) + pool = F.max_pool1d(activations.transpose(1, 2), activations.size(1)) + return pool.view(x.shape[0], -1) + class Model(LightningModule): @@ -38,18 +99,32 @@ class Model(LightningModule): super().__init__() self.hparams = hparams + self.epoch = 1 self.tokenizer = AutoTokenizer.from_pretrained(hparams.bert_flavor) - self.train_set, self.valid_set, self.label_vocab = data.load(self.tokenizer, hparams) + self.train_set, self.valid_set, self.test_set, self.label_vocab = data.load(self.tokenizer, hparams) hparams.num_labels = len(self.label_vocab) - self.bert = AutoModel.from_pretrained(hparams.bert_flavor) - if self.hparams.transfer: - print('loading bert weights from checkpoint "%s"' % self.hparams.transfer) - checkpoint = torch.load(self.hparams.transfer) - state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')} - self.bert.load_state_dict(state_dict) + if hparams.model == 'bert': + self.bert = AutoModel.from_pretrained(hparams.bert_flavor) + if self.hparams.transfer: + print('loading bert weights from checkpoint "%s"' % self.hparams.transfer) + checkpoint = torch.load(self.hparams.transfer) + state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')} + self.bert.load_state_dict(state_dict) + decision_input_size = self.bert.config.hidden_size + + elif hparams.model == 'rnn': + self.rnn = RNN(self.tokenizer.vocab_size, hparams.rnn_embed_size, hparams.rnn_hidden_size, hparams.rnn_layers, hparams.dropout, self.tokenizer.pad_token_id) + decision_input_size = self.hparams.rnn_hidden_size + + elif hparams.model == 'cnn': + self.cnn = CNN(self.tokenizer.vocab_size, hparams.cnn_embed_size, hparams.cnn_hidden_size, hparams.cnn_layers, hparams.cnn_kernel_size, hparams.dropout, self.tokenizer.pad_token_id) + decision_input_size = self.hparams.cnn_hidden_size + + else: + raise ValueError('invalid model type "%s"' % hparams.model) - self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels) + self.decision = nn.Linear(decision_input_size, hparams.num_labels) self.dropout = nn.Dropout(hparams.dropout) if self.hparams.loss == 'bce': self.loss_function = F.binary_cross_entropy_with_logits @@ -59,7 +134,15 @@ class Model(LightningModule): raise ValueError('invalid loss "%s"' % self.hparams.loss) def forward(self, x): - _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long()) + if self.hparams.model == 'bert': + _, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long()) + elif self.hparams.model == 'rnn': + output = self.rnn(x) + elif self.hparams.model == 'cnn': + output = self.cnn(x) + else: + raise ValueError('invalid model type "%s"' % self.hparams.model) + return self.decision(F.gelu(self.dropout(output))) def training_step(self, batch, batch_idx): @@ -73,14 +156,28 @@ class Model(LightningModule): y_hat = self(x) loss = self.loss_function(y_hat, y) bce = F.binary_cross_entropy_with_logits(y_hat, y) + acc_correct = torch.sum((y_hat >= 0) == y) + acc_num = torch.tensor([y.shape[0] * y.shape[1]]) + num_correct = torch.sum((y_hat >= 0) * (y == 1)) + num_hyp = torch.sum(y_hat >= 0) + num_ref = torch.sum(y == 1) + return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'acc_correct': acc_correct, 'acc_num': acc_num} + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = self.loss_function(y_hat, y) + bce = F.binary_cross_entropy_with_logits(y_hat, y) + acc_correct = torch.sum((y_hat >= 0) == y) + acc_num = torch.tensor([y.shape[0] * y.shape[1]]) num_correct = torch.sum((y_hat >= 0) * (y == 1)) num_hyp = torch.sum(y_hat >= 0) num_ref = torch.sum(y == 1) - num = torch.tensor([y.shape[0] * y.shape[1]]) - return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num} + return {'test_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'acc_correct': acc_correct, 'acc_num': acc_num} def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + self.epoch += 1 return {'loss': avg_loss, 'log': {'loss': avg_loss}} def validation_epoch_end(self, outputs): @@ -91,18 +188,59 @@ class Model(LightningModule): bce = values['bce'].mean() num_correct = values['num_correct'].sum() - num = values['num'].sum() - accuracy = num_correct / float(num.item()) + acc_num = values['acc_num'].sum() + accuracy = values['acc_correct'].sum() / float(acc_num.item()) num_ref = values['num_ref'].sum() num_hyp = values['num_hyp'].sum() recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0]) - precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0]) + precision = num_correct / float(num_hyp.item()) if num_hyp != 0 else torch.tensor([0]) fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0]) - return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}} + log_metrics = {'bce': bce.item(), 'accuracy': accuracy.item(), 'recall': recall.item(), 'precision': precision.item(), 'fscore': fscore.item()} + self.custom_logger.log_metrics(self.epoch, log_metrics) + + return {'val_loss': avg_loss} + + def test_epoch_end(self, outputs): + metrics = outputs[0].keys() + values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics} + + avg_loss = values['test_loss'].mean() + + bce = values['bce'].mean() + num_correct = values['num_correct'].sum() + acc_num = values['acc_num'].sum() + accuracy = values['acc_correct'].sum() / float(acc_num.item()) + num_ref = values['num_ref'].sum() + num_hyp = values['num_hyp'].sum() + recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0]) + precision = num_correct / float(num_hyp.item()) if num_hyp != 0 else torch.tensor([0]) + fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0]) + + log_metrics = {'bce': bce.item(), 'accuracy': accuracy.item(), 'recall': recall.item(), 'precision': precision.item(), 'fscore': fscore.item()} + self.custom_logger.log_test(log_metrics) + + return {'test_loss': avg_loss} def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + scheduler = None + if self.hparams.scheduler == 'warmup_linear': + num_warmup_steps = self.hparams.scheduler_warmup + num_training_steps = self.hparams.epochs + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1) + elif self.hparams.scheduler != None: + raise ValueError('invalid scheduler "%s"' % self.hparams.scheduler) + + if scheduler: + return [optimizer], [scheduler] + else: + return optimizer def collate_fn(self, inputs): text_len = max([len(x[0]) for x in inputs]) @@ -118,21 +256,34 @@ class Model(LightningModule): def val_dataloader(self): return DataLoader(self.valid_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn) + def test_dataloader(self): + return DataLoader(self.test_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn) + @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser]) - parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances') + parser.add_argument('--stem', type=str, required=True, help='stem name of json files containing training/validation/test instances (<stem>.{train,valid,test})') parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)') parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)') parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)') parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in %% (default=10)') parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)') - parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed') parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)') parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert') - parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)') + parser.add_argument('--loss', default='bce', type=str, help='choose loss function [f1, bce] (default=bce)') parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)') - parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)') + parser.add_argument('--transfer', default=None, type=str, help='transfer weights from checkpoint (default=do not transfer)') + parser.add_argument('--model', default='bert', type=str, help='model type [rnn, bert] (default=bert)') + parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed') + parser.add_argument('--rnn_embed_size', default=128, type=int, help='rnn embedding size (default=128)') + parser.add_argument('--rnn_hidden_size', default=128, type=int, help='rnn hidden size (default=128)') + parser.add_argument('--rnn_layers', default=1, type=int, help='rnn number of layers (default=1)') + parser.add_argument('--cnn_embed_size', default=128, type=int, help='cnn embedding size (default=128)') + parser.add_argument('--cnn_hidden_size', default=128, type=int, help='cnn hidden size (default=128)') + parser.add_argument('--cnn_layers', default=1, type=int, help='cnn number of layers (default=1)') + parser.add_argument('--cnn_kernel_size', default=3, type=int, help='cnn kernel size (default=3)') + parser.add_argument('--scheduler', default=None, type=str, help='learning rate schedule [warmup_linear] (default=fixed learning rate)') + parser.add_argument('--scheduler_warmup', default=1, type=int, help='learning rate schedule warmup epochs (default=1)') return parser diff --git a/requirements-freeze.txt b/requirements-freeze.txt index 2a6bb22282a13676631188d3e44240d3e9229658..30b3da66b6b331aac1f01ae650e9a7fd6bf9ebbf 100644 --- a/requirements-freeze.txt +++ b/requirements-freeze.txt @@ -1,4 +1,5 @@ absl-py==0.9.0 +bottle==0.12.18 cachetools==4.1.0 certifi==2020.4.5.1 chardet==3.0.4 diff --git a/requirements.txt b/requirements.txt index 031902e4565366bf7a8eda2c001efd28c4f6e101..55c25dab096bb18b5a633bd61b1f26f9f858e179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pytorch-lightning torch==1.5.0+cu101 transformers +bottle diff --git a/trainer.py b/trainer.py index b8fd02b1717cb7efd8da4741bdedc7e91d4864ee..08e631e5d0fd5760eb30e1e5d72ff4f254973d42 100644 --- a/trainer.py +++ b/trainer.py @@ -10,12 +10,16 @@ warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progre warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*') from model import Model +from logger import Logger def main(hparams): + pytorch_lightning.seed_everything(hparams.seed) - model = Model(hparams) + logger = Logger(hparams.name, checkpoint_metric='fscore' if hparams.loss == 'f1' else 'bce') - checkpointer = pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint('checkpoints/%s-{epoch}-{val_loss:.4f}' % hparams.name) + model = Model(hparams) + model.custom_logger = logger + logger.log_hparams(hparams) trainer = pytorch_lightning.Trainer( max_nb_epochs=hparams.epochs, @@ -23,19 +27,26 @@ def main(hparams): nb_gpu_nodes=hparams.nodes, check_val_every_n_epoch=1, progress_bar_refresh_rate=1, - checkpoint_callback=checkpointer, + logger=None, + checkpoint_callback=None, num_sanity_val_steps=0, fast_dev_run=hparams.fast_dev_run, + deterministic=True, ) + logger.set_save_function(trainer.save_checkpoint) trainer.fit(model) + model = Model.load_from_checkpoint(logger.best_checkpoint) + model.custom_logger = logger + trainer.test(model) if __name__ == '__main__': parser = ArgumentParser(add_help=False) - parser.add_argument('--gpus', type=str, default=None) - parser.add_argument('--nodes', type=int, default=1) - parser.add_argument('--name', type=str, required=True) - parser.add_argument('--fast_dev_run', default=False, action='store_true') + parser.add_argument('--gpus', type=str, default=None, help='ids of GPUs to use (use -1 for all available GPUs, defaults to CPU)') + parser.add_argument('--nodes', type=int, default=1, help='number of computation nodes for distributed training (see lightning docs, defaults to 1)') + parser.add_argument('--name', type=str, required=True, help='name of experiment (required)') + parser.add_argument('--fast_dev_run', default=False, action='store_true', help='run a single batch through the whole training loop for catching bugs') + parser.add_argument('--seed', default=123, type=int, help='set global random seed (defaults to 123)') parser = Model.add_model_specific_args(parser) command_line = 'python ' + ' '.join(sys.argv)