Skip to content
Snippets Groups Projects
Commit 42e53ce7 authored by valentin.emiya's avatar valentin.emiya
Browse files

add mechanism to remove eigenvectors (memory saving)

parent 7077f140
No related branches found
No related tags found
No related merge requests found
Pipeline #6090 failed
......@@ -90,9 +90,20 @@ class SolveTffExperiment(Experiment):
suffix : str
Suffix that is appended to the name of the experiment, useful to
save results in a specific folder.
"""
def __init__(self, force_reset=False, suffix=''):
keep_eigenvectors : 'all' or list
Use this parameter to remove eigenvectors from
`GabMulTff` object after computing performance in order to save
space. In this case, the `GabMulTff` will not be usable anymore
after the computation of the performance results.
If 'all', all eigenvectors are kept. To keep only some eigenvectors,
set this parameter to the list of task IDs for which eigenvectors
should be kept (usefull if you want to use or plot some task data
after the experiments, e.g., using method `plot_task`). If the list
is empty, all eigenvectors will be removed.
"""
def __init__(self, force_reset=False, suffix='',
keep_eigenvectors='all'):
Experiment.__init__(self,
name='SolveTffExperiment' + suffix,
get_data=get_data,
......@@ -105,6 +116,7 @@ class SolveTffExperiment(Experiment):
self.fig_dir = self.xp_path / 'figures'
# a little trick to save collections when computing performance
self.measure = lambda **x: perf_measures(**x, exp=self)
self.keep_eigenvectors = keep_eigenvectors
@property
def n_tasks(self):
......@@ -118,7 +130,8 @@ class SolveTffExperiment(Experiment):
return len(list((self.xp_path / 'tasks').glob('0*')))
@staticmethod
def get_experiment(setting='full', force_reset=False):
def get_experiment(setting='full', force_reset=False,
keep_eigenvectors=None):
"""
Get the experiment instance with default values in order to handle it.
......@@ -133,6 +146,9 @@ class SolveTffExperiment(Experiment):
If true, reset the experiment by erasing all previous results
in order to run it from scratch. If False, the existing results are
kept in order to proceed with the existing experiment.
keep_eigenvectors = 'all' or list
See constructor of `SolveTffExperiment`. If None, default
values are used.
Returns
-------
......@@ -157,6 +173,8 @@ class SolveTffExperiment(Experiment):
tolerance_arrf=1e-3,
proba_arrf=1 - 1e-4,
rand_state=0)
keep_eigenvectors = [12, 13]
if setting == 'light':
data_params['loc_source'] = 'bird'
data_params['wideband_src'] = 'car'
......@@ -166,11 +184,14 @@ class SolveTffExperiment(Experiment):
problem_params['wb_to_loc_ratio_db'] = 16
solver_params['tolerance_arrf'] = 1e-2
solver_params['proba_arrf'] = 1 - 1e-2
keep_eigenvectors = [0, 1]
# Create Experiment
suffix = '' if setting == 'full' else '_Light'
exp = SolveTffExperiment(force_reset=force_reset,
suffix=suffix)
suffix=suffix,
keep_eigenvectors=keep_eigenvectors)
exp.add_tasks(data_params=data_params,
problem_params=problem_params,
solver_params=solver_params)
......@@ -645,6 +666,12 @@ class SolveTffExperiment(Experiment):
plt.tight_layout()
plt.savefig(fig_dir / 'spectrogram_true_wb_source.pdf')
def get_idt_from_params(self, data_params, problem_params, solver_params):
d = self.get_task_data_by_params(data_params=data_params,
problem_params=problem_params,
solver_params=solver_params)
return d['id_task']
def get_data(loc_source, wideband_src):
"""
......@@ -959,6 +986,11 @@ def perf_measures(task_params, source_data, problem_data,
rank_sum=np.sum([s.size for s in gmtff.s_vec_list]),
lowest_sv=np.min([np.min(s) for s in gmtff.s_vec_list])
)
idt = exp.get_idt_from_params(**task_params)
if exp.keep_eigenvectors != 'all' and idt not in exp.keep_eigenvectors:
sd = exp._read_item(type_item='solved_data', idt=idt)
sd['gmtff'].u_mat_list = [None for _ in sd['gmtff'].u_mat_list]
exp._write_item(type_item='solved_data', idt=idt, content=sd)
return dict(**running_times, **sdr_res, **is_res, **features)
......
......@@ -62,7 +62,8 @@ from tffpy.experiments.exp_solve_tff import \
try:
experiment = SolveTffExperiment.get_experiment(setting='full',
force_reset=False)
force_reset=False,
keep_eigenvectors=True)
except RuntimeError as e:
experiment = None
print(e)
......@@ -92,13 +93,13 @@ if __name__ == '__main__':
elif answer == 2:
light_exp = SolveTffExperiment.get_experiment(
setting='light', force_reset=False)
for idt in range(light_exp.n_tasks):
for idt in light_exp.keep_eigenvectors:
light_exp.plot_task(idt=idt, fontsize=16)
plt.close('all')
light_exp.plot_results()
elif answer == 3:
experiment = SolveTffExperiment.get_experiment(
setting='full', force_reset=True)
setting='full', force_reset=True, keep_eigenvectors=True)
experiment.display_status()
elif answer == 4:
experiment.display_status()
......
......@@ -87,5 +87,6 @@ def create_config_files():
yafe_config_parser.set('LOGGER', 'path', str(yafe_logger_path))
yafe_config_parser.write(open(yafe_config_file, 'w'))
if __name__ == '__main__':
create_config_files()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment