diff --git a/python/tffpy/experiments/exp_solve_tff.py b/python/tffpy/experiments/exp_solve_tff.py index 55f008ff8398dd8e532e61b7d85ba57224ee8c2a..d0d8dfc13dac60457c68623bd03cfa710edaa479 100644 --- a/python/tffpy/experiments/exp_solve_tff.py +++ b/python/tffpy/experiments/exp_solve_tff.py @@ -90,9 +90,20 @@ class SolveTffExperiment(Experiment): suffix : str Suffix that is appended to the name of the experiment, useful to save results in a specific folder. + keep_eigenvectors : 'all' or list + Use this parameter to remove eigenvectors from + `GabMulTff` object after computing performance in order to save + space. In this case, the `GabMulTff` will not be usable anymore + after the computation of the performance results. + If 'all', all eigenvectors are kept. To keep only some eigenvectors, + set this parameter to the list of task IDs for which eigenvectors + should be kept (usefull if you want to use or plot some task data + after the experiments, e.g., using method `plot_task`). If the list + is empty, all eigenvectors will be removed. """ - def __init__(self, force_reset=False, suffix=''): + def __init__(self, force_reset=False, suffix='', + keep_eigenvectors='all'): Experiment.__init__(self, name='SolveTffExperiment' + suffix, get_data=get_data, @@ -105,6 +116,7 @@ class SolveTffExperiment(Experiment): self.fig_dir = self.xp_path / 'figures' # a little trick to save collections when computing performance self.measure = lambda **x: perf_measures(**x, exp=self) + self.keep_eigenvectors = keep_eigenvectors @property def n_tasks(self): @@ -118,7 +130,8 @@ class SolveTffExperiment(Experiment): return len(list((self.xp_path / 'tasks').glob('0*'))) @staticmethod - def get_experiment(setting='full', force_reset=False): + def get_experiment(setting='full', force_reset=False, + keep_eigenvectors=None): """ Get the experiment instance with default values in order to handle it. @@ -133,6 +146,9 @@ class SolveTffExperiment(Experiment): If true, reset the experiment by erasing all previous results in order to run it from scratch. If False, the existing results are kept in order to proceed with the existing experiment. + keep_eigenvectors = 'all' or list + See constructor of `SolveTffExperiment`. If None, default + values are used. Returns ------- @@ -157,6 +173,8 @@ class SolveTffExperiment(Experiment): tolerance_arrf=1e-3, proba_arrf=1 - 1e-4, rand_state=0) + + keep_eigenvectors = [12, 13] if setting == 'light': data_params['loc_source'] = 'bird' data_params['wideband_src'] = 'car' @@ -166,11 +184,14 @@ class SolveTffExperiment(Experiment): problem_params['wb_to_loc_ratio_db'] = 16 solver_params['tolerance_arrf'] = 1e-2 solver_params['proba_arrf'] = 1 - 1e-2 + keep_eigenvectors = [0, 1] + # Create Experiment suffix = '' if setting == 'full' else '_Light' exp = SolveTffExperiment(force_reset=force_reset, - suffix=suffix) + suffix=suffix, + keep_eigenvectors=keep_eigenvectors) exp.add_tasks(data_params=data_params, problem_params=problem_params, solver_params=solver_params) @@ -645,6 +666,12 @@ class SolveTffExperiment(Experiment): plt.tight_layout() plt.savefig(fig_dir / 'spectrogram_true_wb_source.pdf') + def get_idt_from_params(self, data_params, problem_params, solver_params): + d = self.get_task_data_by_params(data_params=data_params, + problem_params=problem_params, + solver_params=solver_params) + return d['id_task'] + def get_data(loc_source, wideband_src): """ @@ -959,6 +986,11 @@ def perf_measures(task_params, source_data, problem_data, rank_sum=np.sum([s.size for s in gmtff.s_vec_list]), lowest_sv=np.min([np.min(s) for s in gmtff.s_vec_list]) ) + idt = exp.get_idt_from_params(**task_params) + if exp.keep_eigenvectors != 'all' and idt not in exp.keep_eigenvectors: + sd = exp._read_item(type_item='solved_data', idt=idt) + sd['gmtff'].u_mat_list = [None for _ in sd['gmtff'].u_mat_list] + exp._write_item(type_item='solved_data', idt=idt, content=sd) return dict(**running_times, **sdr_res, **is_res, **features) diff --git a/python/tffpy/scripts/script_exp_solve_tff.py b/python/tffpy/scripts/script_exp_solve_tff.py index 51071ec44a0133b924e3a9ad67511855a206696b..3cebe337de35065b3dfa2eae7db78cfde6ceefd8 100644 --- a/python/tffpy/scripts/script_exp_solve_tff.py +++ b/python/tffpy/scripts/script_exp_solve_tff.py @@ -62,7 +62,8 @@ from tffpy.experiments.exp_solve_tff import \ try: experiment = SolveTffExperiment.get_experiment(setting='full', - force_reset=False) + force_reset=False, + keep_eigenvectors=True) except RuntimeError as e: experiment = None print(e) @@ -92,13 +93,13 @@ if __name__ == '__main__': elif answer == 2: light_exp = SolveTffExperiment.get_experiment( setting='light', force_reset=False) - for idt in range(light_exp.n_tasks): + for idt in light_exp.keep_eigenvectors: light_exp.plot_task(idt=idt, fontsize=16) plt.close('all') light_exp.plot_results() elif answer == 3: experiment = SolveTffExperiment.get_experiment( - setting='full', force_reset=True) + setting='full', force_reset=True, keep_eigenvectors=True) experiment.display_status() elif answer == 4: experiment.display_status() diff --git a/python/tffpy/tests/ci_config.py b/python/tffpy/tests/ci_config.py index 2bb36b8ff3985a3a1910a077c7e74a203cc47ad4..c1a568cbf92b9905fe3c7222e6684e6ae000f3a8 100644 --- a/python/tffpy/tests/ci_config.py +++ b/python/tffpy/tests/ci_config.py @@ -87,5 +87,6 @@ def create_config_files(): yafe_config_parser.set('LOGGER', 'path', str(yafe_logger_path)) yafe_config_parser.write(open(yafe_config_file, 'w')) + if __name__ == '__main__': create_config_files() \ No newline at end of file