Skip to content
Snippets Groups Projects
Commit 881106ae authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add weights saving. TODO: density plots at least

parent 07127e25
No related branches found
1 merge request!16Resolve "Experiment pipeline"
...@@ -128,10 +128,12 @@ class Trainer(object): ...@@ -128,10 +128,12 @@ class Trainer(object):
""" """
model_weights = '' model_weights = ''
if type(model) == RandomForestRegressor: if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
model_weights = model.coef_
elif type(model) == OmpForestRegressor:
model_weights = model._omp.coef_ model_weights = model._omp.coef_
elif type(model) == OmpForestMulticlassClassifier:
model_weights = model._dct_class_omp
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
results = ModelRawResults( results = ModelRawResults(
model_weights=model_weights, model_weights=model_weights,
......
...@@ -28,13 +28,12 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d ...@@ -28,13 +28,12 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_train_scores = dict() experiment_train_scores = dict()
experiment_dev_scores = dict() experiment_dev_scores = dict()
experiment_test_scores = dict() experiment_test_scores = dict()
experiment_weights = dict()
all_extracted_forest_sizes = list() all_extracted_forest_sizes = list()
# Used to check if all losses were computed using the same metric (it should be the case) # Used to check if all losses were computed using the same metric (it should be the case)
experiment_score_metrics = list() experiment_score_metrics = list()
all_weights = list()
# For each seed results stored in models/{experiment_id}/seeds # For each seed results stored in models/{experiment_id}/seeds
seeds = os.listdir(experiment_seed_root_path) seeds = os.listdir(experiment_seed_root_path)
seeds.sort(key=int) seeds.sort(key=int)
...@@ -46,6 +45,7 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d ...@@ -46,6 +45,7 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_train_scores[seed] = list() experiment_train_scores[seed] = list()
experiment_dev_scores[seed] = list() experiment_dev_scores[seed] = list()
experiment_test_scores[seed] = list() experiment_test_scores[seed] = list()
experiment_weights[seed] = list()
# List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path) extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
...@@ -62,6 +62,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d ...@@ -62,6 +62,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_test_scores[seed].append(model_raw_results.test_score) experiment_test_scores[seed].append(model_raw_results.test_score)
# Save the metric # Save the metric
experiment_score_metrics.append(model_raw_results.score_metric) experiment_score_metrics.append(model_raw_results.score_metric)
# Save the weights
#experiment_weights[seed].append(model_raw_results.model_weights)
# Sanity checks # Sanity checks
if len(set(experiment_score_metrics)) > 1: if len(set(experiment_score_metrics)) > 1:
...@@ -69,7 +71,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d ...@@ -69,7 +71,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
if len(set([sum(extracted_forest_sizes) for extracted_forest_sizes in all_extracted_forest_sizes])) != 1: if len(set([sum(extracted_forest_sizes) for extracted_forest_sizes in all_extracted_forest_sizes])) != 1:
raise ValueError("The extracted forest sizes aren't the sames across seeds.") raise ValueError("The extracted forest sizes aren't the sames across seeds.")
return experiment_train_scores, experiment_dev_scores, experiment_test_scores, all_extracted_forest_sizes[0], experiment_score_metrics[0] return experiment_train_scores, experiment_dev_scores, experiment_test_scores, \
all_extracted_forest_sizes[0], experiment_score_metrics[0]#, experiment_weights
def extract_scores_across_seeds_and_forest_size(models_dir, results_dir, experiment_id, extracted_forest_sizes_number): def extract_scores_across_seeds_and_forest_size(models_dir, results_dir, experiment_id, extracted_forest_sizes_number):
experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id} experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
...@@ -228,8 +231,6 @@ if __name__ == "__main__": ...@@ -228,8 +231,6 @@ if __name__ == "__main__":
ylabel=experiments_score_metric, ylabel=experiments_score_metric,
title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name) title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name)
) )
Plotter.plot_weight_density()
elif args.stage == 2: elif args.stage == 2:
if len(args.experiment_ids) != 4: if len(args.experiment_ids) != 4:
raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.') raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.')
...@@ -353,6 +354,9 @@ if __name__ == "__main__": ...@@ -353,6 +354,9 @@ if __name__ == "__main__":
extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[1]) extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[1])
# omp_with_params # omp_with_params
logger.info('Loading omp_with_params experiment scores...') logger.info('Loading omp_with_params experiment scores...')
"""omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
omp_with_params_experiment_score_metric, experiment_weights = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2])"""
omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \ omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes( omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2]) args.models_dir, args.results_dir, args.experiment_ids[2])
...@@ -375,7 +379,7 @@ if __name__ == "__main__": ...@@ -375,7 +379,7 @@ if __name__ == "__main__":
raise ValueError('Score metrics of all experiments must be the same.') raise ValueError('Score metrics of all experiments must be the same.')
experiments_score_metric = base_with_params_experiment_score_metric experiments_score_metric = base_with_params_experiment_score_metric
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4') output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4_fix')
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True) pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_stage2_losses( Plotter.plot_stage2_losses(
...@@ -386,6 +390,9 @@ if __name__ == "__main__": ...@@ -386,6 +390,9 @@ if __name__ == "__main__":
xlabel='Number of trees extracted', xlabel='Number of trees extracted',
ylabel=experiments_score_metric, ylabel=experiments_score_metric,
title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name)) title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
# experiment_weights
#Plotter.weight_density(experiment_weights, output_path + os.sep + 'weight_density.png')
else: else:
raise ValueError('This stage number is not supported yet, but it will be!') raise ValueError('This stage number is not supported yet, but it will be!')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment