Commit 881106ae authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add weights saving. TODO: density plots at least

parent 07127e25
......@@ -128,10 +128,12 @@ class Trainer(object):
"""
model_weights = ''
if type(model) == RandomForestRegressor:
model_weights = model.coef_
elif type(model) == OmpForestRegressor:
if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
model_weights = model._omp.coef_
elif type(model) == OmpForestMulticlassClassifier:
model_weights = model._dct_class_omp
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
results = ModelRawResults(
model_weights=model_weights,
......
......@@ -28,13 +28,12 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_train_scores = dict()
experiment_dev_scores = dict()
experiment_test_scores = dict()
experiment_weights = dict()
all_extracted_forest_sizes = list()
# Used to check if all losses were computed using the same metric (it should be the case)
experiment_score_metrics = list()
all_weights = list()
# For each seed results stored in models/{experiment_id}/seeds
seeds = os.listdir(experiment_seed_root_path)
seeds.sort(key=int)
......@@ -46,6 +45,7 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_train_scores[seed] = list()
experiment_dev_scores[seed] = list()
experiment_test_scores[seed] = list()
experiment_weights[seed] = list()
# List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
......@@ -62,6 +62,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
experiment_test_scores[seed].append(model_raw_results.test_score)
# Save the metric
experiment_score_metrics.append(model_raw_results.score_metric)
# Save the weights
#experiment_weights[seed].append(model_raw_results.model_weights)
# Sanity checks
if len(set(experiment_score_metrics)) > 1:
......@@ -69,7 +71,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
if len(set([sum(extracted_forest_sizes) for extracted_forest_sizes in all_extracted_forest_sizes])) != 1:
raise ValueError("The extracted forest sizes aren't the sames across seeds.")
return experiment_train_scores, experiment_dev_scores, experiment_test_scores, all_extracted_forest_sizes[0], experiment_score_metrics[0]
return experiment_train_scores, experiment_dev_scores, experiment_test_scores, \
all_extracted_forest_sizes[0], experiment_score_metrics[0]#, experiment_weights
def extract_scores_across_seeds_and_forest_size(models_dir, results_dir, experiment_id, extracted_forest_sizes_number):
experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
......@@ -228,8 +231,6 @@ if __name__ == "__main__":
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name)
)
Plotter.plot_weight_density()
elif args.stage == 2:
if len(args.experiment_ids) != 4:
raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.')
......@@ -353,6 +354,9 @@ if __name__ == "__main__":
extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[1])
# omp_with_params
logger.info('Loading omp_with_params experiment scores...')
"""omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
omp_with_params_experiment_score_metric, experiment_weights = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2])"""
omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2])
......@@ -375,7 +379,7 @@ if __name__ == "__main__":
raise ValueError('Score metrics of all experiments must be the same.')
experiments_score_metric = base_with_params_experiment_score_metric
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4')
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4_fix')
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_stage2_losses(
......@@ -386,6 +390,9 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
# experiment_weights
#Plotter.weight_density(experiment_weights, output_path + os.sep + 'weight_density.png')
else:
raise ValueError('This stage number is not supported yet, but it will be!')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment