Skip to content
Snippets Groups Projects
Commit 8d6147c2 authored by Luc Giffon's avatar Luc Giffon
Browse files

vizualisation scripts

parent bbad0e52
No related branches found
No related tags found
1 merge request!23Resolve "integration-sota"
...@@ -49,7 +49,6 @@ class OmpForestBinaryClassifier(SingleOmpForest): ...@@ -49,7 +49,6 @@ class OmpForestBinaryClassifier(SingleOmpForest):
result_omp = np.mean(omp_trees_predictions, axis=1) result_omp = np.mean(omp_trees_predictions, axis=1)
return result_omp return result_omp
def score(self, X, y, metric=DEFAULT_SCORE_METRIC): def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
......
from dotenv import load_dotenv, find_dotenv
from pathlib import Path
import os
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
lst_skip_task = ["correlation", "coherence"]
# lst_skip_subset = ["train/dev"]
lst_skip_subset = []
tasks = [
"train_score",
"dev_score",
"test_score",
"coherence",
"correlation"
]
dct_score_metric_fancy = {
"accuracy_score": "% Accuracy",
"mean_squared_error": "MSE"
}
pio.templates.default = "plotly_white"
dct_color_by_strategy = {
"OMP": (255, 0, 0), # red
"OMP Distillation": (255, 0, 0), # red
"OMP Distillation w/o weights": (255, 128, 0), # orange
"OMP w/o weights": (255, 128, 0), # orange
"Random": (0, 0, 0), # black
"Zhang Similarities": (255, 255, 0), # jaune
'Zhang Predictions': (128, 0, 128), # turquoise
'Ensemble': (0, 0, 255), # blue
"Kmeans": (0, 255, 0) # red
}
dct_dash_by_strategy = {
"OMP": None,
"OMP Distillation": "dash",
"OMP Distillation w/o weights": "dash",
"OMP w/o weights": None,
"Random": "dot",
"Zhang Similarities": "dash",
'Zhang Predictions': "dash",
'Ensemble': "dash",
"Kmeans": "dash"
}
def add_trace_from_df(df, fig):
df.sort_values(by="forest_size", inplace=True)
df_groupby_forest_size = df.groupby(['forest_size'])
forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)
mean_value = df_groupby_forest_size[task].mean().values
std_value = df_groupby_forest_size[task].std().values
std_value_upper = list(mean_value + std_value)
std_value_lower = list(mean_value - std_value)
# print(df_strat)
fig.add_trace(go.Scatter(x=forest_sizes, y=mean_value,
mode='lines',
name=strat,
line=dict(dash=dct_dash_by_strategy[strat], color="rgb{}".format(dct_color_by_strategy[strat]))
))
fig.add_trace(go.Scatter(
x=forest_sizes + forest_sizes[::-1],
y=std_value_upper + std_value_lower[::-1],
fill='toself',
showlegend=False,
fillcolor='rgba{}'.format(dct_color_by_strategy[strat] + tpl_transparency),
line_color='rgba(255,255,255,0)',
name=strat
))
tpl_transparency = (0.1,)
if __name__ == "__main__":
load_dotenv(find_dotenv('.env'))
dir_name = "bolsonaro_models_25-03-20"
dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
input_dir_file = dir_path / "results.csv"
df_results = pd.read_csv(open(input_dir_file, 'rb'))
datasets = set(df_results["dataset"].values)
strategies = set(df_results["strategy"].values)
subsets = set(df_results["subset"].values)
for task in tasks:
if task in lst_skip_task:
continue
for data_name in datasets:
df_data = df_results[df_results["dataset"] == data_name]
score_metric_name = df_data["score_metric"].values[0]
for subset_name in subsets:
if subset_name in lst_skip_subset:
continue
df_subset = df_data[df_data["subset"] == subset_name]
fig = go.Figure()
##################
# all techniques #
##################
for strat in strategies:
if strat in lst_skip_strategy:
continue
df_strat = df_subset[df_subset["strategy"] == strat]
if "OMP" in strat:
###########################
# traitement avec weights #
###########################
df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
if data_name == "Boston" and subset_name == "train+dev/train+dev":
df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
add_trace_from_df(df_strat_wo_weights, fig)
if "OMP" in strat and subset_name == "train/dev":
continue
elif "Random" not in strat and subset_name == "train/dev":
continue
#################################
# traitement general wo_weights #
#################################
if "Random" in strat:
df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
else:
df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
if "OMP" in strat:
strat = "{} w/o weights".format(strat)
add_trace_from_df(df_strat_wo_weights, fig)
title = "{} {} {}".format(task, data_name, subset_name)
fig.update_layout(barmode='group',
# title=title,
xaxis_title="# Selected Trees",
yaxis_title=dct_score_metric_fancy[score_metric_name],
font=dict(
# family="Courier New, monospace",
size=18,
color="black"
),
showlegend = False,
margin = dict(
l=1,
r=1,
b=1,
t=1,
# pad=4
),
legend=dict(
traceorder="normal",
font=dict(
family="sans-serif",
size=18,
color="black"
),
# bgcolor="LightSteelBlue",
# bordercolor="Black",
borderwidth=1,
)
)
# fig.show()
sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
filename = sanitize(title)
output_dir = out_dir / sanitize(subset_name) / sanitize(task)
output_dir.mkdir(parents=True, exist_ok=True)
fig.write_image(str((output_dir / filename).absolute()) + ".png")
# exit()
...@@ -11,22 +11,30 @@ from dotenv import load_dotenv, find_dotenv ...@@ -11,22 +11,30 @@ from dotenv import load_dotenv, find_dotenv
dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 9)) dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 9))
dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17))) dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17)))
dct_experiment_id_technique = {"1": 'None', NONE = 'None'
"2": 'Random', Random = 'Random'
"3": 'OMP', OMP = 'OMP'
"4": 'OMP Distillation', OMP_Distillation = 'OMP Distillation'
"5": 'Kmeans', Kmeans = 'Kmeans'
"6": 'Zhang Similarities', Zhang_Similarities = 'Zhang Similarities'
"7": 'Zhang Predictions', Zhang_Predictions = 'Zhang Predictions'
"8": 'Ensemble', Ensemble = 'Ensemble'
"9": 'None', dct_experiment_id_technique = {"1": NONE,
"10": 'Random', "2": Random,
"11": 'OMP', "3": OMP,
"12": 'OMP Distillation', "4": OMP_Distillation,
"13": 'Kmeans', "5": Kmeans,
"14": 'Zhang Similarities', "6": Zhang_Similarities,
"15": 'Zhang Predictions', "7": Zhang_Predictions,
"16": 'Ensemble' "8": Ensemble,
"9": NONE,
"10": Random,
"11": OMP,
"12": OMP_Distillation,
"13": Kmeans,
"14": Zhang_Similarities,
"15": Zhang_Predictions,
"16": Ensemble
} }
...@@ -49,7 +57,8 @@ dct_dataset_fancy = { ...@@ -49,7 +57,8 @@ dct_dataset_fancy = {
} }
skip_attributes = ["datetime", "model_weights"] skip_attributes = ["datetime", "model_weights"]
set_no_coherence = set()
set_no_corr = set()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -63,9 +72,14 @@ if __name__ == "__main__": ...@@ -63,9 +72,14 @@ if __name__ == "__main__":
for root, dirs, files in os.walk(dir_path, topdown=False): for root, dirs, files in os.walk(dir_path, topdown=False):
for file_str in files: for file_str in files:
if file_str == "results.csv":
continue
path_dir = Path(root) path_dir = Path(root)
path_file = path_dir / file_str path_file = path_dir / file_str
obj_results = pickle.load(open(path_file, 'rb')) try:
obj_results = pickle.load(open(path_file, 'rb'))
except:
print("problem loading pickle file {}".format(path_file))
path_dir_split = str(path_dir).split("/") path_dir_split = str(path_dir).split("/")
...@@ -92,9 +106,31 @@ if __name__ == "__main__": ...@@ -92,9 +106,31 @@ if __name__ == "__main__":
continue continue
if val_result == "": if val_result == "":
val_result = None val_result = None
if key_result == "coherence" and val_result is None:
set_no_coherence.add(id_xp)
if key_result == "correlation" and val_result is None:
set_no_corr.add(id_xp)
dct_results[key_result].append(val_result) dct_results[key_result].append(val_result)
print(path_file) # class 'dict'>: {'model_weights': '',
# 'training_time': 0.0032033920288085938,
# 'datetime': datetime.datetime(2020, 3, 25, 0, 28, 34, 938400),
# 'train_score': 1.0,
# 'dev_score': 0.978021978021978,
# 'test_score': 0.9736842105263158,
# 'train_score_base': 1.0,
# 'dev_score_base': 0.978021978021978,
# 'test_score_base': 0.9736842105263158,
# 'score_metric': 'accuracy_score',
# 'base_score_metric': 'accuracy_score',
# 'coherence': 0.9892031711775613,
# 'correlation': 0.9510700193340448}
# print(path_file)
print("coh", set_no_coherence, len(set_no_coherence))
print("cor", set_no_corr, len(set_no_corr))
final_df = pd.DataFrame.from_dict(dct_results) final_df = pd.DataFrame.from_dict(dct_results)
......
# local package alabaster==0.7.12
-e . attrs==19.3.0
awscli==1.16.272
# external requirements Babel==2.7.0
click backcall==0.1.0
Sphinx -e git+git@gitlab.lis-lab.fr:luc.giffon/bolsonaro.git@bbad0e522d6b4b392f1926fa935f2a7fac093411#egg=bolsonaro
coverage botocore==1.13.8
awscli certifi==2019.11.28
flake8 chardet==3.0.4
pytest Click==7.0
scikit-learn colorama==0.4.1
git+git://github.com/darenr/scikit-optimize@master coverage==4.5.4
python-dotenv cycler==0.10.0
matplotlib decorator==4.4.2
pandas docutils==0.15.2
entrypoints==0.3
flake8==3.7.9
idna==2.8
imagesize==1.1.0
importlib-metadata==1.5.0
ipython==7.13.0
ipython-genutils==0.2.0
jedi==0.16.0
Jinja2==2.10.3
jmespath==0.9.4
joblib==0.14.0
kiwisolver==1.1.0
MarkupSafe==1.1.1
matplotlib==3.1.1
mccabe==0.6.1
mkl-fft==1.0.14
mkl-random==1.1.0
mkl-service==2.3.0
more-itertools==8.2.0
numpy==1.17.3
packaging==20.3
pandas==0.25.3
parso==0.6.2
pexpect==4.8.0
pickleshare==0.7.5
plotly==4.5.2
pluggy==0.13.1
prompt-toolkit==3.0.3
psutil==5.7.0
ptyprocess==0.6.0
py==1.8.1
pyaml==20.3.1
pyasn1==0.4.7
pycodestyle==2.5.0
pyflakes==2.1.1
Pygments==2.6.1
pyparsing==2.4.5
pytest==5.4.1
python-dateutil==2.8.1
python-dotenv==0.10.3
pytz==2019.3
PyYAML==5.1.2
requests==2.22.0
retrying==1.3.3
rsa==3.4.2
s3transfer==0.2.1
scikit-learn==0.21.3
scikit-optimize==0.7.4
scipy==1.3.1
six==1.12.0
snowballstemmer==2.0.0
Sphinx==2.2.1
sphinxcontrib-applehelp==1.0.1
sphinxcontrib-devhelp==1.0.1
sphinxcontrib-htmlhelp==1.0.2
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.2
sphinxcontrib-serializinghtml==1.1.3
tornado==6.0.3
tqdm==4.43.0
traitlets==4.3.3
urllib3==1.25.6
wcwidth==0.1.8
zipp==2.2.0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment