Skip to content
Snippets Groups Projects
Commit 21ccc627 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

- Add a temp fix for the subset used in base and random strategies;

- Add new results for stage4.
parent 30a57834
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
......@@ -42,10 +42,14 @@ class Trainer(object):
def base_score_metric_name(self):
return self._base_score_metric_name
def init(self, model):
def init(self, model, subsets_used='train,dev'):
if type(model) in [RandomForestRegressor, RandomForestClassifier]:
self._X_forest = self._dataset.X_train
self._y_forest = self._dataset.y_train
if subsets_used == 'train,dev':
self._X_forest = self._dataset.X_train
self._y_forest = self._dataset.y_train
else:
self._X_forest = np.concatenate([self._dataset.X_train, self._dataset.X_dev])
self._y_forest = np.concatenate([self._dataset.y_train, self._dataset.y_dev])
self._logger.debug('Fitting the forest on train subset')
elif model.models_parameters.subsets_used == 'train,dev':
self._X_forest = self._dataset.X_train
......
......@@ -328,7 +328,7 @@ if __name__ == "__main__":
ylabel=experiments_score_metric,
title='Loss values of {}\nusing different training subsets'.format(args.dataset_name))"""
elif args.stage == 4:
if len(args.experiment_ids) != 5:
if len(args.experiment_ids) != 3:
raise ValueError('In the case of stage 4, the number of specified experiment ids must be 3.')
# Retreive the extracted forest sizes number used in order to have a base forest axis as long as necessary
......@@ -351,7 +351,7 @@ if __name__ == "__main__":
omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2])
# base_with_params
"""# base_with_params
logger.info('Loading base_with_params experiment scores 2...')
_, _, base_with_params_test_scores_2, \
_ = \
......@@ -361,7 +361,7 @@ if __name__ == "__main__":
logger.info('Loading random_with_params experiment scores 2...')
_, _, random_with_params_test_scores_2, \
_, _ = \
extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[4])
extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[4])"""
# Sanity check on the metrics retreived
if not (base_with_params_experiment_score_metric == random_with_params_experiment_score_metric
......@@ -374,10 +374,8 @@ if __name__ == "__main__":
Plotter.plot_stage2_losses(
file_path=output_path + os.sep + 'losses.png',
all_experiment_scores=[base_with_params_test_scores, base_with_params_test_scores_2, random_with_params_test_scores,
random_with_params_test_scores_2,
omp_with_params_test_scores],
all_labels=['base_train-dev', 'base', 'random_train-dev', 'random', 'omp'],
all_experiment_scores=[base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores],
all_labels=['base', 'random', 'omp'],
x_value=with_params_extracted_forest_sizes,
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
......
......@@ -73,7 +73,7 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
model = ModelFactory.build(dataset.task, model_parameters)
trainer.init(model)
trainer.init(model, subsets_used=parameters['subsets_used'])
trainer.train(model)
trainer.compute_results(model, sub_models_dir)
else:
......@@ -95,7 +95,7 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
model = ModelFactory.build(dataset.task, model_parameters)
trainer.init(model)
trainer.init(model, subsets_used=parameters['subsets_used'])
trainer.train(model)
trainer.compute_results(model, sub_models_dir)
logger.info('Training done')
......
results/boston/stage4/losses.png

43.8 KiB | W: | H:

results/boston/stage4/losses.png

43.7 KiB | W: | H:

results/boston/stage4/losses.png
results/boston/stage4/losses.png
results/boston/stage4/losses.png
results/boston/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/breast_cancer/stage4/losses.png

53.9 KiB | W: | H:

results/breast_cancer/stage4/losses.png

54.4 KiB | W: | H:

results/breast_cancer/stage4/losses.png
results/breast_cancer/stage4/losses.png
results/breast_cancer/stage4/losses.png
results/breast_cancer/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/california_housing/stage4/losses.png

37.5 KiB | W: | H:

results/california_housing/stage4/losses.png

41 KiB | W: | H:

results/california_housing/stage4/losses.png
results/california_housing/stage4/losses.png
results/california_housing/stage4/losses.png
results/california_housing/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/diabetes/stage4/losses.png

42.7 KiB | W: | H:

results/diabetes/stage4/losses.png

42.5 KiB | W: | H:

results/diabetes/stage4/losses.png
results/diabetes/stage4/losses.png
results/diabetes/stage4/losses.png
results/diabetes/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/diamonds/stage4/losses.png

48.1 KiB | W: | H:

results/diamonds/stage4/losses.png

49.9 KiB | W: | H:

results/diamonds/stage4/losses.png
results/diamonds/stage4/losses.png
results/diamonds/stage4/losses.png
results/diamonds/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/digits/stage4/losses.png

48.5 KiB | W: | H:

results/digits/stage4/losses.png

55.6 KiB | W: | H:

results/digits/stage4/losses.png
results/digits/stage4/losses.png
results/digits/stage4/losses.png
results/digits/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/iris/stage4/losses.png

26 KiB | W: | H:

results/iris/stage4/losses.png

30.7 KiB | W: | H:

results/iris/stage4/losses.png
results/iris/stage4/losses.png
results/iris/stage4/losses.png
results/iris/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/lfw_pairs/stage4/losses.png

56.6 KiB

results/olivetti_faces/stage4/losses.png

34 KiB | W: | H:

results/olivetti_faces/stage4/losses.png

33.8 KiB | W: | H:

results/olivetti_faces/stage4/losses.png
results/olivetti_faces/stage4/losses.png
results/olivetti_faces/stage4/losses.png
results/olivetti_faces/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
results/wine/stage4/losses.png

32.2 KiB | W: | H:

results/wine/stage4/losses.png

32.8 KiB | W: | H:

results/wine/stage4/losses.png
results/wine/stage4/losses.png
results/wine/stage4/losses.png
results/wine/stage4/losses.png
  • 2-up
  • Swipe
  • Onion skin
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment