Skip to content
Snippets Groups Projects
Commit 14243bb9 authored by Benoit Favre's avatar Benoit Favre
Browse files

add custom logger

parent 567e3e7c
No related branches found
No related tags found
No related merge requests found
......@@ -40,38 +40,52 @@ def to_int(tokenizer, label_vocab, hparams, dataset):
def load(tokenizer, hparams):
with open(hparams.train_filename) as fp:
articles = json.loads(fp.read())
with open(hparams.stem + '.train') as fp:
train_articles = json.loads(fp.read())
with open(hparams.stem + '.valid') as fp:
valid_articles = json.loads(fp.read())
with open(hparams.stem + '.test') as fp:
test_articles = json.loads(fp.read())
label_vocab = collections.defaultdict(lambda: len(label_vocab))
for article in articles:
for article in train_articles:
if 'topics' in article:
for topic in article['topics']:
label_vocab[topic]
label_vocab = dict(label_vocab)
dataset = [article for article in articles if 'topics' in article] # and 'abstract' in article]
assert len(dataset) > 0
train_dataset = [article for article in train_articles if 'topics' in article] # and 'abstract' in article]
valid_dataset = [article for article in valid_articles if 'topics' in article] # and 'abstract' in article]
test_dataset = [article for article in test_articles if 'topics' in article] # and 'abstract' in article]
assert len(train_dataset) > 0 and len(valid_dataset) > 0 and len(test_dataset) > 0
hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0)
assert hparams.valid_size > 0
#hparams.valid_size = int(hparams.valid_size_percent * len(dataset) / 100.0)
#assert hparams.valid_size > 0
for name, dataset in [('train', train_dataset), ('valid', valid_dataset), ('test', test_dataset)]:
missing_abstracts = 0
for article in dataset:
if 'abstract' not in article or article['abstract'] == []:
article['abstract'] = ['']
missing_abstracts += 1
print('WARNING: %.2f%% missing abstract' % (100 * missing_abstracts / len(dataset)))
print('WARNING: %.2f%% missing abstract in %s' % (100 * missing_abstracts / len(dataset), name))
random.shuffle(dataset)
#random.shuffle(dataset)
int_texts, int_labels = to_int(tokenizer, label_vocab, hparams, dataset)
train_int_texts, train_int_labels = to_int(tokenizer, label_vocab, hparams, train_dataset)
valid_int_texts, valid_int_labels = to_int(tokenizer, label_vocab, hparams, valid_dataset)
test_int_texts, test_int_labels = to_int(tokenizer, label_vocab, hparams, test_dataset)
train_set = CustomDataset(int_texts[hparams.valid_size:], int_labels[hparams.valid_size:])
valid_set = CustomDataset(int_texts[:hparams.valid_size], int_labels[:hparams.valid_size])
train_set = CustomDataset(train_int_texts, train_int_labels)
valid_set = CustomDataset(valid_int_texts, valid_int_labels)
test_set = CustomDataset(test_int_texts, test_int_labels)
print('training set', len(train_set))
print('valid set', len(valid_set))
print('test set', len(test_set))
return train_set, valid_set, test_set, label_vocab
return train_set, valid_set, label_vocab
import json
import os
import sys
import collections
class Logger:
def __init__(self, name, checkpoint_metric='val_loss', logdir='logs'):
self.directory = os.path.join(logdir, name)
os.makedirs(self.directory, exist_ok=True)
self.metrics = collections.defaultdict(dict)
self.checkpoint_metric = checkpoint_metric
self.hparams = {}
self.best_loss = None
self.best_checkpoint = os.path.join(self.directory, 'best_checkpoint')
self.test_metrics = {}
self.save_function = None
def set_save_function(self, save_function):
self.save_function = save_function
def log_metrics(self, epoch, metrics):
self.metrics[epoch].update(metrics)
self.save_function(os.path.join(self.directory, 'last_checkpoint'))
if self.checkpoint_metric in metrics and (self.best_loss is None or metrics[self.checkpoint_metric] > self.best_loss):
self.best_loss = metrics[self.checkpoint_metric]
self.save_function(os.path.join(self.directory, 'best_checkpoint'))
self.save()
def log_test(self, metrics):
self.test_metrics = metrics
self.save()
def log_hparams(self, hparams):
self.hparams = vars(hparams)
self.save()
def save(self):
with open(os.path.join(self.directory, 'run.json'), 'w') as fp:
fp.write(json.dumps({
'metrics': self.metrics,
'hparams': self.hparams,
'test': self.test_metrics,
'best_loss': self.best_loss,
}, indent=2))
if __name__ == '__main__':
import bottle
import glob
logdir = sys.argv[1] if len(sys.argv) > 1 else 'logs'
@bottle.route('/<metric>')
def metric(metric='val_loss'):
series = []
for path in glob.glob(logdir + '/*/*.json'):
with open(path) as fp:
logs = json.loads(fp.read())
values = [x[metric] for epoch, x in sorted(logs['metrics'].items(), key=lambda k: int(k[0])) if metric in x]
if len(values) > 0:
series.append({
'values': values,
'name': '\n'.join(['%s = %s' % (k, str(v)) for k, v in logs['hparams'].items()])
})
bottle.response.content_type = 'application/json'
return json.dumps(series)
@bottle.route('/')
def index():
metrics = set()
for path in glob.glob('logs/*/*.json'):
with open(path) as fp:
logs = json.loads(fp.read())
for row in logs['metrics'].values():
metrics.update(row.keys())
buttons = '<div id="buttons">' + ' | '.join(['<a href="#" onclick="update(\'%s\')">%s</a>' % (metric, metric) for metric in sorted(metrics)]) + '</div>'
html = buttons + """<canvas id="canvas">
<script src="https://pageperso.lis-lab.fr/benoit.favre/files/autoplot.js"></script>
<script>
var selected_metric;
function update(metric) {
selected_metric = metric;
fetch('/' + metric).then(res => res.json()).then(series => {
chart('canvas', series);
});
}
setInterval(function() {
update(selected_metric);
}, 60 * 1000);
update('%s');
</script>""" % sorted(metrics)[0]
return html
bottle.run(host='localhost', port=6006, quiet=True)
......@@ -31,6 +31,67 @@ def binary_f1_score_with_logits(y_pred, y_true, epsilon=1e-7):
return 1 - f1.mean()
class RNNLayer(nn.Module):
def __init__(self, hidden_size=128, dropout=0.3):
super().__init__()
rnn_output = hidden_size * rnn_layers * directions
self.rnn = nn.GRU(hidden_size, hidden_size, bias=True, num_layers=1, bidirectional=True, batch_first=True)
self.dense = nn.Linear(rnn_output, hidden_size)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x):
output, hidden = self.rnn(x)
layer = self.dropout(F.gelu(self.dense(output))) + x
return self.norm(layer)
class RNN(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout, padding_idx=0):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx)
self.embed_to_rnn = nn.Linear(embed_size, hidden_size)
self.layers = nn.ModuleList([RNNLayer(hidden_size=hidden_size, dropout=dropout) for i in range(num_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, x_text):
embed = self.dropout(self.embed(x_text))
activations = self.embed_to_rnn(F.gelu(embed))
for layer in self.layers:
activations = layer(activations)
return activations
class CNNLayer(nn.Module):
def __init__(self, hidden_size, kernel_size, dropout):
super().__init__()
self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size=kernel_size)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x):
output = self.conv(x.transpose(1, 2)).transpose(2, 1)
missing = x.shape[1] - output.shape[1]
output = torch.cat([output, torch.zeros(x.shape[0], missing, x.shape[2], device=x.device)], 1)
layer = self.dropout(F.gelu(output)) + x
return self.norm(layer)
class CNN(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers, kernel_size, dropout, padding_idx=0):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx)
self.embed_to_cnn = nn.Linear(embed_size, hidden_size)
self.layers = nn.ModuleList([CNNLayer(hidden_size=hidden_size, kernel_size=kernel_size, dropout=dropout) for i in range(num_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, x):
embed = self.dropout(self.embed(x))
activations = self.embed_to_cnn(F.gelu(embed))
for layer in self.layers:
activations = layer(activations)
pool = F.max_pool1d(activations.transpose(1, 2), activations.size(1))
return pool.view(x.shape[0], -1)
class Model(LightningModule):
......@@ -38,18 +99,32 @@ class Model(LightningModule):
super().__init__()
self.hparams = hparams
self.epoch = 1
self.tokenizer = AutoTokenizer.from_pretrained(hparams.bert_flavor)
self.train_set, self.valid_set, self.label_vocab = data.load(self.tokenizer, hparams)
self.train_set, self.valid_set, self.test_set, self.label_vocab = data.load(self.tokenizer, hparams)
hparams.num_labels = len(self.label_vocab)
if hparams.model == 'bert':
self.bert = AutoModel.from_pretrained(hparams.bert_flavor)
if self.hparams.transfer:
print('loading bert weights from checkpoint "%s"' % self.hparams.transfer)
checkpoint = torch.load(self.hparams.transfer)
state_dict = {x[5:]: y for x, y in checkpoint['state_dict'].items() if x.startswith('bert.')}
self.bert.load_state_dict(state_dict)
decision_input_size = self.bert.config.hidden_size
self.decision = nn.Linear(self.bert.config.hidden_size, hparams.num_labels)
elif hparams.model == 'rnn':
self.rnn = RNN(self.tokenizer.vocab_size, hparams.rnn_embed_size, hparams.rnn_hidden_size, hparams.rnn_layers, hparams.dropout, self.tokenizer.pad_token_id)
decision_input_size = self.hparams.rnn_hidden_size
elif hparams.model == 'cnn':
self.cnn = CNN(self.tokenizer.vocab_size, hparams.cnn_embed_size, hparams.cnn_hidden_size, hparams.cnn_layers, hparams.cnn_kernel_size, hparams.dropout, self.tokenizer.pad_token_id)
decision_input_size = self.hparams.cnn_hidden_size
else:
raise ValueError('invalid model type "%s"' % hparams.model)
self.decision = nn.Linear(decision_input_size, hparams.num_labels)
self.dropout = nn.Dropout(hparams.dropout)
if self.hparams.loss == 'bce':
self.loss_function = F.binary_cross_entropy_with_logits
......@@ -59,7 +134,15 @@ class Model(LightningModule):
raise ValueError('invalid loss "%s"' % self.hparams.loss)
def forward(self, x):
if self.hparams.model == 'bert':
_, output = self.bert(x, attention_mask = (x != self.tokenizer.pad_token_id).long())
elif self.hparams.model == 'rnn':
output = self.rnn(x)
elif self.hparams.model == 'cnn':
output = self.cnn(x)
else:
raise ValueError('invalid model type "%s"' % self.hparams.model)
return self.decision(F.gelu(self.dropout(output)))
def training_step(self, batch, batch_idx):
......@@ -73,14 +156,28 @@ class Model(LightningModule):
y_hat = self(x)
loss = self.loss_function(y_hat, y)
bce = F.binary_cross_entropy_with_logits(y_hat, y)
acc_correct = torch.sum((y_hat >= 0) == y)
acc_num = torch.tensor([y.shape[0] * y.shape[1]])
num_correct = torch.sum((y_hat >= 0) * (y == 1))
num_hyp = torch.sum(y_hat >= 0)
num_ref = torch.sum(y == 1)
num = torch.tensor([y.shape[0] * y.shape[1]])
return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'num': num}
return {'val_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'acc_correct': acc_correct, 'acc_num': acc_num}
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_function(y_hat, y)
bce = F.binary_cross_entropy_with_logits(y_hat, y)
acc_correct = torch.sum((y_hat >= 0) == y)
acc_num = torch.tensor([y.shape[0] * y.shape[1]])
num_correct = torch.sum((y_hat >= 0) * (y == 1))
num_hyp = torch.sum(y_hat >= 0)
num_ref = torch.sum(y == 1)
return {'test_loss': loss, 'bce': bce, 'num_correct': num_correct, 'num_ref': num_ref, 'num_hyp': num_hyp, 'acc_correct': acc_correct, 'acc_num': acc_num}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
self.epoch += 1
return {'loss': avg_loss, 'log': {'loss': avg_loss}}
def validation_epoch_end(self, outputs):
......@@ -91,18 +188,59 @@ class Model(LightningModule):
bce = values['bce'].mean()
num_correct = values['num_correct'].sum()
num = values['num'].sum()
accuracy = num_correct / float(num.item())
acc_num = values['acc_num'].sum()
accuracy = values['acc_correct'].sum() / float(acc_num.item())
num_ref = values['num_ref'].sum()
num_hyp = values['num_hyp'].sum()
recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
precision = num_correct / float(num_hyp.item()) if num_ref != 0 else torch.tensor([0])
precision = num_correct / float(num_hyp.item()) if num_hyp != 0 else torch.tensor([0])
fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])
return {'val_loss': avg_loss, 'log': {'val_loss': avg_loss, 'bce': bce, 'accuracy': accuracy, 'recall': recall, 'precision': precision, 'fscore': fscore}}
log_metrics = {'bce': bce.item(), 'accuracy': accuracy.item(), 'recall': recall.item(), 'precision': precision.item(), 'fscore': fscore.item()}
self.custom_logger.log_metrics(self.epoch, log_metrics)
return {'val_loss': avg_loss}
def test_epoch_end(self, outputs):
metrics = outputs[0].keys()
values = {metric: torch.stack([x[metric] for x in outputs]) for metric in metrics}
avg_loss = values['test_loss'].mean()
bce = values['bce'].mean()
num_correct = values['num_correct'].sum()
acc_num = values['acc_num'].sum()
accuracy = values['acc_correct'].sum() / float(acc_num.item())
num_ref = values['num_ref'].sum()
num_hyp = values['num_hyp'].sum()
recall = num_correct / float(num_ref.item()) if num_ref != 0 else torch.tensor([0])
precision = num_correct / float(num_hyp.item()) if num_hyp != 0 else torch.tensor([0])
fscore = 2 * recall * precision / float((precision + recall).item()) if precision + recall != 0 else torch.tensor([0])
log_metrics = {'bce': bce.item(), 'accuracy': accuracy.item(), 'recall': recall.item(), 'precision': precision.item(), 'fscore': fscore.item()}
self.custom_logger.log_test(log_metrics)
return {'test_loss': avg_loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
scheduler = None
if self.hparams.scheduler == 'warmup_linear':
num_warmup_steps = self.hparams.scheduler_warmup
num_training_steps = self.hparams.epochs
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) )
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1)
elif self.hparams.scheduler != None:
raise ValueError('invalid scheduler "%s"' % self.hparams.scheduler)
if scheduler:
return [optimizer], [scheduler]
else:
return optimizer
def collate_fn(self, inputs):
text_len = max([len(x[0]) for x in inputs])
......@@ -118,21 +256,34 @@ class Model(LightningModule):
def val_dataloader(self):
return DataLoader(self.valid_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn)
def test_dataloader(self):
return DataLoader(self.test_set, batch_size=self.hparams.batch_size, pin_memory=True, collate_fn=self.collate_fn)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser])
parser.add_argument('--train_filename', type=str, required=True, help='name of json file containing training/validation instances')
parser.add_argument('--stem', type=str, required=True, help='stem name of json files containing training/validation/test instances (<stem>.{train,valid,test})')
parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in %% (default=10)')
parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
parser.add_argument('--dropout', default=.3, type=float, help='dropout after bert')
parser.add_argument('--loss', default='f1', type=str, help='choose loss function [f1, bce] (default=f1)')
parser.add_argument('--loss', default='bce', type=str, help='choose loss function [f1, bce] (default=bce)')
parser.add_argument('--augment_data', default=False, action='store_true', help='simulate missing abstract through augmentation (default=do not augment data)')
parser.add_argument('--transfer', default=None, type=str, help='transfer bert weights from checkpoint (default=do not transfer)')
parser.add_argument('--transfer', default=None, type=str, help='transfer weights from checkpoint (default=do not transfer)')
parser.add_argument('--model', default='bert', type=str, help='model type [rnn, bert] (default=bert)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--rnn_embed_size', default=128, type=int, help='rnn embedding size (default=128)')
parser.add_argument('--rnn_hidden_size', default=128, type=int, help='rnn hidden size (default=128)')
parser.add_argument('--rnn_layers', default=1, type=int, help='rnn number of layers (default=1)')
parser.add_argument('--cnn_embed_size', default=128, type=int, help='cnn embedding size (default=128)')
parser.add_argument('--cnn_hidden_size', default=128, type=int, help='cnn hidden size (default=128)')
parser.add_argument('--cnn_layers', default=1, type=int, help='cnn number of layers (default=1)')
parser.add_argument('--cnn_kernel_size', default=3, type=int, help='cnn kernel size (default=3)')
parser.add_argument('--scheduler', default=None, type=str, help='learning rate schedule [warmup_linear] (default=fixed learning rate)')
parser.add_argument('--scheduler_warmup', default=1, type=int, help='learning rate schedule warmup epochs (default=1)')
return parser
absl-py==0.9.0
bottle==0.12.18
cachetools==4.1.0
certifi==2020.4.5.1
chardet==3.0.4
......
......@@ -10,12 +10,16 @@ warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progre
warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*')
from model import Model
from logger import Logger
def main(hparams):
pytorch_lightning.seed_everything(hparams.seed)
model = Model(hparams)
logger = Logger(hparams.name, checkpoint_metric='fscore' if hparams.loss == 'f1' else 'bce')
checkpointer = pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint('checkpoints/%s-{epoch}-{val_loss:.4f}' % hparams.name)
model = Model(hparams)
model.custom_logger = logger
logger.log_hparams(hparams)
trainer = pytorch_lightning.Trainer(
max_nb_epochs=hparams.epochs,
......@@ -23,19 +27,26 @@ def main(hparams):
nb_gpu_nodes=hparams.nodes,
check_val_every_n_epoch=1,
progress_bar_refresh_rate=1,
checkpoint_callback=checkpointer,
logger=None,
checkpoint_callback=None,
num_sanity_val_steps=0,
fast_dev_run=hparams.fast_dev_run,
deterministic=True,
)
logger.set_save_function(trainer.save_checkpoint)
trainer.fit(model)
model = Model.load_from_checkpoint(logger.best_checkpoint)
model.custom_logger = logger
trainer.test(model)
if __name__ == '__main__':
parser = ArgumentParser(add_help=False)
parser.add_argument('--gpus', type=str, default=None)
parser.add_argument('--nodes', type=int, default=1)
parser.add_argument('--name', type=str, required=True)
parser.add_argument('--fast_dev_run', default=False, action='store_true')
parser.add_argument('--gpus', type=str, default=None, help='ids of GPUs to use (use -1 for all available GPUs, defaults to CPU)')
parser.add_argument('--nodes', type=int, default=1, help='number of computation nodes for distributed training (see lightning docs, defaults to 1)')
parser.add_argument('--name', type=str, required=True, help='name of experiment (required)')
parser.add_argument('--fast_dev_run', default=False, action='store_true', help='run a single batch through the whole training loop for catching bugs')
parser.add_argument('--seed', default=123, type=int, help='set global random seed (defaults to 123)')
parser = Model.add_model_specific_args(parser)
command_line = 'python ' + ' '.join(sys.argv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment