Skip to content
Snippets Groups Projects
Commit eadac78d authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Start to add back density plot

parent c80ddd61
No related branches found
No related tags found
1 merge request!16Resolve "Experiment pipeline"
......@@ -6,12 +6,12 @@ import datetime
class ModelRawResults(object):
def __init__(self, model_object, training_time,
def __init__(self, model_weights, training_time,
datetime, train_score, dev_score, test_score,
train_score_base, dev_score_base,
test_score_base, score_metric, base_score_metric):
self._model_object = model_object
self._model_weights = model_weights
self._training_time = training_time
self._datetime = datetime
self._train_score = train_score
......@@ -24,8 +24,8 @@ class ModelRawResults(object):
self._base_score_metric = base_score_metric
@property
def model_object(self):
return self.model_object
def model_weights(self):
return self.model_weights
@property
def training_time(self):
......
......@@ -8,6 +8,7 @@ from sklearn.base import BaseEstimator
class OmpForest(BaseEstimator, metaclass=ABCMeta):
def __init__(self, models_parameters, base_forest_estimator):
self._base_forest_estimator = base_forest_estimator
self._models_parameters = models_parameters
......@@ -95,6 +96,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
pass
class SingleOmpForest(OmpForest):
def __init__(self, models_parameters, base_forest_estimator):
# fit_intercept shouldn't be set to False as the data isn't necessarily centered here
# normalization is handled outsite OMP
......
......@@ -126,8 +126,15 @@ class Trainer(object):
:param model: Object with
:param models_dir: Where the results will be saved
"""
model_weights = ''
if type(model) == RandomForestRegressor:
model_weights = model.coef_
elif type(model) == OmpForestRegressor:
model_weights = model._omp.coef_
results = ModelRawResults(
model_object='',
model_weights=model_weights,
training_time=self._end_time - self._begin_time,
datetime=datetime.datetime.now(),
train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train),
......
......@@ -33,6 +33,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
# Used to check if all losses were computed using the same metric (it should be the case)
experiment_score_metrics = list()
all_weights = list()
# For each seed results stored in models/{experiment_id}/seeds
seeds = os.listdir(experiment_seed_root_path)
seeds.sort(key=int)
......@@ -120,6 +122,7 @@ if __name__ == "__main__":
DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results'
DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
DEFAULT_PLOT_WEIGHT_DENSITY = False
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
......@@ -130,6 +133,7 @@ if __name__ == "__main__":
parser.add_argument('--dataset_name', nargs='?', type=str, required=True, help='Specify the dataset name. TODO: read it from models dir directly.')
parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.')
parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.')
args = parser.parse_args()
if args.stage not in list(range(1, 6)):
......@@ -224,6 +228,8 @@ if __name__ == "__main__":
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name)
)
Plotter.plot_weight_density()
elif args.stage == 2:
if len(args.experiment_ids) != 4:
raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment