From cb400cc6bb846326a030e083936bc9854f62b54b Mon Sep 17 00:00:00 2001
From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr>
Date: Sun, 3 Nov 2019 13:40:36 +0100
Subject: [PATCH] - Finish training part (TODO: normalization implementation) -
 Add error handling module (TODO: add logging over the code) - Record dataset
 parameters and model parameters - Begin compute_results, plotter and
 visualize files

---
 bolsonaro/__init__.py                         |  3 +
 bolsonaro/data/dataset_loader.py              |  5 +-
 bolsonaro/data/dataset_parameters.py          | 16 ++++
 bolsonaro/error_handling/__init__.py          | 29 +++++++
 bolsonaro/error_handling/color_print.py       | 61 ++++++++++++++
 bolsonaro/error_handling/console_logger.py    | 81 +++++++++++++++++++
 .../error_handling/exception_decorators.py    | 55 +++++++++++++
 bolsonaro/error_handling/logger_factory.py    | 66 +++++++++++++++
 bolsonaro/models/model_factory.py             |  8 +-
 bolsonaro/models/model_parameters.py          | 32 ++++++++
 bolsonaro/models/omp_forest_regressor.py      | 19 +++--
 bolsonaro/trainer.py                          | 38 +++++----
 bolsonaro/utils.py                            | 10 +++
 bolsonaro/visualization/plotter.py            | 26 +++++-
 compute_results.py                            | 21 +++++
 train.py                                      | 66 ++++++++++-----
 visualize.py                                  |  8 ++
 17 files changed, 490 insertions(+), 54 deletions(-)
 create mode 100644 bolsonaro/error_handling/__init__.py
 create mode 100644 bolsonaro/error_handling/color_print.py
 create mode 100644 bolsonaro/error_handling/console_logger.py
 create mode 100644 bolsonaro/error_handling/exception_decorators.py
 create mode 100644 bolsonaro/error_handling/logger_factory.py
 create mode 100644 bolsonaro/models/model_parameters.py
 create mode 100644 compute_results.py

diff --git a/bolsonaro/__init__.py b/bolsonaro/__init__.py
index e69de29..ce8e424 100644
--- a/bolsonaro/__init__.py
+++ b/bolsonaro/__init__.py
@@ -0,0 +1,3 @@
+import os
+
+LOG_PATH = os.path.abspath(os.path.dirname(__file__) + os.sep + '..' + os.sep + '..' + os.sep + 'log')
diff --git a/bolsonaro/data/dataset_loader.py b/bolsonaro/data/dataset_loader.py
index 1e4264e..c510a90 100644
--- a/bolsonaro/data/dataset_loader.py
+++ b/bolsonaro/data/dataset_loader.py
@@ -66,11 +66,12 @@ class DatasetLoader(object):
         X, y = dataset_loading_func(return_X_y=True)
         X_train, X_test, y_train, y_test = train_test_split(X, y,
             test_size=dataset_parameters.test_size,
-            random_state=dataset_parameters.seed)
+            random_state=dataset_parameters.random_state)
         X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train,
             test_size=dataset_parameters.dev_size,
-            random_state=dataset_parameters.seed)
+            random_state=dataset_parameters.random_state)
 
+        # TODO
         if dataset_parameters.normalize:
             pass
 
diff --git a/bolsonaro/data/dataset_parameters.py b/bolsonaro/data/dataset_parameters.py
index e820b8f..556c960 100644
--- a/bolsonaro/data/dataset_parameters.py
+++ b/bolsonaro/data/dataset_parameters.py
@@ -1,3 +1,7 @@
+import json
+import os
+
+
 class DatasetParameters(object):
 
     def __init__(self, name, test_size, dev_size, random_state, normalize):
@@ -26,3 +30,15 @@ class DatasetParameters(object):
     @property
     def normalize(self):
         return self._normalize
+
+    def save(self, directory_path, experiment_id):
+        with open(directory_path + os.sep + 'dataset_parameters_{}.json'.format(experiment_id), 'w') as output_file:
+            json.dump({
+                'name': self._name,
+                'test_size': self._test_size,
+                'dev_size': self._dev_size,
+                'random_state': self._random_state,
+                'normalize': self._normalize
+            },
+            output_file,
+            indent=4)
diff --git a/bolsonaro/error_handling/__init__.py b/bolsonaro/error_handling/__init__.py
new file mode 100644
index 0000000..a8ca18d
--- /dev/null
+++ b/bolsonaro/error_handling/__init__.py
@@ -0,0 +1,29 @@
+ #####################################################################################
+ # MIT License                                                                       #
+ #                                                                                   #
+ # Copyright (C) 2019 Charly Lamothe                                                 #
+ #                                                                                   #
+ # This file is part of VQ-VAE-Speech.                                               #
+ #                                                                                   #
+ #   Permission is hereby granted, free of charge, to any person obtaining a copy    #
+ #   of this software and associated documentation files (the "Software"), to deal   #
+ #   in the Software without restriction, including without limitation the rights    #
+ #   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell       #
+ #   copies of the Software, and to permit persons to whom the Software is           #
+ #   furnished to do so, subject to the following conditions:                        #
+ #                                                                                   #
+ #   The above copyright notice and this permission notice shall be included in all  #
+ #   copies or substantial portions of the Software.                                 #
+ #                                                                                   #
+ #   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR      #
+ #   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,        #
+ #   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE     #
+ #   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER          #
+ #   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,   #
+ #   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE   #
+ #   SOFTWARE.                                                                       #
+ #####################################################################################
+
+import os
+
+LOG_PATH = os.path.abspath(os.path.dirname(__file__) + os.sep + '..' + os.sep + '..' + os.sep + 'log')
diff --git a/bolsonaro/error_handling/color_print.py b/bolsonaro/error_handling/color_print.py
new file mode 100644
index 0000000..b577e5a
--- /dev/null
+++ b/bolsonaro/error_handling/color_print.py
@@ -0,0 +1,61 @@
+ #####################################################################################
+ # MIT License                                                                       #
+ #                                                                                   #
+ # Copyright (C) 2019 Charly Lamothe                                                 #
+ #                                                                                   #
+ # This file is part of VQ-VAE-Speech.                                               #
+ #                                                                                   #
+ #   Permission is hereby granted, free of charge, to any person obtaining a copy    #
+ #   of this software and associated documentation files (the "Software"), to deal   #
+ #   in the Software without restriction, including without limitation the rights    #
+ #   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell       #
+ #   copies of the Software, and to permit persons to whom the Software is           #
+ #   furnished to do so, subject to the following conditions:                        #
+ #                                                                                   #
+ #   The above copyright notice and this permission notice shall be included in all  #
+ #   copies or substantial portions of the Software.                                 #
+ #                                                                                   #
+ #   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR      #
+ #   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,        #
+ #   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE     #
+ #   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER          #
+ #   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,   #
+ #   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE   #
+ #   SOFTWARE.                                                                       #
+ #####################################################################################
+
+import sys
+
+
+class ColorPrint(object):
+    """ Colored printing functions for strings that use universal ANSI escape sequences.
+
+    fail: bold red, pass: bold green, warn: bold yellow, 
+    info: bold blue, bold: bold white
+
+    :source: https://stackoverflow.com/a/47622205
+    """
+
+    @staticmethod
+    def print_fail(message, end='\n'):
+        sys.stderr.write('\x1b[1;31m' + message.strip() + '\x1b[0m' + end)
+
+    @staticmethod
+    def print_pass(message, end='\n'):
+        sys.stdout.write('\x1b[1;32m' + message.strip() + '\x1b[0m' + end)
+
+    @staticmethod
+    def print_warn(message, end='\n'):
+        sys.stderr.write('\x1b[1;33m' + message.strip() + '\x1b[0m' + end)
+
+    @staticmethod
+    def print_info(message, end='\n'):
+        sys.stdout.write('\x1b[1;34m' + message.strip() + '\x1b[0m' + end)
+
+    @staticmethod
+    def print_major_fail(message, end='\n'):
+        sys.stdout.write('\x1b[1;35m' + message.strip() + '\x1b[0m' + end)
+
+    @staticmethod
+    def print_bold(message, end='\n'):
+        sys.stdout.write('\x1b[1;37m' + message.strip() + '\x1b[0m' + end)
diff --git a/bolsonaro/error_handling/console_logger.py b/bolsonaro/error_handling/console_logger.py
new file mode 100644
index 0000000..7014b4c
--- /dev/null
+++ b/bolsonaro/error_handling/console_logger.py
@@ -0,0 +1,81 @@
+ #####################################################################################
+ # MIT License                                                                       #
+ #                                                                                   #
+ # Copyright (C) 2019 Charly Lamothe                                                 #
+ #                                                                                   #
+ # This file is part of VQ-VAE-Speech.                                               #
+ #                                                                                   #
+ #   Permission is hereby granted, free of charge, to any person obtaining a copy    #
+ #   of this software and associated documentation files (the "Software"), to deal   #
+ #   in the Software without restriction, including without limitation the rights    #
+ #   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell       #
+ #   copies of the Software, and to permit persons to whom the Software is           #
+ #   furnished to do so, subject to the following conditions:                        #
+ #                                                                                   #
+ #   The above copyright notice and this permission notice shall be included in all  #
+ #   copies or substantial portions of the Software.                                 #
+ #                                                                                   #
+ #   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR      #
+ #   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,        #
+ #   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE     #
+ #   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER          #
+ #   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,   #
+ #   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE   #
+ #   SOFTWARE.                                                                       #
+ #####################################################################################
+
+from error_handling.color_print import ColorPrint
+
+import sys
+import traceback
+import os
+
+
+class ConsoleLogger(object):
+
+    @staticmethod
+    def status(message):
+        if os.name == 'nt':
+            print('[~] {message}'.format(message=message))
+        else:
+            ColorPrint.print_info('[~] {message}'.format(message=message))
+
+    @staticmethod
+    def success(message):
+        if os.name == 'nt':
+            print('[+] {message}'.format(message=message))
+        else:
+            ColorPrint.print_pass('[+] {message}'.format(message=message))
+
+    @staticmethod
+    def error(message):
+        if sys.exc_info()[2]:
+            line = traceback.extract_tb(sys.exc_info()[2])[-1].lineno
+            error_message = '[-] {message} with cause: {cause} (line {line})'.format( \
+                message=message, cause=str(sys.exc_info()[1]), line=line)
+        else:
+            error_message = '[-] {message}'.format(message=message)
+        if os.name == 'nt':
+            print(error_message)
+        else:
+            ColorPrint.print_fail(error_message)
+
+    @staticmethod
+    def warn(message):
+        if os.name == 'nt':
+            print('[-] {message}'.format(message=message))
+        else:
+            ColorPrint.print_warn('[-] {message}'.format(message=message))
+
+    @staticmethod
+    def critical(message):
+        if sys.exc_info()[2]:
+            line = traceback.extract_tb(sys.exc_info()[2])[-1].lineno
+            error_message = '[!] {message} with cause: {cause} (line {line})'.format( \
+                message=message, cause=str(sys.exc_info()[1]), line=line)
+        else:
+            error_message = '[!] {message}'.format(message=message)
+        if os.name == 'nt':
+            print(error_message)
+        else:
+            ColorPrint.print_major_fail(error_message)
diff --git a/bolsonaro/error_handling/exception_decorators.py b/bolsonaro/error_handling/exception_decorators.py
new file mode 100644
index 0000000..428c618
--- /dev/null
+++ b/bolsonaro/error_handling/exception_decorators.py
@@ -0,0 +1,55 @@
+ #####################################################################################
+ # MIT License                                                                       #
+ #                                                                                   #
+ # Copyright (C) 2019 Charly Lamothe                                                 #
+ #                                                                                   #
+ # This file is part of VQ-VAE-Speech.                                               #
+ #                                                                                   #
+ #   Permission is hereby granted, free of charge, to any person obtaining a copy    #
+ #   of this software and associated documentation files (the "Software"), to deal   #
+ #   in the Software without restriction, including without limitation the rights    #
+ #   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell       #
+ #   copies of the Software, and to permit persons to whom the Software is           #
+ #   furnished to do so, subject to the following conditions:                        #
+ #                                                                                   #
+ #   The above copyright notice and this permission notice shall be included in all  #
+ #   copies or substantial portions of the Software.                                 #
+ #                                                                                   #
+ #   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR      #
+ #   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,        #
+ #   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE     #
+ #   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER          #
+ #   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,   #
+ #   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE   #
+ #   SOFTWARE.                                                                       #
+ #####################################################################################
+
+from functools import wraps
+
+
+class InvalidRaiseException(Exception):
+    pass
+
+
+def only_throws(E):
+    """
+    :source: https://stackoverflow.com/a/18289516
+    """
+    
+    def decorator(f):
+        @wraps(f)
+        def wrapped(*args, **kwargs):
+            try:
+                return f(*args, **kwargs)
+            except E:
+                raise
+            except InvalidRaiseException:
+                raise
+            except Exception as e:
+                raise InvalidRaiseException(
+                    'got %s, expected %s, from %s' % (e.__class__.__name__, E.__name__, f.__name__)
+                )
+
+        return wrapped
+
+    return decorator
diff --git a/bolsonaro/error_handling/logger_factory.py b/bolsonaro/error_handling/logger_factory.py
new file mode 100644
index 0000000..f524851
--- /dev/null
+++ b/bolsonaro/error_handling/logger_factory.py
@@ -0,0 +1,66 @@
+ #####################################################################################
+ # MIT License                                                                       #
+ #                                                                                   #
+ # Copyright (C) 2019 Charly Lamothe                                                 #
+ #                                                                                   #
+ # This file is part of VQ-VAE-Speech.                                               #
+ #                                                                                   #
+ #   Permission is hereby granted, free of charge, to any person obtaining a copy    #
+ #   of this software and associated documentation files (the "Software"), to deal   #
+ #   in the Software without restriction, including without limitation the rights    #
+ #   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell       #
+ #   copies of the Software, and to permit persons to whom the Software is           #
+ #   furnished to do so, subject to the following conditions:                        #
+ #                                                                                   #
+ #   The above copyright notice and this permission notice shall be included in all  #
+ #   copies or substantial portions of the Software.                                 #
+ #                                                                                   #
+ #   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR      #
+ #   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,        #
+ #   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE     #
+ #   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER          #
+ #   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,   #
+ #   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE   #
+ #   SOFTWARE.                                                                       #
+ #####################################################################################
+
+import logging
+from logging.handlers import RotatingFileHandler
+import os
+import errno
+
+
+class LoggerFactory(object):
+    
+    @staticmethod
+    def create(path, module_name):
+        # Create logger
+        logger = logging.getLogger(module_name)
+        logger.setLevel(logging.DEBUG)
+
+        try:
+            os.makedirs(path)
+        except OSError as e:
+            if e.errno != errno.EEXIST:
+                raise
+
+        # Create file handler
+        fh = RotatingFileHandler(path + os.sep + module_name + '.log', maxBytes=1000000, backupCount=5)
+        fh.setLevel(logging.DEBUG)
+
+        # Create console handler
+        ch = logging.StreamHandler()
+        ch.setLevel(logging.DEBUG)
+
+        # Create formatter
+        formatter = logging.Formatter('%(asctime)s - %(filename)s:%(lineno)s - %(name)s - %(levelname)s - %(message)s')
+
+        # Add formatter to handlers
+        fh.setFormatter(formatter)
+        ch.setFormatter(formatter) # TODO: add another formatter to the console logger?
+
+        # Add fh and ch to logger
+        logger.addHandler(fh)
+        logger.addHandler(ch)
+
+        return logger
diff --git a/bolsonaro/models/model_factory.py b/bolsonaro/models/model_factory.py
index a2f02a6..5bad7f4 100644
--- a/bolsonaro/models/model_factory.py
+++ b/bolsonaro/models/model_factory.py
@@ -6,15 +6,11 @@ from bolsonaro.data.task import Task
 class ModelFactory(object):
 
     @staticmethod
-    def build(task, forest_size, extracted_forest_size, seed=None):
+    def build(task, model_parameters):
         if task == Task.CLASSIFICATION:
             model_func = OmpForestClassifier
         elif task == Task.REGRESSION:
             model_func = OmpForestRegressor
         else:
             raise ValueError("Unsupported task '{}'".format(task))
-        return model_func(
-            forest_size=forest_size,
-            extracted_forest_size=extracted_forest_size,
-            seed=seed
-        )
+        return model_func(model_parameters)
diff --git a/bolsonaro/models/model_parameters.py b/bolsonaro/models/model_parameters.py
new file mode 100644
index 0000000..b1fec8c
--- /dev/null
+++ b/bolsonaro/models/model_parameters.py
@@ -0,0 +1,32 @@
+import json
+import os
+
+
+class ModelParameters(object):
+
+    def __init__(self, forest_size, extracted_forest_size, seed=None):
+        self._forest_size = forest_size
+        self._extracted_forest_size = extracted_forest_size
+        self._seed = seed
+
+    @property
+    def forest_size(self):
+        return self._forest_size
+
+    @property
+    def extracted_forest_size(self):
+        return self._extracted_forest_size
+
+    @property
+    def seed(self):
+        return self._seed
+
+    def save(self, directory_path, experiment_id):
+        with open(directory_path + os.sep + 'model_parameters_{}.json'.format(experiment_id), 'w') as output_file:
+            json.dump({
+                'forest_size': self._forest_size,
+                'extracted_forest_size': self._extracted_forest_size,
+                'seed': self._seed
+            },
+            output_file,
+            indent=4)
diff --git a/bolsonaro/models/omp_forest_regressor.py b/bolsonaro/models/omp_forest_regressor.py
index 17d99aa..be60cae 100644
--- a/bolsonaro/models/omp_forest_regressor.py
+++ b/bolsonaro/models/omp_forest_regressor.py
@@ -1,14 +1,14 @@
-from sklearn.base import BaseEstimator
 from sklearn.ensemble import RandomForestRegressor
 from sklearn.linear_model import OrthogonalMatchingPursuit
+from sklearn.base import BaseEstimator
 
 
 class OmpForestRegressor(BaseEstimator):
 
-    def __init__(self, forest_size, extracted_forest_size, seed=None):
-        self._regressor = RandomForestRegressor(n_estimators=forest_size,
-            random_state=seed)
-        self._extracted_forest_size = extracted_forest_size
+    def __init__(self, models_parameters):
+        self._regressor = RandomForestRegressor(n_estimators=models_parameters.forest_size,
+            random_state=models_parameters.seed)
+        self._models_parameters = models_parameters
 
     def fit(self, X_train, y_train):
         self._forest = self._train_forest(X_train, y_train)
@@ -25,14 +25,19 @@ class OmpForestRegressor(BaseEstimator):
     def weights(self):
         return self._weights
 
+    @property
+    def models_parameters(self):
+        return self._models_parameters
+
     def _train_forest(self, X_train, y_train):
         self._regressor.fit(X_train, y_train)
         forest = self._regressor.estimators_
         return forest
     
     def _extract_subforest(self, X_train, y_train):
-        D = [[tree.predict([elem])[0] for tree in forest] for elem in X_train]
-        omp = OrthogonalMatchingPursuit(n_nonzero_coefs=self._extracted_forest_size,
+        D = [[tree.predict([elem])[0] for tree in self._forest] for elem in X_train]
+        omp = OrthogonalMatchingPursuit(
+            n_nonzero_coefs=self._models_parameters.extracted_forest_size,
             fit_intercept=False, normalize=False)
         omp.fit(D, y_train)
         weights = omp.coef_
diff --git a/bolsonaro/trainer.py b/bolsonaro/trainer.py
index cb9f9fe..7c1436b 100644
--- a/bolsonaro/trainer.py
+++ b/bolsonaro/trainer.py
@@ -1,26 +1,30 @@
-from bolsonaro.utils import resolve_output_file_name
+from bolsonaro.error_handling.logger_factory import LoggerFactory
+from . import LOG_PATH
 
 import pickle
+import os
+import time
+import datetime
 
 
 class Trainer(object):
 
-    def __init__(self, dataset, model, results_dir, models_dir):
+    def __init__(self, dataset):
         self._dataset = dataset
-        self._model = model
-        self._results_dir = results_dir
-        self._models_dir = models_dir
+        self._logger = LoggerFactory.create(LOG_PATH, __name__)
 
-    def process(self):
-        self._model.fit(self._dataset.X_train, self._dataset.y_train)
-        output_file_name = resolve_output_file_name(
-            self._dataset.dataset_parameters,
-            self._model.model_parameters,
-            self._results_dir,
-            self._models_dir
-        )
-        with open(output_file_name, 'wb') as output_file:
-            pickle.dump(output_file, {
+    def iterate(self, model, models_dir):
+        self._logger.info('Training model using train set...')
+        begin_time = time.time()
+        model.fit(self._dataset.X_train, self._dataset.y_train)
+        end_time = time.time()
 
-            })
-        # save forest and weights here
+        output_file_path = models_dir + os.sep + 'model.pickle'
+        self._logger.info('Saving trained model to {}'.format(output_file_path))
+        with open(output_file_path, 'wb') as output_file:
+            pickle.dump({
+                'forest': model.forest,
+                'weights': model.weights,
+                'training_time': end_time - begin_time,
+                'datetime': datetime.datetime.now()
+            }, output_file)
diff --git a/bolsonaro/utils.py b/bolsonaro/utils.py
index e69de29..2affd37 100644
--- a/bolsonaro/utils.py
+++ b/bolsonaro/utils.py
@@ -0,0 +1,10 @@
+import os
+
+
+def resolve_experiment_id(models_dir):
+    ids = [x for x in os.listdir(models_dir) 
+        if os.path.isdir(models_dir + os.sep + x)]
+    if len(ids) > 0:
+        ids.sort(key=int)
+        return int(max(ids)) + 1
+    return 1
diff --git a/bolsonaro/visualization/plotter.py b/bolsonaro/visualization/plotter.py
index 01f0f03..c119d47 100644
--- a/bolsonaro/visualization/plotter.py
+++ b/bolsonaro/visualization/plotter.py
@@ -1,3 +1,27 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from sklearn.neighbors.kde import KernelDensity
+
+
 class Plotter(object):
 
-    
\ No newline at end of file
+    @staticmethod
+    def weight_density(weights):
+        """
+        TODO: to complete
+        """
+        X_plot = [np.exp(elem) for elem in weights]
+        fig, ax = plt.subplots()
+
+        for kernel in ['gaussian', 'tophat', 'epanechnikov']:
+            kde = KernelDensity(kernel=kernel, bandwidth=0.5).fit(X_plot)
+            log_dens = kde.score_samples(X_plot)
+            ax.plot(X_plot[:, 0], np.exp(log_dens), '-',
+                    label="kernel = '{0}'".format(kernel))
+
+        ax.legend(loc='upper left')
+        ax.plot(X[:, 0], -0.005 - 0.01 * np.random.random(X.shape[0]), '+k')
+
+        ax.set_xlim(-4, 9)
+        ax.set_ylim(-0.02, 0.4)
+        plt.show()
diff --git a/compute_results.py b/compute_results.py
new file mode 100644
index 0000000..ba80f0b
--- /dev/null
+++ b/compute_results.py
@@ -0,0 +1,21 @@
+import argparse
+import pathlib
+
+
+if __name__ == "__main__":
+    default_results_dir = 'results'
+    default_models_dir = 'models'
+    default_experiment_id = -1
+
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--results_dir', nargs='?', type=str, default=default_results_dir, help='The output directory of the results.')
+    parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.')
+    parser.add_argument('--experiment_id', nargs='?', type=int, default=default_experiment_id, help='Compute the results of a single experiment id')
+    args = parser.parse_args()
+
+    pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)
+
+    if args.experiment_id == -1:
+        pass
+    else:
+        pass
diff --git a/train.py b/train.py
index cd6f228..0e6896c 100644
--- a/train.py
+++ b/train.py
@@ -1,11 +1,15 @@
 from bolsonaro.data.dataset_parameters import DatasetParameters
 from bolsonaro.data.dataset_loader import DatasetLoader
 from bolsonaro.models.model_factory import ModelFactory
+from bolsonaro.models.model_parameters import ModelParameters
 from bolsonaro.trainer import Trainer
+from bolsonaro.utils import resolve_experiment_id
 
 import argparse
 import pathlib
 import random
+import os
+import errno
 
 
 if __name__ == "__main__":
@@ -13,7 +17,6 @@ if __name__ == "__main__":
     default_normalize = False
     default_forest_size = 100
     default_extracted_forest_size = 10
-    default_results_dir = 'results'
     default_models_dir = 'models'
     default_dev_size = 0.2
     default_test_size = 0.2
@@ -27,7 +30,6 @@ if __name__ == "__main__":
     parser.add_argument('--normalize', action='store_true', default=default_normalize, help='Normalize the data by doing the L2 division of the pred vectors.')
     parser.add_argument('--forest_size', nargs='?', type=int, default=default_forest_size, help='The number of trees of the random forest.')
     parser.add_argument('--extracted_forest_size', nargs='+', type=int, default=default_extracted_forest_size, help='The number of trees selected by OMP.')
-    parser.add_argument('--results_dir', nargs='?', type=str, default=default_results_dir, help='The output directory of the results.')
     parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.')
     parser.add_argument('--dev_size', nargs='?', type=float, default=default_dev_size, help='Dev subset ratio')
     parser.add_argument('--test_size', nargs='?', type=float, default=default_test_size, help='Test subset ratio')
@@ -35,35 +37,57 @@ if __name__ == "__main__":
     parser.add_argument('--random_seed_number', nargs='?', type=int, default=default_random_seed_number, help='Number of random seeds used')
     args = parser.parse_args()
 
-    pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)
     pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True)
 
-    random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) for i in range(args.random_seed_number)] \
+    args.extracted_forest_size = args.extracted_forest_size \
+        if type(args.extracted_forest_size) == list \
+        else [args.extracted_forest_size]
+
+    random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) \
+        for i in range(args.random_seed_number)] \
         if args.use_random_seed else None
 
+    experiment_id = resolve_experiment_id(args.models_dir)
+    experiment_id_str = str(experiment_id)
+
     for random_seed in random_seeds:
-        dataset = DatasetLoader.load_from_name(
-            DatasetParameters(
-                name=args.dataset_name,
-                test_size=args.test_size,
-                dev_size=args.dev_size,
-                random_state=random_seed,
-                normalize=args.normalize
-            )
+        random_seed_str = str(random_seed)
+        models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \
+            os.sep + random_seed_str
+        try:
+            os.makedirs(models_dir)
+        except OSError as e:
+            if e.errno != errno.EEXIST:
+                raise
+
+        dataset_parameters = DatasetParameters(
+            name=args.dataset_name,
+            test_size=args.test_size,
+            dev_size=args.dev_size,
+            random_state=random_seed,
+            normalize=args.normalize
         )
+        dataset_parameters.save(models_dir, experiment_id_str)
+
+        dataset = DatasetLoader.load_from_name(dataset_parameters)
+
+        trainer = Trainer(dataset)
 
         for extracted_forest_size in args.extracted_forest_size:
-            model = ModelFactory(
-                task=dataset.task,
+            sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)
+            try:
+                os.makedirs(sub_models_dir)
+            except OSError as e:
+                if e.errno != errno.EEXIST:
+                    raise
+
+            model_parameters = ModelParameters(
                 forest_size=args.forest_size,
                 extracted_forest_size=extracted_forest_size,
                 seed=random_seed
             )
+            model_parameters.save(sub_models_dir, experiment_id)
 
-            trainer = Trainer(
-                dataset=dataset,
-                model=model,
-                results_dir=args.results_dir,
-                models_dir=args.models_dir
-            )
-            trainer.process()
+            model = ModelFactory.build(dataset.task, model_parameters)
+
+            trainer.iterate(model, sub_models_dir)
diff --git a/visualize.py b/visualize.py
index e69de29..6ae9da1 100644
--- a/visualize.py
+++ b/visualize.py
@@ -0,0 +1,8 @@
+import argparse
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--weight_density', action='store_true', default=default_use_weright_density, help='')
+    args = parser.parse_args()
+
-- 
GitLab