Skip to content
Snippets Groups Projects
Commit 690cf820 authored by Luc Giffon's avatar Luc Giffon
Browse files

Add few comments/docstring + update gitignore with models files + update readme for .env variables

parent a826e7cc
No related branches found
No related tags found
2 merge requests!3clean scripts,!2Luc manage normalization
models/*
*/.kile/*
*.kilepr
# Byte-compiled / optimized / DLL files
......
......@@ -49,5 +49,16 @@ Project Organization
Instal project
--------------
First install the project pacakge:
pip install -r requirements.txt
Then create a file `.env` by copying the file `.env.example`:
cp .env.example .env
Then you must set the project directory in the `.env` file :
project_dir = "path/to/your/project/directory"
This directory will be used for storing the model parameters.
\ No newline at end of file
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.base import BaseEstimator
import numpy as np
class OmpForestRegressor(BaseEstimator):
......@@ -35,10 +35,20 @@ class OmpForestRegressor(BaseEstimator):
return forest
def _extract_subforest(self, X_train, y_train):
D = [[tree.predict([elem])[0] for tree in self._forest] for elem in X_train]
"""
Given an already estimated regressor: apply OMP to get the weight of each tree.
The X_train data is used for interrogation of every tree in the forest. The y_train data
is used for finding the weights in OMP.
:param X_train: (n_sample, n_features) array
:param y_train: (n_sample,) array
:return:
"""
D = np.array([tree.predict(X_train) for tree in self._forest]).T
omp = OrthogonalMatchingPursuit(
n_nonzero_coefs=self._models_parameters.extracted_forest_size,
fit_intercept=False, normalize=False)
omp.fit(D, y_train)
weights = omp.coef_
weights = omp.coef_ # why not to use directly the omp estimator and bypass it using the coefs?
return weights
......@@ -14,6 +14,7 @@ class Trainer(object):
self._logger = LoggerFactory.create(LOG_PATH, __name__)
def iterate(self, model, models_dir):
# why is this function named iterate?
self._logger.info('Training model using train set...')
begin_time = time.time()
model.fit(self._dataset.X_train, self._dataset.y_train)
......
......@@ -2,6 +2,14 @@ import os
def resolve_experiment_id(models_dir):
"""
Return the ID of the next experiment.
The ID is an int equal to n+1 where n is the current number of directory in `models_dir
`
:param models_dir:
:return:
"""
ids = [x for x in os.listdir(models_dir)
if os.path.isdir(models_dir + os.sep + x)]
if len(ids) > 0:
......
from dotenv import load_dotenv
from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.dataset_loader import DatasetLoader
from bolsonaro.models.model_factory import ModelFactory
......@@ -5,6 +7,7 @@ from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.trainer import Trainer
from bolsonaro.utils import resolve_experiment_id
from dotenv import find_dotenv, load_dotenv
import argparse
import pathlib
import random
......@@ -13,11 +16,15 @@ import errno
if __name__ == "__main__":
# get environment variables in .env
load_dotenv(find_dotenv())
default_dataset_name = 'boston'
default_normalize = False
default_forest_size = 100
default_extracted_forest_size = 10
default_models_dir = 'models'
# the models will be stored in a directory structure like: models/{experiment_id}/seeds/{seed_nb}/extracted_forest_size/{nb_extracted_trees}
default_models_dir = os.environ["project_dir"] + os.sep + 'models'
default_dev_size = 0.2
default_test_size = 0.2
default_use_random_seed = True
......@@ -43,6 +50,7 @@ if __name__ == "__main__":
if type(args.extracted_forest_size) == list \
else [args.extracted_forest_size]
# todo the seeds shouldn't be randomly generated but fixed in range instead. We want it to be reproducible: exact same arguments should return exact same results.
random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) \
for i in range(args.random_seed_number)] \
if args.use_random_seed else None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment