Skip to content
Snippets Groups Projects
Commit ee81e251 authored by Benoit Favre's avatar Benoit Favre
Browse files

fix test model selection; add oar scripts

parent eccc6d81
No related branches found
No related tags found
No related merge requests found
from glob import glob
import sys
import os
import json
import re
import math
import collections
def match(hparams):
for pattern in sys.argv[1:]:
found = False
for name, value in hparams.items():
key = '%s=%s' % (name, str(value))
if pattern == key:
found = True
break
if not found:
return False
return True
names = collections.defaultdict(list)
test_metrics = collections.defaultdict(lambda: collections.defaultdict(list))
val_metrics = collections.defaultdict(lambda: collections.defaultdict(list))
seen_hparams = collections.defaultdict(set)
for filename in glob('logs/*/run.json'):
dirname = os.path.dirname(filename)
if os.path.exists(dirname + '/finished'):
with open(filename) as fp:
experiment = json.loads(fp.read())
hparams = experiment['hparams']
name = hparams['name']
del hparams['name']
del hparams['cmd']
del hparams['num_labels']
hparams['stem'] = re.sub('-\d*$', '', hparams['stem'])
if match(hparams):
signature = json.dumps(hparams, sort_keys=True)
names[signature].append(name)
for metric, value in experiment['test'].items():
test_metrics[signature][metric].append(value)
test_metrics[signature]['val_loss'].append(experiment['best_loss'])
for metric, value in experiment['metrics']['20'].items():
val_metrics[signature][metric].append(value)
for k, v in hparams.items():
seen_hparams[k].add(json.dumps(v, sort_keys=True))
def compute_stats(values):
values.sort()
mean = sum(values) / len(values)
variance = sum((x - mean) ** 2 for x in values) / len(values)
return {'median': values[len(values) // 2], 'min': values[0], 'max': values[-1], 'mean': sum(values) / len(values), 'var': variance}
#WARN
test_metrics = val_metrics
best = None
best_hparams = None
best_signature = None
for signature in test_metrics:
stats = {metric: compute_stats(test_metrics[signature][metric]) for metric in test_metrics[signature]}
if best is None or stats['fscore']['mean'] > best['fscore']['mean']:
best = stats
best_hparams = {k: v for k, v in json.loads(signature).items()} # if len(seen_hparams[k]) > 1}
best_signature = signature
print('fscore mean=%.4f var=%.4f' % (best['fscore']['mean'], best['fscore']['var']))
for name, value in best_hparams.items():
print(name, '=', value)
print(names[best_signature])
......@@ -4,11 +4,12 @@ import sys
import collections
class Logger:
def __init__(self, name, checkpoint_metric='val_loss', logdir='logs', save_checkpoints=True):
def __init__(self, name, checkpoint_metric='val_loss', metric_aggregator=min, logdir='logs', save_checkpoints=True):
self.directory = os.path.join(logdir, name)
os.makedirs(self.directory, exist_ok=True)
self.metrics = collections.defaultdict(dict)
self.checkpoint_metric = checkpoint_metric
self.metric_aggregator = metric_aggregator
self.hparams = {}
self.best_loss = None
self.best_checkpoint = os.path.join(self.directory, 'best_checkpoint')
......@@ -23,7 +24,7 @@ class Logger:
self.metrics[epoch].update(metrics)
if self.save_checkpoints:
self.save_function(os.path.join(self.directory, 'last_checkpoint'))
if self.checkpoint_metric in metrics and (self.best_loss is None or metrics[self.checkpoint_metric] > self.best_loss):
if self.checkpoint_metric in metrics and (self.best_loss is None or self.metric_aggregator(metrics[self.checkpoint_metric], self.best_loss) == metrics[self.checkpoint_metric]):
self.best_loss = metrics[self.checkpoint_metric]
if self.save_checkpoints:
self.save_function(os.path.join(self.directory, 'best_checkpoint'))
......
#!/bin/bash
properties="(gpu IS NOT NULL) AND ((host != 'diflives1') AND (host != 'lifnode1') AND (host != 'sensei1') AND (host != 'asfalda1') AND (host != 'lisnode3'))"
find logs/ -mindepth 1 -maxdepth 1 -type d | while read dir; do
if [ ! -f $dir/finished ]; then
echo $dir
rm -f "$dir/stderr" "$dir/stdout"
oarsub -p "$properties" -l "walltime=16:00:00" -E "$dir/stderr" -O "$dir/stdout" "$dir/cmd"
fi
done
#!/bin/bash
#properties="(gpu IS NOT NULL) AND ((host != 'diflives1') AND (host != 'lifnode1') AND (host != 'sensei1') AND (host != 'asfalda1'))"
properties="(gpu IS NOT NULL) AND ((host != 'diflives1') AND (host != 'lifnode1') AND (host != 'sensei1') AND (host != 'asfalda1') AND (host != 'lisnode3'))"
#DEBUG=--fast_dev_run
DEBUG=
while read id job; do
dir="logs/exp_$id"
if [ -d "$dir" ]; then
echo ERROR: $dir already exists
exit 1
fi
mkdir -p "$dir"
cat > "$dir/cmd" << EOF
#!/bin/bash
source env/bin/activate
set -e -u -o pipefail
date > "$dir/started"
hostname >&2
echo CUDA_VISIBLE_DEVICES=\$CUDA_VISIBLE_DEVICES >&2
$job $DEBUG
date > "$dir/finished"
ls "$dir/"*checkpoint
rm -f "$dir/"*checkpoint
EOF
chmod +x "$dir/cmd"
oarsub -p "$properties" -l "walltime=16:00:00" -E "$dir/stderr" -O "$dir/stdout" "$dir/cmd"
done
......@@ -15,7 +15,12 @@ from logger import Logger
def main(hparams):
pytorch_lightning.seed_everything(hparams.seed)
logger = Logger(hparams.name, checkpoint_metric='fscore' if hparams.loss == 'f1' else 'bce')
if hparams.loss == 'f1':
logger = Logger(hparams.name, checkpoint_metric='fscore', metric_aggregator=max)
elif hparams.loss == 'bce':
logger = Logger(hparams.name, checkpoint_metric='bce', metric_aggregator=min)
else:
raise ValueError('invalid loss "%s"' % hparams.loss)
model = Model(hparams)
model.custom_logger = logger
......@@ -36,6 +41,7 @@ def main(hparams):
logger.set_save_function(trainer.save_checkpoint)
trainer.fit(model)
model = None
model = Model.load_from_checkpoint(logger.best_checkpoint)
model.custom_logger = logger
trainer.test(model)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment