Skip to content
Snippets Groups Projects
Commit 29d4fc58 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Merge branch '12-experiment-pipeline' into 'master'

Resolve "Experiment pipeline"

See merge request !16
parents 28c3d874 881106ae
No related branches found
No related tags found
1 merge request!16Resolve "Experiment pipeline"
......@@ -6,12 +6,12 @@ import datetime
class ModelRawResults(object):
def __init__(self, model_object, training_time,
def __init__(self, model_weights, training_time,
datetime, train_score, dev_score, test_score,
train_score_base, dev_score_base,
test_score_base, score_metric, base_score_metric):
self._model_object = model_object
self._model_weights = model_weights
self._training_time = training_time
self._datetime = datetime
self._train_score = train_score
......@@ -24,8 +24,8 @@ class ModelRawResults(object):
self._base_score_metric = base_score_metric
@property
def model_object(self):
return self.model_object
def model_weights(self):
return self.model_weights
@property
def training_time(self):
......
......@@ -8,6 +8,7 @@ from sklearn.base import BaseEstimator
class OmpForest(BaseEstimator, metaclass=ABCMeta):
def __init__(self, models_parameters, base_forest_estimator):
self._base_forest_estimator = base_forest_estimator
self._models_parameters = models_parameters
......@@ -96,6 +97,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
pass
class SingleOmpForest(OmpForest):
def __init__(self, models_parameters, base_forest_estimator):
# fit_intercept shouldn't be set to False as the data isn't necessarily centered here
# normalization is handled outsite OMP
......
......@@ -126,8 +126,17 @@ class Trainer(object):
:param model: Object with
:param models_dir: Where the results will be saved
"""
model_weights = ''
if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
model_weights = model._omp.coef_
elif type(model) == OmpForestMulticlassClassifier:
model_weights = model._dct_class_omp
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
results = ModelRawResults(
model_object='',
model_weights=model_weights,
training_time=self._end_time - self._begin_time,
datetime=datetime.datetime.now(),
train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train),
......
......@@ -28,6 +28,7 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_train_scores = dict()
experiment_dev_scores = dict()
experiment_test_scores = dict()
experiment_weights = dict()
all_extracted_forest_sizes = list()
# Used to check if all losses were computed using the same metric (it should be the case)
......@@ -44,6 +45,7 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_train_scores[seed] = list()
experiment_dev_scores[seed] = list()
experiment_test_scores[seed] = list()
experiment_weights[seed] = list()
# List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
......@@ -60,6 +62,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_test_scores[seed].append(model_raw_results.test_score)
# Save the metric
experiment_score_metrics.append(model_raw_results.score_metric)
# Save the weights
#experiment_weights[seed].append(model_raw_results.model_weights)
# Sanity checks
if len(set(experiment_score_metrics)) > 1:
......@@ -67,7 +71,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
if len(set([sum(extracted_forest_sizes) for extracted_forest_sizes in all_extracted_forest_sizes])) != 1:
raise ValueError("The extracted forest sizes aren't the sames across seeds.")
return experiment_train_scores, experiment_dev_scores, experiment_test_scores, all_extracted_forest_sizes[0], experiment_score_metrics[0]
return experiment_train_scores, experiment_dev_scores, experiment_test_scores, \
all_extracted_forest_sizes[0], experiment_score_metrics[0]#, experiment_weights
def extract_scores_across_seeds_and_forest_size(models_dir, results_dir, experiment_id, extracted_forest_sizes_number):
experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
......@@ -120,6 +125,7 @@ if __name__ == "__main__":
DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results'
DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
DEFAULT_PLOT_WEIGHT_DENSITY = False
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
......@@ -130,6 +136,7 @@ if __name__ == "__main__":
parser.add_argument('--dataset_name', nargs='?', type=str, required=True, help='Specify the dataset name. TODO: read it from models dir directly.')
parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.')
parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.')
args = parser.parse_args()
if args.stage not in list(range(1, 6)):
......@@ -347,6 +354,9 @@ if __name__ == "__main__":
extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[1])
# omp_with_params
logger.info('Loading omp_with_params experiment scores...')
"""omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
omp_with_params_experiment_score_metric, experiment_weights = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2])"""
omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2])
......@@ -369,7 +379,7 @@ if __name__ == "__main__":
raise ValueError('Score metrics of all experiments must be the same.')
experiments_score_metric = base_with_params_experiment_score_metric
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4')
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4_fix')
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_stage2_losses(
......@@ -380,6 +390,9 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
# experiment_weights
#Plotter.weight_density(experiment_weights, output_path + os.sep + 'weight_density.png')
else:
raise ValueError('This stage number is not supported yet, but it will be!')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment