Skip to content
Snippets Groups Projects
Commit 2cc31d8c authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

- Add progress bar during training;

- Update requirements packages.
parent 8b8eb9a5
No related branches found
No related tags found
1 merge request!3clean scripts
......@@ -12,7 +12,7 @@ import argparse
import pathlib
import random
import os
import errno
from tqdm import tqdm
if __name__ == "__main__":
......@@ -67,51 +67,42 @@ if __name__ == "__main__":
experiment_id = resolve_experiment_id(args.models_dir)
experiment_id_str = str(experiment_id)
for seed in seeds:
logger.debug('Seed={}'.format(seed))
seed_str = str(seed)
models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \
os.sep + seed_str
try:
os.makedirs(models_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
dataset_parameters = DatasetParameters(
name=args.dataset_name,
test_size=args.test_size,
dev_size=args.dev_size,
random_state=seed,
normalize=normalize,
train_on_subset=args.train_on_subset
)
dataset_parameters.save(models_dir, experiment_id_str)
dataset = DatasetLoader.load(dataset_parameters)
trainer = Trainer(dataset)
for extracted_forest_size in args.extracted_forest_size:
logger.debug('extracted_forest_size={}'.format(extracted_forest_size))
sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)
try:
os.makedirs(sub_models_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
model_parameters = ModelParameters(
forest_size=args.forest_size,
extracted_forest_size=extracted_forest_size,
seed=seed,
normalize=normalize
with tqdm(seeds) as seed_bar:
for seed in seed_bar:
seed_bar.set_description('seed={}'.format(seed))
seed_str = str(seed)
models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \
os.sep + seed_str
pathlib.Path(models_dir).mkdir(parents=True, exist_ok=True)
dataset_parameters = DatasetParameters(
name=args.dataset_name,
test_size=args.test_size,
dev_size=args.dev_size,
random_state=seed,
normalize=normalize,
train_on_subset=args.train_on_subset
)
model_parameters.save(sub_models_dir, experiment_id)
dataset_parameters.save(models_dir, experiment_id_str)
model = ModelFactory.build(dataset.task, model_parameters)
dataset = DatasetLoader.load(dataset_parameters)
trainer.train(model, sub_models_dir)
trainer = Trainer(dataset)
logger.info('Error on test set: {}'.format(model.score(dataset.X_test, dataset.y_test)))
logger.info('Accuracy on test set: {}'.format(model.score_regressor(dataset.X_test, dataset.y_test)))
with tqdm(args.extracted_forest_size) as extracted_forest_size_bar:
for extracted_forest_size in extracted_forest_size_bar:
extracted_forest_size_bar.set_description('extracted_forest_size={}'.format(extracted_forest_size))
sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)
pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
model_parameters = ModelParameters(
forest_size=args.forest_size,
extracted_forest_size=extracted_forest_size,
seed=seed,
normalize=normalize
)
model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters)
trainer.train(model, sub_models_dir)
......@@ -8,4 +8,6 @@ coverage
awscli
flake8
python-dotenv>=0.5.1
scikit-learn
\ No newline at end of file
scikit-learn
python-dotenv
tqdm
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment