Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
skmad-suite
tff2020
Commits
c5e446d9
Commit
c5e446d9
authored
Dec 03, 2020
by
valentin.emiya
Browse files
fix conflict
parent
d3690e9b
Pipeline
#6062
canceled with stage
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
python/tffpy/experiments/exp_approx.py
0 → 100644
View file @
c5e446d9
# -*- coding: utf-8 -*-
"""
.. moduleauthor:: Valentin Emiya
"""
import
numpy
as
np
from
yafe
import
Experiment
from
tffpy.datasets
import
get_mix
,
get_dataset
from
tffpy.experiments.exp_solve_tff
import
SolveTffExperiment
class
ApproxExperiment
(
SolveTffExperiment
):
def
__init__
(
self
,
force_reset
=
False
,
suffix
=
''
):
SolveTffExperiment
.
__init__
(
self
,
force_reset
=
force_reset
,
suffix
=
'Approx'
+
suffix
)
def
display_results
(
self
):
res
=
self
.
load_results
(
array_type
=
'xarray'
)
res
=
res
.
squeeze
()
tff_list
=
res
.
to_dict
()[
'coords'
][
'solver_tol_subregions'
][
'data'
]
tol_list
=
res
.
to_dict
()[
'coords'
][
'solver_tolerance_arrf'
][
'data'
]
for
measure
in
[
'sdr_tff'
,
'sdr_tffo'
,
'sdr_tffe'
,
'is_tff'
,
'is_tffo'
,
'is_tffe'
]:
for
solver_tol_subregions
in
tff_list
:
for
tol
in
tol_list
:
mean_res
=
float
(
res
.
sel
(
solver_tolerance_arrf
=
tol
,
solver_tol_subregions
=
solver_tol_subregions
,
measure
=
measure
).
mean
())
std_res
=
float
(
res
.
sel
(
solver_tolerance_arrf
=
tol
,
solver_tol_subregions
=
solver_tol_subregions
,
measure
=
measure
).
std
())
t_res
=
float
(
res
.
sel
(
solver_tolerance_arrf
=
tol
,
solver_tol_subregions
=
solver_tol_subregions
,
measure
=
't_arrf'
).
mean
())
rank_res
=
float
(
res
.
sel
(
solver_tolerance_arrf
=
tol
,
solver_tol_subregions
=
solver_tol_subregions
,
measure
=
'rank_sum'
).
mean
())
if
solver_tol_subregions
is
None
:
measure_name
=
measure
+
'-1'
else
:
measure_name
=
measure
+
'-P'
print
(
'{}: mean={:.2f} std={:.2g} tol={}, t={}, rk={}'
.
format
(
measure_name
,
mean_res
,
std_res
,
tol
,
t_res
,
rank_res
))
def
plot_results
(
self
):
# No more need for this method
pass
def
plot_task
(
self
,
idt
,
fontsize
=
16
):
# No more need for this method
pass
@
staticmethod
def
get_experiment
(
setting
=
'full'
,
force_reset
=
False
):
assert
setting
in
(
'full'
,
'light'
)
dataset
=
get_dataset
()
# Set task parameters
data_params
=
dict
(
loc_source
=
'bird'
,
wideband_src
=
'car'
)
problem_params
=
dict
(
win_choice
=
'gauss 256'
,
# win_choice=['gauss 256', 'hann 512'],
wb_to_loc_ratio_db
=
8
,
n_iter_closing
=
3
,
n_iter_opening
=
3
,
closing_first
=
True
,
delta_mix_db
=
0
,
delta_loc_db
=
40
,
or_mask
=
True
,
crop
=
None
,
fig_dir
=
None
)
solver_params
=
dict
(
tol_subregions
=
[
None
,
1e-5
],
tolerance_arrf
=
list
(
10
**
np
.
arange
(
-
3
,
-
0.5
,
1
))
+
list
(
10
**
np
.
arange
(
-
1
,
0
,
0.2
)),
proba_arrf
=
1
-
1e-4
,
rand_state
=
np
.
arange
(
3
))
if
setting
==
'light'
:
problem_params
[
'win_choice'
]
=
'gauss 64'
,
problem_params
[
'crop'
]
=
4096
problem_params
[
'delta_loc_db'
]
=
20
problem_params
[
'wb_to_loc_ratio_db'
]
=
16
solver_params
[
'tolerance_arrf'
]
=
[
1e-1
,
1e-2
]
solver_params
[
'proba_arrf'
]
=
1
-
1e-2
solver_params
[
'tol_subregions'
]
=
1e-5
# Create Experiment
suffix
=
''
if
setting
==
'full'
else
'_Light'
exp
=
ApproxExperiment
(
force_reset
=
force_reset
,
suffix
=
suffix
)
exp
.
add_tasks
(
data_params
=
data_params
,
problem_params
=
problem_params
,
solver_params
=
solver_params
)
exp
.
generate_tasks
()
return
exp
def
create_and_run_light_experiment
():
"""
Create a light experiment and run it
"""
exp
=
ApproxExperiment
.
get_experiment
(
setting
=
'light'
,
force_reset
=
True
)
print
(
'*'
*
80
)
print
(
'Created experiment'
)
print
(
exp
)
print
(
exp
.
display_status
())
print
(
'*'
*
80
)
print
(
'Run task 0'
)
task_data
=
exp
.
get_task_data_by_id
(
idt
=
0
)
print
(
task_data
.
keys
())
print
(
task_data
[
'task_params'
][
'data_params'
])
problem
=
exp
.
get_problem
(
**
task_data
[
'task_params'
][
'problem_params'
])
print
(
problem
)
print
(
'*'
*
80
)
print
(
'Run all'
)
exp
.
launch_experiment
()
print
(
'*'
*
80
)
print
(
'Collect and plot results'
)
exp
.
collect_results
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment