Skip to content
Snippets Groups Projects
Commit cb400cc6 authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

- Finish training part (TODO: normalization implementation)

- Add error handling module (TODO: add logging over the code)
- Record dataset parameters and model parameters
- Begin compute_results, plotter and visualize files
parent baa0e4ea
No related branches found
No related tags found
1 merge request!3clean scripts
Showing
with 490 additions and 54 deletions
import os
LOG_PATH = os.path.abspath(os.path.dirname(__file__) + os.sep + '..' + os.sep + '..' + os.sep + 'log')
......@@ -66,11 +66,12 @@ class DatasetLoader(object):
X, y = dataset_loading_func(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=dataset_parameters.test_size,
random_state=dataset_parameters.seed)
random_state=dataset_parameters.random_state)
X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train,
test_size=dataset_parameters.dev_size,
random_state=dataset_parameters.seed)
random_state=dataset_parameters.random_state)
# TODO
if dataset_parameters.normalize:
pass
......
import json
import os
class DatasetParameters(object):
def __init__(self, name, test_size, dev_size, random_state, normalize):
......@@ -26,3 +30,15 @@ class DatasetParameters(object):
@property
def normalize(self):
return self._normalize
def save(self, directory_path, experiment_id):
with open(directory_path + os.sep + 'dataset_parameters_{}.json'.format(experiment_id), 'w') as output_file:
json.dump({
'name': self._name,
'test_size': self._test_size,
'dev_size': self._dev_size,
'random_state': self._random_state,
'normalize': self._normalize
},
output_file,
indent=4)
#####################################################################################
# MIT License #
# #
# Copyright (C) 2019 Charly Lamothe #
# #
# This file is part of VQ-VAE-Speech. #
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy #
# of this software and associated documentation files (the "Software"), to deal #
# in the Software without restriction, including without limitation the rights #
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell #
# copies of the Software, and to permit persons to whom the Software is #
# furnished to do so, subject to the following conditions: #
# #
# The above copyright notice and this permission notice shall be included in all #
# copies or substantial portions of the Software. #
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR #
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, #
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, #
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
# SOFTWARE. #
#####################################################################################
import os
LOG_PATH = os.path.abspath(os.path.dirname(__file__) + os.sep + '..' + os.sep + '..' + os.sep + 'log')
#####################################################################################
# MIT License #
# #
# Copyright (C) 2019 Charly Lamothe #
# #
# This file is part of VQ-VAE-Speech. #
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy #
# of this software and associated documentation files (the "Software"), to deal #
# in the Software without restriction, including without limitation the rights #
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell #
# copies of the Software, and to permit persons to whom the Software is #
# furnished to do so, subject to the following conditions: #
# #
# The above copyright notice and this permission notice shall be included in all #
# copies or substantial portions of the Software. #
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR #
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, #
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, #
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
# SOFTWARE. #
#####################################################################################
import sys
class ColorPrint(object):
""" Colored printing functions for strings that use universal ANSI escape sequences.
fail: bold red, pass: bold green, warn: bold yellow,
info: bold blue, bold: bold white
:source: https://stackoverflow.com/a/47622205
"""
@staticmethod
def print_fail(message, end='\n'):
sys.stderr.write('\x1b[1;31m' + message.strip() + '\x1b[0m' + end)
@staticmethod
def print_pass(message, end='\n'):
sys.stdout.write('\x1b[1;32m' + message.strip() + '\x1b[0m' + end)
@staticmethod
def print_warn(message, end='\n'):
sys.stderr.write('\x1b[1;33m' + message.strip() + '\x1b[0m' + end)
@staticmethod
def print_info(message, end='\n'):
sys.stdout.write('\x1b[1;34m' + message.strip() + '\x1b[0m' + end)
@staticmethod
def print_major_fail(message, end='\n'):
sys.stdout.write('\x1b[1;35m' + message.strip() + '\x1b[0m' + end)
@staticmethod
def print_bold(message, end='\n'):
sys.stdout.write('\x1b[1;37m' + message.strip() + '\x1b[0m' + end)
#####################################################################################
# MIT License #
# #
# Copyright (C) 2019 Charly Lamothe #
# #
# This file is part of VQ-VAE-Speech. #
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy #
# of this software and associated documentation files (the "Software"), to deal #
# in the Software without restriction, including without limitation the rights #
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell #
# copies of the Software, and to permit persons to whom the Software is #
# furnished to do so, subject to the following conditions: #
# #
# The above copyright notice and this permission notice shall be included in all #
# copies or substantial portions of the Software. #
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR #
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, #
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, #
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
# SOFTWARE. #
#####################################################################################
from error_handling.color_print import ColorPrint
import sys
import traceback
import os
class ConsoleLogger(object):
@staticmethod
def status(message):
if os.name == 'nt':
print('[~] {message}'.format(message=message))
else:
ColorPrint.print_info('[~] {message}'.format(message=message))
@staticmethod
def success(message):
if os.name == 'nt':
print('[+] {message}'.format(message=message))
else:
ColorPrint.print_pass('[+] {message}'.format(message=message))
@staticmethod
def error(message):
if sys.exc_info()[2]:
line = traceback.extract_tb(sys.exc_info()[2])[-1].lineno
error_message = '[-] {message} with cause: {cause} (line {line})'.format( \
message=message, cause=str(sys.exc_info()[1]), line=line)
else:
error_message = '[-] {message}'.format(message=message)
if os.name == 'nt':
print(error_message)
else:
ColorPrint.print_fail(error_message)
@staticmethod
def warn(message):
if os.name == 'nt':
print('[-] {message}'.format(message=message))
else:
ColorPrint.print_warn('[-] {message}'.format(message=message))
@staticmethod
def critical(message):
if sys.exc_info()[2]:
line = traceback.extract_tb(sys.exc_info()[2])[-1].lineno
error_message = '[!] {message} with cause: {cause} (line {line})'.format( \
message=message, cause=str(sys.exc_info()[1]), line=line)
else:
error_message = '[!] {message}'.format(message=message)
if os.name == 'nt':
print(error_message)
else:
ColorPrint.print_major_fail(error_message)
#####################################################################################
# MIT License #
# #
# Copyright (C) 2019 Charly Lamothe #
# #
# This file is part of VQ-VAE-Speech. #
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy #
# of this software and associated documentation files (the "Software"), to deal #
# in the Software without restriction, including without limitation the rights #
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell #
# copies of the Software, and to permit persons to whom the Software is #
# furnished to do so, subject to the following conditions: #
# #
# The above copyright notice and this permission notice shall be included in all #
# copies or substantial portions of the Software. #
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR #
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, #
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, #
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
# SOFTWARE. #
#####################################################################################
from functools import wraps
class InvalidRaiseException(Exception):
pass
def only_throws(E):
"""
:source: https://stackoverflow.com/a/18289516
"""
def decorator(f):
@wraps(f)
def wrapped(*args, **kwargs):
try:
return f(*args, **kwargs)
except E:
raise
except InvalidRaiseException:
raise
except Exception as e:
raise InvalidRaiseException(
'got %s, expected %s, from %s' % (e.__class__.__name__, E.__name__, f.__name__)
)
return wrapped
return decorator
#####################################################################################
# MIT License #
# #
# Copyright (C) 2019 Charly Lamothe #
# #
# This file is part of VQ-VAE-Speech. #
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy #
# of this software and associated documentation files (the "Software"), to deal #
# in the Software without restriction, including without limitation the rights #
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell #
# copies of the Software, and to permit persons to whom the Software is #
# furnished to do so, subject to the following conditions: #
# #
# The above copyright notice and this permission notice shall be included in all #
# copies or substantial portions of the Software. #
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR #
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, #
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, #
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
# SOFTWARE. #
#####################################################################################
import logging
from logging.handlers import RotatingFileHandler
import os
import errno
class LoggerFactory(object):
@staticmethod
def create(path, module_name):
# Create logger
logger = logging.getLogger(module_name)
logger.setLevel(logging.DEBUG)
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
# Create file handler
fh = RotatingFileHandler(path + os.sep + module_name + '.log', maxBytes=1000000, backupCount=5)
fh.setLevel(logging.DEBUG)
# Create console handler
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(filename)s:%(lineno)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to handlers
fh.setFormatter(formatter)
ch.setFormatter(formatter) # TODO: add another formatter to the console logger?
# Add fh and ch to logger
logger.addHandler(fh)
logger.addHandler(ch)
return logger
......@@ -6,15 +6,11 @@ from bolsonaro.data.task import Task
class ModelFactory(object):
@staticmethod
def build(task, forest_size, extracted_forest_size, seed=None):
def build(task, model_parameters):
if task == Task.CLASSIFICATION:
model_func = OmpForestClassifier
elif task == Task.REGRESSION:
model_func = OmpForestRegressor
else:
raise ValueError("Unsupported task '{}'".format(task))
return model_func(
forest_size=forest_size,
extracted_forest_size=extracted_forest_size,
seed=seed
)
return model_func(model_parameters)
import json
import os
class ModelParameters(object):
def __init__(self, forest_size, extracted_forest_size, seed=None):
self._forest_size = forest_size
self._extracted_forest_size = extracted_forest_size
self._seed = seed
@property
def forest_size(self):
return self._forest_size
@property
def extracted_forest_size(self):
return self._extracted_forest_size
@property
def seed(self):
return self._seed
def save(self, directory_path, experiment_id):
with open(directory_path + os.sep + 'model_parameters_{}.json'.format(experiment_id), 'w') as output_file:
json.dump({
'forest_size': self._forest_size,
'extracted_forest_size': self._extracted_forest_size,
'seed': self._seed
},
output_file,
indent=4)
from sklearn.base import BaseEstimator
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.base import BaseEstimator
class OmpForestRegressor(BaseEstimator):
def __init__(self, forest_size, extracted_forest_size, seed=None):
self._regressor = RandomForestRegressor(n_estimators=forest_size,
random_state=seed)
self._extracted_forest_size = extracted_forest_size
def __init__(self, models_parameters):
self._regressor = RandomForestRegressor(n_estimators=models_parameters.forest_size,
random_state=models_parameters.seed)
self._models_parameters = models_parameters
def fit(self, X_train, y_train):
self._forest = self._train_forest(X_train, y_train)
......@@ -25,14 +25,19 @@ class OmpForestRegressor(BaseEstimator):
def weights(self):
return self._weights
@property
def models_parameters(self):
return self._models_parameters
def _train_forest(self, X_train, y_train):
self._regressor.fit(X_train, y_train)
forest = self._regressor.estimators_
return forest
def _extract_subforest(self, X_train, y_train):
D = [[tree.predict([elem])[0] for tree in forest] for elem in X_train]
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=self._extracted_forest_size,
D = [[tree.predict([elem])[0] for tree in self._forest] for elem in X_train]
omp = OrthogonalMatchingPursuit(
n_nonzero_coefs=self._models_parameters.extracted_forest_size,
fit_intercept=False, normalize=False)
omp.fit(D, y_train)
weights = omp.coef_
......
from bolsonaro.utils import resolve_output_file_name
from bolsonaro.error_handling.logger_factory import LoggerFactory
from . import LOG_PATH
import pickle
import os
import time
import datetime
class Trainer(object):
def __init__(self, dataset, model, results_dir, models_dir):
def __init__(self, dataset):
self._dataset = dataset
self._model = model
self._results_dir = results_dir
self._models_dir = models_dir
self._logger = LoggerFactory.create(LOG_PATH, __name__)
def process(self):
self._model.fit(self._dataset.X_train, self._dataset.y_train)
output_file_name = resolve_output_file_name(
self._dataset.dataset_parameters,
self._model.model_parameters,
self._results_dir,
self._models_dir
)
with open(output_file_name, 'wb') as output_file:
pickle.dump(output_file, {
def iterate(self, model, models_dir):
self._logger.info('Training model using train set...')
begin_time = time.time()
model.fit(self._dataset.X_train, self._dataset.y_train)
end_time = time.time()
})
# save forest and weights here
output_file_path = models_dir + os.sep + 'model.pickle'
self._logger.info('Saving trained model to {}'.format(output_file_path))
with open(output_file_path, 'wb') as output_file:
pickle.dump({
'forest': model.forest,
'weights': model.weights,
'training_time': end_time - begin_time,
'datetime': datetime.datetime.now()
}, output_file)
import os
def resolve_experiment_id(models_dir):
ids = [x for x in os.listdir(models_dir)
if os.path.isdir(models_dir + os.sep + x)]
if len(ids) > 0:
ids.sort(key=int)
return int(max(ids)) + 1
return 1
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors.kde import KernelDensity
class Plotter(object):
\ No newline at end of file
@staticmethod
def weight_density(weights):
"""
TODO: to complete
"""
X_plot = [np.exp(elem) for elem in weights]
fig, ax = plt.subplots()
for kernel in ['gaussian', 'tophat', 'epanechnikov']:
kde = KernelDensity(kernel=kernel, bandwidth=0.5).fit(X_plot)
log_dens = kde.score_samples(X_plot)
ax.plot(X_plot[:, 0], np.exp(log_dens), '-',
label="kernel = '{0}'".format(kernel))
ax.legend(loc='upper left')
ax.plot(X[:, 0], -0.005 - 0.01 * np.random.random(X.shape[0]), '+k')
ax.set_xlim(-4, 9)
ax.set_ylim(-0.02, 0.4)
plt.show()
import argparse
import pathlib
if __name__ == "__main__":
default_results_dir = 'results'
default_models_dir = 'models'
default_experiment_id = -1
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--results_dir', nargs='?', type=str, default=default_results_dir, help='The output directory of the results.')
parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.')
parser.add_argument('--experiment_id', nargs='?', type=int, default=default_experiment_id, help='Compute the results of a single experiment id')
args = parser.parse_args()
pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)
if args.experiment_id == -1:
pass
else:
pass
from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.dataset_loader import DatasetLoader
from bolsonaro.models.model_factory import ModelFactory
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.trainer import Trainer
from bolsonaro.utils import resolve_experiment_id
import argparse
import pathlib
import random
import os
import errno
if __name__ == "__main__":
......@@ -13,7 +17,6 @@ if __name__ == "__main__":
default_normalize = False
default_forest_size = 100
default_extracted_forest_size = 10
default_results_dir = 'results'
default_models_dir = 'models'
default_dev_size = 0.2
default_test_size = 0.2
......@@ -27,7 +30,6 @@ if __name__ == "__main__":
parser.add_argument('--normalize', action='store_true', default=default_normalize, help='Normalize the data by doing the L2 division of the pred vectors.')
parser.add_argument('--forest_size', nargs='?', type=int, default=default_forest_size, help='The number of trees of the random forest.')
parser.add_argument('--extracted_forest_size', nargs='+', type=int, default=default_extracted_forest_size, help='The number of trees selected by OMP.')
parser.add_argument('--results_dir', nargs='?', type=str, default=default_results_dir, help='The output directory of the results.')
parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.')
parser.add_argument('--dev_size', nargs='?', type=float, default=default_dev_size, help='Dev subset ratio')
parser.add_argument('--test_size', nargs='?', type=float, default=default_test_size, help='Test subset ratio')
......@@ -35,35 +37,57 @@ if __name__ == "__main__":
parser.add_argument('--random_seed_number', nargs='?', type=int, default=default_random_seed_number, help='Number of random seeds used')
args = parser.parse_args()
pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True)
random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) for i in range(args.random_seed_number)] \
args.extracted_forest_size = args.extracted_forest_size \
if type(args.extracted_forest_size) == list \
else [args.extracted_forest_size]
random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) \
for i in range(args.random_seed_number)] \
if args.use_random_seed else None
experiment_id = resolve_experiment_id(args.models_dir)
experiment_id_str = str(experiment_id)
for random_seed in random_seeds:
dataset = DatasetLoader.load_from_name(
DatasetParameters(
name=args.dataset_name,
test_size=args.test_size,
dev_size=args.dev_size,
random_state=random_seed,
normalize=args.normalize
)
random_seed_str = str(random_seed)
models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \
os.sep + random_seed_str
try:
os.makedirs(models_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
dataset_parameters = DatasetParameters(
name=args.dataset_name,
test_size=args.test_size,
dev_size=args.dev_size,
random_state=random_seed,
normalize=args.normalize
)
dataset_parameters.save(models_dir, experiment_id_str)
dataset = DatasetLoader.load_from_name(dataset_parameters)
trainer = Trainer(dataset)
for extracted_forest_size in args.extracted_forest_size:
model = ModelFactory(
task=dataset.task,
sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)
try:
os.makedirs(sub_models_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
model_parameters = ModelParameters(
forest_size=args.forest_size,
extracted_forest_size=extracted_forest_size,
seed=random_seed
)
model_parameters.save(sub_models_dir, experiment_id)
trainer = Trainer(
dataset=dataset,
model=model,
results_dir=args.results_dir,
models_dir=args.models_dir
)
trainer.process()
model = ModelFactory.build(dataset.task, model_parameters)
trainer.iterate(model, sub_models_dir)
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--weight_density', action='store_true', default=default_use_weright_density, help='')
args = parser.parse_args()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment