Skip to content
Snippets Groups Projects
Commit 07127e25 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Merge branch '12-experiment-pipeline-density-plot' into 12-experiment-pipeline

parents 28dacbd4 eadac78d
No related branches found
No related tags found
1 merge request!16Resolve "Experiment pipeline"
...@@ -6,12 +6,12 @@ import datetime ...@@ -6,12 +6,12 @@ import datetime
class ModelRawResults(object): class ModelRawResults(object):
def __init__(self, model_object, training_time, def __init__(self, model_weights, training_time,
datetime, train_score, dev_score, test_score, datetime, train_score, dev_score, test_score,
train_score_base, dev_score_base, train_score_base, dev_score_base,
test_score_base, score_metric, base_score_metric): test_score_base, score_metric, base_score_metric):
self._model_object = model_object self._model_weights = model_weights
self._training_time = training_time self._training_time = training_time
self._datetime = datetime self._datetime = datetime
self._train_score = train_score self._train_score = train_score
...@@ -24,8 +24,8 @@ class ModelRawResults(object): ...@@ -24,8 +24,8 @@ class ModelRawResults(object):
self._base_score_metric = base_score_metric self._base_score_metric = base_score_metric
@property @property
def model_object(self): def model_weights(self):
return self.model_object return self.model_weights
@property @property
def training_time(self): def training_time(self):
......
...@@ -8,6 +8,7 @@ from sklearn.base import BaseEstimator ...@@ -8,6 +8,7 @@ from sklearn.base import BaseEstimator
class OmpForest(BaseEstimator, metaclass=ABCMeta): class OmpForest(BaseEstimator, metaclass=ABCMeta):
def __init__(self, models_parameters, base_forest_estimator): def __init__(self, models_parameters, base_forest_estimator):
self._base_forest_estimator = base_forest_estimator self._base_forest_estimator = base_forest_estimator
self._models_parameters = models_parameters self._models_parameters = models_parameters
...@@ -96,6 +97,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): ...@@ -96,6 +97,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
pass pass
class SingleOmpForest(OmpForest): class SingleOmpForest(OmpForest):
def __init__(self, models_parameters, base_forest_estimator): def __init__(self, models_parameters, base_forest_estimator):
# fit_intercept shouldn't be set to False as the data isn't necessarily centered here # fit_intercept shouldn't be set to False as the data isn't necessarily centered here
# normalization is handled outsite OMP # normalization is handled outsite OMP
......
...@@ -126,8 +126,15 @@ class Trainer(object): ...@@ -126,8 +126,15 @@ class Trainer(object):
:param model: Object with :param model: Object with
:param models_dir: Where the results will be saved :param models_dir: Where the results will be saved
""" """
model_weights = ''
if type(model) == RandomForestRegressor:
model_weights = model.coef_
elif type(model) == OmpForestRegressor:
model_weights = model._omp.coef_
results = ModelRawResults( results = ModelRawResults(
model_object='', model_weights=model_weights,
training_time=self._end_time - self._begin_time, training_time=self._end_time - self._begin_time,
datetime=datetime.datetime.now(), datetime=datetime.datetime.now(),
train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train), train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train),
......
...@@ -33,6 +33,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d ...@@ -33,6 +33,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
# Used to check if all losses were computed using the same metric (it should be the case) # Used to check if all losses were computed using the same metric (it should be the case)
experiment_score_metrics = list() experiment_score_metrics = list()
all_weights = list()
# For each seed results stored in models/{experiment_id}/seeds # For each seed results stored in models/{experiment_id}/seeds
seeds = os.listdir(experiment_seed_root_path) seeds = os.listdir(experiment_seed_root_path)
seeds.sort(key=int) seeds.sort(key=int)
...@@ -120,6 +122,7 @@ if __name__ == "__main__": ...@@ -120,6 +122,7 @@ if __name__ == "__main__":
DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results' DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results'
DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models' DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
DEFAULT_PLOT_WEIGHT_DENSITY = False
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].') parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
...@@ -130,6 +133,7 @@ if __name__ == "__main__": ...@@ -130,6 +133,7 @@ if __name__ == "__main__":
parser.add_argument('--dataset_name', nargs='?', type=str, required=True, help='Specify the dataset name. TODO: read it from models dir directly.') parser.add_argument('--dataset_name', nargs='?', type=str, required=True, help='Specify the dataset name. TODO: read it from models dir directly.')
parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.') parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.')
parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.') parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.')
args = parser.parse_args() args = parser.parse_args()
if args.stage not in list(range(1, 6)): if args.stage not in list(range(1, 6)):
...@@ -224,6 +228,8 @@ if __name__ == "__main__": ...@@ -224,6 +228,8 @@ if __name__ == "__main__":
ylabel=experiments_score_metric, ylabel=experiments_score_metric,
title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name) title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name)
) )
Plotter.plot_weight_density()
elif args.stage == 2: elif args.stage == 2:
if len(args.experiment_ids) != 4: if len(args.experiment_ids) != 4:
raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.') raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment