Skip to content
Snippets Groups Projects
Commit 0fce0319 authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

Add the weights normalization parameter (but not implemented yet)

parent 7455fd98
No related branches found
No related tags found
1 merge request!3clean scripts
...@@ -5,11 +5,12 @@ import os ...@@ -5,11 +5,12 @@ import os
class ModelParameters(object): class ModelParameters(object):
def __init__(self, forest_size, extracted_forest_size, normalize_D, subsets_used, seed=None): def __init__(self, forest_size, extracted_forest_size, normalize_D, subsets_used, normalize_weights, seed):
self._forest_size = forest_size self._forest_size = forest_size
self._extracted_forest_size = extracted_forest_size self._extracted_forest_size = extracted_forest_size
self._normalize_D = normalize_D self._normalize_D = normalize_D
self._subsets_used = subsets_used self._subsets_used = subsets_used
self._normalize_weights = normalize_weights
self._seed = seed self._seed = seed
@property @property
...@@ -28,6 +29,10 @@ class ModelParameters(object): ...@@ -28,6 +29,10 @@ class ModelParameters(object):
def subsets_used(self): def subsets_used(self):
return self._subsets_used return self._subsets_used
@property
def normalize_weights(self):
return self._normalize_weights
@property @property
def seed(self): def seed(self):
return self._seed return self._seed
......
...@@ -56,7 +56,8 @@ class OmpForestRegressor(BaseEstimator): ...@@ -56,7 +56,8 @@ class OmpForestRegressor(BaseEstimator):
if self._models_parameters.normalize_D: if self._models_parameters.normalize_D:
D /= self._forest_norms D /= self._forest_norms
predictions = D @ self.weights # TODO: use self._models_parameters.normalize_weights here
predictions = D @ self._weights
return predictions return predictions
......
...@@ -31,6 +31,7 @@ if __name__ == "__main__": ...@@ -31,6 +31,7 @@ if __name__ == "__main__":
DEFAULT_RANDOM_SEED_NUMBER = 1 DEFAULT_RANDOM_SEED_NUMBER = 1
DEFAULT_SUBSETS_USED = 'train,dev' DEFAULT_SUBSETS_USED = 'train,dev'
DEFAULT_DISABLE_PROGRESS = False DEFAULT_DISABLE_PROGRESS = False
DEFAULT_normalize_weights = False
begin_random_seed_range = 1 begin_random_seed_range = 1
end_random_seed_range = 2000 end_random_seed_range = 2000
...@@ -48,6 +49,7 @@ if __name__ == "__main__": ...@@ -48,6 +49,7 @@ if __name__ == "__main__":
parser.add_argument('--seeds', nargs='+', type=int, default=None, help='Specific a list of seeds instead of generate them randomly') parser.add_argument('--seeds', nargs='+', type=int, default=None, help='Specific a list of seeds instead of generate them randomly')
parser.add_argument('--subsets_used', nargs='+', type=str, default=DEFAULT_SUBSETS_USED, help='train,dev: forest on train, OMP on dev. train+dev,train+dev: both forest and OMP on train+dev. train,train+dev: forest on train+dev and OMP on dev.') parser.add_argument('--subsets_used', nargs='+', type=str, default=DEFAULT_SUBSETS_USED, help='train,dev: forest on train, OMP on dev. train+dev,train+dev: both forest and OMP on train+dev. train,train+dev: forest on train+dev and OMP on dev.')
parser.add_argument('--disable_progress', action='store_true', default=DEFAULT_DISABLE_PROGRESS, help='Disable the progress bars.') parser.add_argument('--disable_progress', action='store_true', default=DEFAULT_DISABLE_PROGRESS, help='Disable the progress bars.')
parser.add_argument('--normalize_weights', action='store_true', default=DEFAULT_normalize_weights, help='Divide the predictions by the weights sum.')
args = parser.parse_args() args = parser.parse_args()
pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True) pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True)
...@@ -102,6 +104,7 @@ if __name__ == "__main__": ...@@ -102,6 +104,7 @@ if __name__ == "__main__":
extracted_forest_size=extracted_forest_size, extracted_forest_size=extracted_forest_size,
normalize_D=args.normalize_D, normalize_D=args.normalize_D,
subsets_used=args.subsets_used, subsets_used=args.subsets_used,
normalize_weights=args.normalize_weights,
seed=seed seed=seed
) )
model_parameters.save(sub_models_dir, experiment_id) model_parameters.save(sub_models_dir, experiment_id)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment