From cb400cc6bb846326a030e083936bc9854f62b54b Mon Sep 17 00:00:00 2001 From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr> Date: Sun, 3 Nov 2019 13:40:36 +0100 Subject: [PATCH] - Finish training part (TODO: normalization implementation) - Add error handling module (TODO: add logging over the code) - Record dataset parameters and model parameters - Begin compute_results, plotter and visualize files --- bolsonaro/__init__.py | 3 + bolsonaro/data/dataset_loader.py | 5 +- bolsonaro/data/dataset_parameters.py | 16 ++++ bolsonaro/error_handling/__init__.py | 29 +++++++ bolsonaro/error_handling/color_print.py | 61 ++++++++++++++ bolsonaro/error_handling/console_logger.py | 81 +++++++++++++++++++ .../error_handling/exception_decorators.py | 55 +++++++++++++ bolsonaro/error_handling/logger_factory.py | 66 +++++++++++++++ bolsonaro/models/model_factory.py | 8 +- bolsonaro/models/model_parameters.py | 32 ++++++++ bolsonaro/models/omp_forest_regressor.py | 19 +++-- bolsonaro/trainer.py | 38 +++++---- bolsonaro/utils.py | 10 +++ bolsonaro/visualization/plotter.py | 26 +++++- compute_results.py | 21 +++++ train.py | 66 ++++++++++----- visualize.py | 8 ++ 17 files changed, 490 insertions(+), 54 deletions(-) create mode 100644 bolsonaro/error_handling/__init__.py create mode 100644 bolsonaro/error_handling/color_print.py create mode 100644 bolsonaro/error_handling/console_logger.py create mode 100644 bolsonaro/error_handling/exception_decorators.py create mode 100644 bolsonaro/error_handling/logger_factory.py create mode 100644 bolsonaro/models/model_parameters.py create mode 100644 compute_results.py diff --git a/bolsonaro/__init__.py b/bolsonaro/__init__.py index e69de29..ce8e424 100644 --- a/bolsonaro/__init__.py +++ b/bolsonaro/__init__.py @@ -0,0 +1,3 @@ +import os + +LOG_PATH = os.path.abspath(os.path.dirname(__file__) + os.sep + '..' + os.sep + '..' + os.sep + 'log') diff --git a/bolsonaro/data/dataset_loader.py b/bolsonaro/data/dataset_loader.py index 1e4264e..c510a90 100644 --- a/bolsonaro/data/dataset_loader.py +++ b/bolsonaro/data/dataset_loader.py @@ -66,11 +66,12 @@ class DatasetLoader(object): X, y = dataset_loading_func(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=dataset_parameters.test_size, - random_state=dataset_parameters.seed) + random_state=dataset_parameters.random_state) X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, test_size=dataset_parameters.dev_size, - random_state=dataset_parameters.seed) + random_state=dataset_parameters.random_state) + # TODO if dataset_parameters.normalize: pass diff --git a/bolsonaro/data/dataset_parameters.py b/bolsonaro/data/dataset_parameters.py index e820b8f..556c960 100644 --- a/bolsonaro/data/dataset_parameters.py +++ b/bolsonaro/data/dataset_parameters.py @@ -1,3 +1,7 @@ +import json +import os + + class DatasetParameters(object): def __init__(self, name, test_size, dev_size, random_state, normalize): @@ -26,3 +30,15 @@ class DatasetParameters(object): @property def normalize(self): return self._normalize + + def save(self, directory_path, experiment_id): + with open(directory_path + os.sep + 'dataset_parameters_{}.json'.format(experiment_id), 'w') as output_file: + json.dump({ + 'name': self._name, + 'test_size': self._test_size, + 'dev_size': self._dev_size, + 'random_state': self._random_state, + 'normalize': self._normalize + }, + output_file, + indent=4) diff --git a/bolsonaro/error_handling/__init__.py b/bolsonaro/error_handling/__init__.py new file mode 100644 index 0000000..a8ca18d --- /dev/null +++ b/bolsonaro/error_handling/__init__.py @@ -0,0 +1,29 @@ + ##################################################################################### + # MIT License # + # # + # Copyright (C) 2019 Charly Lamothe # + # # + # This file is part of VQ-VAE-Speech. # + # # + # Permission is hereby granted, free of charge, to any person obtaining a copy # + # of this software and associated documentation files (the "Software"), to deal # + # in the Software without restriction, including without limitation the rights # + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # + # copies of the Software, and to permit persons to whom the Software is # + # furnished to do so, subject to the following conditions: # + # # + # The above copyright notice and this permission notice shall be included in all # + # copies or substantial portions of the Software. # + # # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # + # SOFTWARE. # + ##################################################################################### + +import os + +LOG_PATH = os.path.abspath(os.path.dirname(__file__) + os.sep + '..' + os.sep + '..' + os.sep + 'log') diff --git a/bolsonaro/error_handling/color_print.py b/bolsonaro/error_handling/color_print.py new file mode 100644 index 0000000..b577e5a --- /dev/null +++ b/bolsonaro/error_handling/color_print.py @@ -0,0 +1,61 @@ + ##################################################################################### + # MIT License # + # # + # Copyright (C) 2019 Charly Lamothe # + # # + # This file is part of VQ-VAE-Speech. # + # # + # Permission is hereby granted, free of charge, to any person obtaining a copy # + # of this software and associated documentation files (the "Software"), to deal # + # in the Software without restriction, including without limitation the rights # + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # + # copies of the Software, and to permit persons to whom the Software is # + # furnished to do so, subject to the following conditions: # + # # + # The above copyright notice and this permission notice shall be included in all # + # copies or substantial portions of the Software. # + # # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # + # SOFTWARE. # + ##################################################################################### + +import sys + + +class ColorPrint(object): + """ Colored printing functions for strings that use universal ANSI escape sequences. + + fail: bold red, pass: bold green, warn: bold yellow, + info: bold blue, bold: bold white + + :source: https://stackoverflow.com/a/47622205 + """ + + @staticmethod + def print_fail(message, end='\n'): + sys.stderr.write('\x1b[1;31m' + message.strip() + '\x1b[0m' + end) + + @staticmethod + def print_pass(message, end='\n'): + sys.stdout.write('\x1b[1;32m' + message.strip() + '\x1b[0m' + end) + + @staticmethod + def print_warn(message, end='\n'): + sys.stderr.write('\x1b[1;33m' + message.strip() + '\x1b[0m' + end) + + @staticmethod + def print_info(message, end='\n'): + sys.stdout.write('\x1b[1;34m' + message.strip() + '\x1b[0m' + end) + + @staticmethod + def print_major_fail(message, end='\n'): + sys.stdout.write('\x1b[1;35m' + message.strip() + '\x1b[0m' + end) + + @staticmethod + def print_bold(message, end='\n'): + sys.stdout.write('\x1b[1;37m' + message.strip() + '\x1b[0m' + end) diff --git a/bolsonaro/error_handling/console_logger.py b/bolsonaro/error_handling/console_logger.py new file mode 100644 index 0000000..7014b4c --- /dev/null +++ b/bolsonaro/error_handling/console_logger.py @@ -0,0 +1,81 @@ + ##################################################################################### + # MIT License # + # # + # Copyright (C) 2019 Charly Lamothe # + # # + # This file is part of VQ-VAE-Speech. # + # # + # Permission is hereby granted, free of charge, to any person obtaining a copy # + # of this software and associated documentation files (the "Software"), to deal # + # in the Software without restriction, including without limitation the rights # + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # + # copies of the Software, and to permit persons to whom the Software is # + # furnished to do so, subject to the following conditions: # + # # + # The above copyright notice and this permission notice shall be included in all # + # copies or substantial portions of the Software. # + # # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # + # SOFTWARE. # + ##################################################################################### + +from error_handling.color_print import ColorPrint + +import sys +import traceback +import os + + +class ConsoleLogger(object): + + @staticmethod + def status(message): + if os.name == 'nt': + print('[~] {message}'.format(message=message)) + else: + ColorPrint.print_info('[~] {message}'.format(message=message)) + + @staticmethod + def success(message): + if os.name == 'nt': + print('[+] {message}'.format(message=message)) + else: + ColorPrint.print_pass('[+] {message}'.format(message=message)) + + @staticmethod + def error(message): + if sys.exc_info()[2]: + line = traceback.extract_tb(sys.exc_info()[2])[-1].lineno + error_message = '[-] {message} with cause: {cause} (line {line})'.format( \ + message=message, cause=str(sys.exc_info()[1]), line=line) + else: + error_message = '[-] {message}'.format(message=message) + if os.name == 'nt': + print(error_message) + else: + ColorPrint.print_fail(error_message) + + @staticmethod + def warn(message): + if os.name == 'nt': + print('[-] {message}'.format(message=message)) + else: + ColorPrint.print_warn('[-] {message}'.format(message=message)) + + @staticmethod + def critical(message): + if sys.exc_info()[2]: + line = traceback.extract_tb(sys.exc_info()[2])[-1].lineno + error_message = '[!] {message} with cause: {cause} (line {line})'.format( \ + message=message, cause=str(sys.exc_info()[1]), line=line) + else: + error_message = '[!] {message}'.format(message=message) + if os.name == 'nt': + print(error_message) + else: + ColorPrint.print_major_fail(error_message) diff --git a/bolsonaro/error_handling/exception_decorators.py b/bolsonaro/error_handling/exception_decorators.py new file mode 100644 index 0000000..428c618 --- /dev/null +++ b/bolsonaro/error_handling/exception_decorators.py @@ -0,0 +1,55 @@ + ##################################################################################### + # MIT License # + # # + # Copyright (C) 2019 Charly Lamothe # + # # + # This file is part of VQ-VAE-Speech. # + # # + # Permission is hereby granted, free of charge, to any person obtaining a copy # + # of this software and associated documentation files (the "Software"), to deal # + # in the Software without restriction, including without limitation the rights # + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # + # copies of the Software, and to permit persons to whom the Software is # + # furnished to do so, subject to the following conditions: # + # # + # The above copyright notice and this permission notice shall be included in all # + # copies or substantial portions of the Software. # + # # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # + # SOFTWARE. # + ##################################################################################### + +from functools import wraps + + +class InvalidRaiseException(Exception): + pass + + +def only_throws(E): + """ + :source: https://stackoverflow.com/a/18289516 + """ + + def decorator(f): + @wraps(f) + def wrapped(*args, **kwargs): + try: + return f(*args, **kwargs) + except E: + raise + except InvalidRaiseException: + raise + except Exception as e: + raise InvalidRaiseException( + 'got %s, expected %s, from %s' % (e.__class__.__name__, E.__name__, f.__name__) + ) + + return wrapped + + return decorator diff --git a/bolsonaro/error_handling/logger_factory.py b/bolsonaro/error_handling/logger_factory.py new file mode 100644 index 0000000..f524851 --- /dev/null +++ b/bolsonaro/error_handling/logger_factory.py @@ -0,0 +1,66 @@ + ##################################################################################### + # MIT License # + # # + # Copyright (C) 2019 Charly Lamothe # + # # + # This file is part of VQ-VAE-Speech. # + # # + # Permission is hereby granted, free of charge, to any person obtaining a copy # + # of this software and associated documentation files (the "Software"), to deal # + # in the Software without restriction, including without limitation the rights # + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # + # copies of the Software, and to permit persons to whom the Software is # + # furnished to do so, subject to the following conditions: # + # # + # The above copyright notice and this permission notice shall be included in all # + # copies or substantial portions of the Software. # + # # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # + # SOFTWARE. # + ##################################################################################### + +import logging +from logging.handlers import RotatingFileHandler +import os +import errno + + +class LoggerFactory(object): + + @staticmethod + def create(path, module_name): + # Create logger + logger = logging.getLogger(module_name) + logger.setLevel(logging.DEBUG) + + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + # Create file handler + fh = RotatingFileHandler(path + os.sep + module_name + '.log', maxBytes=1000000, backupCount=5) + fh.setLevel(logging.DEBUG) + + # Create console handler + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(filename)s:%(lineno)s - %(name)s - %(levelname)s - %(message)s') + + # Add formatter to handlers + fh.setFormatter(formatter) + ch.setFormatter(formatter) # TODO: add another formatter to the console logger? + + # Add fh and ch to logger + logger.addHandler(fh) + logger.addHandler(ch) + + return logger diff --git a/bolsonaro/models/model_factory.py b/bolsonaro/models/model_factory.py index a2f02a6..5bad7f4 100644 --- a/bolsonaro/models/model_factory.py +++ b/bolsonaro/models/model_factory.py @@ -6,15 +6,11 @@ from bolsonaro.data.task import Task class ModelFactory(object): @staticmethod - def build(task, forest_size, extracted_forest_size, seed=None): + def build(task, model_parameters): if task == Task.CLASSIFICATION: model_func = OmpForestClassifier elif task == Task.REGRESSION: model_func = OmpForestRegressor else: raise ValueError("Unsupported task '{}'".format(task)) - return model_func( - forest_size=forest_size, - extracted_forest_size=extracted_forest_size, - seed=seed - ) + return model_func(model_parameters) diff --git a/bolsonaro/models/model_parameters.py b/bolsonaro/models/model_parameters.py new file mode 100644 index 0000000..b1fec8c --- /dev/null +++ b/bolsonaro/models/model_parameters.py @@ -0,0 +1,32 @@ +import json +import os + + +class ModelParameters(object): + + def __init__(self, forest_size, extracted_forest_size, seed=None): + self._forest_size = forest_size + self._extracted_forest_size = extracted_forest_size + self._seed = seed + + @property + def forest_size(self): + return self._forest_size + + @property + def extracted_forest_size(self): + return self._extracted_forest_size + + @property + def seed(self): + return self._seed + + def save(self, directory_path, experiment_id): + with open(directory_path + os.sep + 'model_parameters_{}.json'.format(experiment_id), 'w') as output_file: + json.dump({ + 'forest_size': self._forest_size, + 'extracted_forest_size': self._extracted_forest_size, + 'seed': self._seed + }, + output_file, + indent=4) diff --git a/bolsonaro/models/omp_forest_regressor.py b/bolsonaro/models/omp_forest_regressor.py index 17d99aa..be60cae 100644 --- a/bolsonaro/models/omp_forest_regressor.py +++ b/bolsonaro/models/omp_forest_regressor.py @@ -1,14 +1,14 @@ -from sklearn.base import BaseEstimator from sklearn.ensemble import RandomForestRegressor from sklearn.linear_model import OrthogonalMatchingPursuit +from sklearn.base import BaseEstimator class OmpForestRegressor(BaseEstimator): - def __init__(self, forest_size, extracted_forest_size, seed=None): - self._regressor = RandomForestRegressor(n_estimators=forest_size, - random_state=seed) - self._extracted_forest_size = extracted_forest_size + def __init__(self, models_parameters): + self._regressor = RandomForestRegressor(n_estimators=models_parameters.forest_size, + random_state=models_parameters.seed) + self._models_parameters = models_parameters def fit(self, X_train, y_train): self._forest = self._train_forest(X_train, y_train) @@ -25,14 +25,19 @@ class OmpForestRegressor(BaseEstimator): def weights(self): return self._weights + @property + def models_parameters(self): + return self._models_parameters + def _train_forest(self, X_train, y_train): self._regressor.fit(X_train, y_train) forest = self._regressor.estimators_ return forest def _extract_subforest(self, X_train, y_train): - D = [[tree.predict([elem])[0] for tree in forest] for elem in X_train] - omp = OrthogonalMatchingPursuit(n_nonzero_coefs=self._extracted_forest_size, + D = [[tree.predict([elem])[0] for tree in self._forest] for elem in X_train] + omp = OrthogonalMatchingPursuit( + n_nonzero_coefs=self._models_parameters.extracted_forest_size, fit_intercept=False, normalize=False) omp.fit(D, y_train) weights = omp.coef_ diff --git a/bolsonaro/trainer.py b/bolsonaro/trainer.py index cb9f9fe..7c1436b 100644 --- a/bolsonaro/trainer.py +++ b/bolsonaro/trainer.py @@ -1,26 +1,30 @@ -from bolsonaro.utils import resolve_output_file_name +from bolsonaro.error_handling.logger_factory import LoggerFactory +from . import LOG_PATH import pickle +import os +import time +import datetime class Trainer(object): - def __init__(self, dataset, model, results_dir, models_dir): + def __init__(self, dataset): self._dataset = dataset - self._model = model - self._results_dir = results_dir - self._models_dir = models_dir + self._logger = LoggerFactory.create(LOG_PATH, __name__) - def process(self): - self._model.fit(self._dataset.X_train, self._dataset.y_train) - output_file_name = resolve_output_file_name( - self._dataset.dataset_parameters, - self._model.model_parameters, - self._results_dir, - self._models_dir - ) - with open(output_file_name, 'wb') as output_file: - pickle.dump(output_file, { + def iterate(self, model, models_dir): + self._logger.info('Training model using train set...') + begin_time = time.time() + model.fit(self._dataset.X_train, self._dataset.y_train) + end_time = time.time() - }) - # save forest and weights here + output_file_path = models_dir + os.sep + 'model.pickle' + self._logger.info('Saving trained model to {}'.format(output_file_path)) + with open(output_file_path, 'wb') as output_file: + pickle.dump({ + 'forest': model.forest, + 'weights': model.weights, + 'training_time': end_time - begin_time, + 'datetime': datetime.datetime.now() + }, output_file) diff --git a/bolsonaro/utils.py b/bolsonaro/utils.py index e69de29..2affd37 100644 --- a/bolsonaro/utils.py +++ b/bolsonaro/utils.py @@ -0,0 +1,10 @@ +import os + + +def resolve_experiment_id(models_dir): + ids = [x for x in os.listdir(models_dir) + if os.path.isdir(models_dir + os.sep + x)] + if len(ids) > 0: + ids.sort(key=int) + return int(max(ids)) + 1 + return 1 diff --git a/bolsonaro/visualization/plotter.py b/bolsonaro/visualization/plotter.py index 01f0f03..c119d47 100644 --- a/bolsonaro/visualization/plotter.py +++ b/bolsonaro/visualization/plotter.py @@ -1,3 +1,27 @@ +import matplotlib.pyplot as plt +import numpy as np +from sklearn.neighbors.kde import KernelDensity + + class Plotter(object): - \ No newline at end of file + @staticmethod + def weight_density(weights): + """ + TODO: to complete + """ + X_plot = [np.exp(elem) for elem in weights] + fig, ax = plt.subplots() + + for kernel in ['gaussian', 'tophat', 'epanechnikov']: + kde = KernelDensity(kernel=kernel, bandwidth=0.5).fit(X_plot) + log_dens = kde.score_samples(X_plot) + ax.plot(X_plot[:, 0], np.exp(log_dens), '-', + label="kernel = '{0}'".format(kernel)) + + ax.legend(loc='upper left') + ax.plot(X[:, 0], -0.005 - 0.01 * np.random.random(X.shape[0]), '+k') + + ax.set_xlim(-4, 9) + ax.set_ylim(-0.02, 0.4) + plt.show() diff --git a/compute_results.py b/compute_results.py new file mode 100644 index 0000000..ba80f0b --- /dev/null +++ b/compute_results.py @@ -0,0 +1,21 @@ +import argparse +import pathlib + + +if __name__ == "__main__": + default_results_dir = 'results' + default_models_dir = 'models' + default_experiment_id = -1 + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--results_dir', nargs='?', type=str, default=default_results_dir, help='The output directory of the results.') + parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.') + parser.add_argument('--experiment_id', nargs='?', type=int, default=default_experiment_id, help='Compute the results of a single experiment id') + args = parser.parse_args() + + pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True) + + if args.experiment_id == -1: + pass + else: + pass diff --git a/train.py b/train.py index cd6f228..0e6896c 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,15 @@ from bolsonaro.data.dataset_parameters import DatasetParameters from bolsonaro.data.dataset_loader import DatasetLoader from bolsonaro.models.model_factory import ModelFactory +from bolsonaro.models.model_parameters import ModelParameters from bolsonaro.trainer import Trainer +from bolsonaro.utils import resolve_experiment_id import argparse import pathlib import random +import os +import errno if __name__ == "__main__": @@ -13,7 +17,6 @@ if __name__ == "__main__": default_normalize = False default_forest_size = 100 default_extracted_forest_size = 10 - default_results_dir = 'results' default_models_dir = 'models' default_dev_size = 0.2 default_test_size = 0.2 @@ -27,7 +30,6 @@ if __name__ == "__main__": parser.add_argument('--normalize', action='store_true', default=default_normalize, help='Normalize the data by doing the L2 division of the pred vectors.') parser.add_argument('--forest_size', nargs='?', type=int, default=default_forest_size, help='The number of trees of the random forest.') parser.add_argument('--extracted_forest_size', nargs='+', type=int, default=default_extracted_forest_size, help='The number of trees selected by OMP.') - parser.add_argument('--results_dir', nargs='?', type=str, default=default_results_dir, help='The output directory of the results.') parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.') parser.add_argument('--dev_size', nargs='?', type=float, default=default_dev_size, help='Dev subset ratio') parser.add_argument('--test_size', nargs='?', type=float, default=default_test_size, help='Test subset ratio') @@ -35,35 +37,57 @@ if __name__ == "__main__": parser.add_argument('--random_seed_number', nargs='?', type=int, default=default_random_seed_number, help='Number of random seeds used') args = parser.parse_args() - pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True) pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True) - random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) for i in range(args.random_seed_number)] \ + args.extracted_forest_size = args.extracted_forest_size \ + if type(args.extracted_forest_size) == list \ + else [args.extracted_forest_size] + + random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) \ + for i in range(args.random_seed_number)] \ if args.use_random_seed else None + experiment_id = resolve_experiment_id(args.models_dir) + experiment_id_str = str(experiment_id) + for random_seed in random_seeds: - dataset = DatasetLoader.load_from_name( - DatasetParameters( - name=args.dataset_name, - test_size=args.test_size, - dev_size=args.dev_size, - random_state=random_seed, - normalize=args.normalize - ) + random_seed_str = str(random_seed) + models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \ + os.sep + random_seed_str + try: + os.makedirs(models_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + dataset_parameters = DatasetParameters( + name=args.dataset_name, + test_size=args.test_size, + dev_size=args.dev_size, + random_state=random_seed, + normalize=args.normalize ) + dataset_parameters.save(models_dir, experiment_id_str) + + dataset = DatasetLoader.load_from_name(dataset_parameters) + + trainer = Trainer(dataset) for extracted_forest_size in args.extracted_forest_size: - model = ModelFactory( - task=dataset.task, + sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size) + try: + os.makedirs(sub_models_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + model_parameters = ModelParameters( forest_size=args.forest_size, extracted_forest_size=extracted_forest_size, seed=random_seed ) + model_parameters.save(sub_models_dir, experiment_id) - trainer = Trainer( - dataset=dataset, - model=model, - results_dir=args.results_dir, - models_dir=args.models_dir - ) - trainer.process() + model = ModelFactory.build(dataset.task, model_parameters) + + trainer.iterate(model, sub_models_dir) diff --git a/visualize.py b/visualize.py index e69de29..6ae9da1 100644 --- a/visualize.py +++ b/visualize.py @@ -0,0 +1,8 @@ +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--weight_density', action='store_true', default=default_use_weright_density, help='') + args = parser.parse_args() + -- GitLab