Skip to content
Snippets Groups Projects
Commit 0fce0319 authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

Add the weights normalization parameter (but not implemented yet)

parent 7455fd98
No related branches found
No related tags found
1 merge request!3clean scripts
......@@ -5,11 +5,12 @@ import os
class ModelParameters(object):
def __init__(self, forest_size, extracted_forest_size, normalize_D, subsets_used, seed=None):
def __init__(self, forest_size, extracted_forest_size, normalize_D, subsets_used, normalize_weights, seed):
self._forest_size = forest_size
self._extracted_forest_size = extracted_forest_size
self._normalize_D = normalize_D
self._subsets_used = subsets_used
self._normalize_weights = normalize_weights
self._seed = seed
@property
......@@ -28,6 +29,10 @@ class ModelParameters(object):
def subsets_used(self):
return self._subsets_used
@property
def normalize_weights(self):
return self._normalize_weights
@property
def seed(self):
return self._seed
......
......@@ -56,7 +56,8 @@ class OmpForestRegressor(BaseEstimator):
if self._models_parameters.normalize_D:
D /= self._forest_norms
predictions = D @ self.weights
# TODO: use self._models_parameters.normalize_weights here
predictions = D @ self._weights
return predictions
......
......@@ -31,6 +31,7 @@ if __name__ == "__main__":
DEFAULT_RANDOM_SEED_NUMBER = 1
DEFAULT_SUBSETS_USED = 'train,dev'
DEFAULT_DISABLE_PROGRESS = False
DEFAULT_normalize_weights = False
begin_random_seed_range = 1
end_random_seed_range = 2000
......@@ -48,6 +49,7 @@ if __name__ == "__main__":
parser.add_argument('--seeds', nargs='+', type=int, default=None, help='Specific a list of seeds instead of generate them randomly')
parser.add_argument('--subsets_used', nargs='+', type=str, default=DEFAULT_SUBSETS_USED, help='train,dev: forest on train, OMP on dev. train+dev,train+dev: both forest and OMP on train+dev. train,train+dev: forest on train+dev and OMP on dev.')
parser.add_argument('--disable_progress', action='store_true', default=DEFAULT_DISABLE_PROGRESS, help='Disable the progress bars.')
parser.add_argument('--normalize_weights', action='store_true', default=DEFAULT_normalize_weights, help='Divide the predictions by the weights sum.')
args = parser.parse_args()
pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True)
......@@ -102,6 +104,7 @@ if __name__ == "__main__":
extracted_forest_size=extracted_forest_size,
normalize_D=args.normalize_D,
subsets_used=args.subsets_used,
normalize_weights=args.normalize_weights,
seed=seed
)
model_parameters.save(sub_models_dir, experiment_id)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment