Skip to content
Snippets Groups Projects
Commit aa540f42 authored by valentin.emiya's avatar valentin.emiya
Browse files

Merge branch 'py' of gitlab.lis-lab.fr:skmad-suite/tff2020 into approx

parents 5ffdf5b3 c4fb1342
No related branches found
No related tags found
No related merge requests found
...@@ -898,8 +898,9 @@ def perf_measures(task_params, source_data, problem_data, ...@@ -898,8 +898,9 @@ def perf_measures(task_params, source_data, problem_data,
------- -------
dict dict
All data useful for result analysis including SDR and Itakura-Saito All data useful for result analysis including SDR and Itakura-Saito
performance, running times, hyperparameter values, mask size and performance, running times, hyperparameter values, mask size,
number of sub-regions. number of sub-regions, estimated rank (summed over sub-regions),
lowest singular value.
""" """
x_tff = solved_data['x_tff'] x_tff = solved_data['x_tff']
x_zero = solved_data['x_zero'] x_zero = solved_data['x_zero']
...@@ -955,7 +956,9 @@ def perf_measures(task_params, source_data, problem_data, ...@@ -955,7 +956,9 @@ def perf_measures(task_params, source_data, problem_data,
features = dict(mask_size=np.sum(gmtff.mask > 0), features = dict(mask_size=np.sum(gmtff.mask > 0),
mask_ratio=np.mean(gmtff.mask > 0), mask_ratio=np.mean(gmtff.mask > 0),
n_subregions=gmtff.n_areas, n_subregions=gmtff.n_areas,
rank_sum=np.sum([s.size for s in gmtff.s_vec_list])) rank_sum=np.sum([s.size for s in gmtff.s_vec_list]),
lowest_sv=np.min([np.min(s) for s in gmtff.s_vec_list])
)
return dict(**running_times, **sdr_res, **is_res, **features) return dict(**running_times, **sdr_res, **is_res, **features)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment