diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py index 002dd93d1cee15c5ac4e0a87918dad043c368f1b..2381937b214ab37e0f6e18f96971df9606ec52e5 100644 --- a/code/bolsonaro/models/omp_forest_classifier.py +++ b/code/bolsonaro/models/omp_forest_classifier.py @@ -49,7 +49,6 @@ class OmpForestBinaryClassifier(SingleOmpForest): result_omp = np.mean(omp_trees_predictions, axis=1) - return result_omp def score(self, X, y, metric=DEFAULT_SCORE_METRIC): diff --git a/code/vizualisation/csv_to_figure.py b/code/vizualisation/csv_to_figure.py new file mode 100644 index 0000000000000000000000000000000000000000..879cbdcbf767aab3f1e55db8c17c67cc461b8af1 --- /dev/null +++ b/code/vizualisation/csv_to_figure.py @@ -0,0 +1,181 @@ +from dotenv import load_dotenv, find_dotenv +from pathlib import Path +import os +import pandas as pd +import numpy as np +import plotly.graph_objects as go +import plotly.io as pio + + +lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"] +lst_skip_task = ["correlation", "coherence"] +# lst_skip_subset = ["train/dev"] +lst_skip_subset = [] + +tasks = [ + "train_score", + "dev_score", + "test_score", + "coherence", + "correlation" +] + +dct_score_metric_fancy = { + "accuracy_score": "% Accuracy", + "mean_squared_error": "MSE" +} + +pio.templates.default = "plotly_white" + +dct_color_by_strategy = { + "OMP": (255, 0, 0), # red + "OMP Distillation": (255, 0, 0), # red + "OMP Distillation w/o weights": (255, 128, 0), # orange + "OMP w/o weights": (255, 128, 0), # orange + "Random": (0, 0, 0), # black + "Zhang Similarities": (255, 255, 0), # jaune + 'Zhang Predictions': (128, 0, 128), # turquoise + 'Ensemble': (0, 0, 255), # blue + "Kmeans": (0, 255, 0) # red +} + +dct_dash_by_strategy = { + "OMP": None, + "OMP Distillation": "dash", + "OMP Distillation w/o weights": "dash", + "OMP w/o weights": None, + "Random": "dot", + "Zhang Similarities": "dash", + 'Zhang Predictions': "dash", + 'Ensemble': "dash", + "Kmeans": "dash" +} + +def add_trace_from_df(df, fig): + df.sort_values(by="forest_size", inplace=True) + df_groupby_forest_size = df.groupby(['forest_size']) + forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values) + mean_value = df_groupby_forest_size[task].mean().values + std_value = df_groupby_forest_size[task].std().values + std_value_upper = list(mean_value + std_value) + std_value_lower = list(mean_value - std_value) + # print(df_strat) + fig.add_trace(go.Scatter(x=forest_sizes, y=mean_value, + mode='lines', + name=strat, + line=dict(dash=dct_dash_by_strategy[strat], color="rgb{}".format(dct_color_by_strategy[strat])) + )) + + fig.add_trace(go.Scatter( + x=forest_sizes + forest_sizes[::-1], + y=std_value_upper + std_value_lower[::-1], + fill='toself', + showlegend=False, + fillcolor='rgba{}'.format(dct_color_by_strategy[strat] + tpl_transparency), + line_color='rgba(255,255,255,0)', + name=strat + )) + +tpl_transparency = (0.1,) + +if __name__ == "__main__": + + load_dotenv(find_dotenv('.env')) + dir_name = "bolsonaro_models_25-03-20" + dir_path = Path(os.environ["project_dir"]) / "results" / dir_name + + out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name + + input_dir_file = dir_path / "results.csv" + df_results = pd.read_csv(open(input_dir_file, 'rb')) + + datasets = set(df_results["dataset"].values) + strategies = set(df_results["strategy"].values) + subsets = set(df_results["subset"].values) + + for task in tasks: + if task in lst_skip_task: + continue + for data_name in datasets: + df_data = df_results[df_results["dataset"] == data_name] + score_metric_name = df_data["score_metric"].values[0] + + for subset_name in subsets: + if subset_name in lst_skip_subset: + continue + df_subset = df_data[df_data["subset"] == subset_name] + fig = go.Figure() + + ################## + # all techniques # + ################## + for strat in strategies: + if strat in lst_skip_strategy: + continue + df_strat = df_subset[df_subset["strategy"] == strat] + + if "OMP" in strat: + ########################### + # traitement avec weights # + ########################### + df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False] + if data_name == "Boston" and subset_name == "train+dev/train+dev": + df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400] + add_trace_from_df(df_strat_wo_weights, fig) + + if "OMP" in strat and subset_name == "train/dev": + continue + elif "Random" not in strat and subset_name == "train/dev": + continue + + ################################# + # traitement general wo_weights # + ################################# + if "Random" in strat: + df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False] + else: + df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True] + + if "OMP" in strat: + strat = "{} w/o weights".format(strat) + + add_trace_from_df(df_strat_wo_weights, fig) + + title = "{} {} {}".format(task, data_name, subset_name) + fig.update_layout(barmode='group', + # title=title, + xaxis_title="# Selected Trees", + yaxis_title=dct_score_metric_fancy[score_metric_name], + font=dict( + # family="Courier New, monospace", + size=18, + color="black" + ), + showlegend = False, + margin = dict( + l=1, + r=1, + b=1, + t=1, + # pad=4 + ), + legend=dict( + traceorder="normal", + font=dict( + family="sans-serif", + size=18, + color="black" + ), + # bgcolor="LightSteelBlue", + # bordercolor="Black", + borderwidth=1, + ) + ) + # fig.show() + sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_") + filename = sanitize(title) + output_dir = out_dir / sanitize(subset_name) / sanitize(task) + output_dir.mkdir(parents=True, exist_ok=True) + fig.write_image(str((output_dir / filename).absolute()) + ".png") + + # exit() diff --git a/code/vizualisation/results_to_csv.py b/code/vizualisation/results_to_csv.py index d800d6468ff8534a439e69c3b33936ee949a24c3..ba54f35ee7d045b9fc3826ab0ad9db4198ba93a3 100644 --- a/code/vizualisation/results_to_csv.py +++ b/code/vizualisation/results_to_csv.py @@ -11,22 +11,30 @@ from dotenv import load_dotenv, find_dotenv dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 9)) dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17))) -dct_experiment_id_technique = {"1": 'None', - "2": 'Random', - "3": 'OMP', - "4": 'OMP Distillation', - "5": 'Kmeans', - "6": 'Zhang Similarities', - "7": 'Zhang Predictions', - "8": 'Ensemble', - "9": 'None', - "10": 'Random', - "11": 'OMP', - "12": 'OMP Distillation', - "13": 'Kmeans', - "14": 'Zhang Similarities', - "15": 'Zhang Predictions', - "16": 'Ensemble' +NONE = 'None' +Random = 'Random' +OMP = 'OMP' +OMP_Distillation = 'OMP Distillation' +Kmeans = 'Kmeans' +Zhang_Similarities = 'Zhang Similarities' +Zhang_Predictions = 'Zhang Predictions' +Ensemble = 'Ensemble' +dct_experiment_id_technique = {"1": NONE, + "2": Random, + "3": OMP, + "4": OMP_Distillation, + "5": Kmeans, + "6": Zhang_Similarities, + "7": Zhang_Predictions, + "8": Ensemble, + "9": NONE, + "10": Random, + "11": OMP, + "12": OMP_Distillation, + "13": Kmeans, + "14": Zhang_Similarities, + "15": Zhang_Predictions, + "16": Ensemble } @@ -49,7 +57,8 @@ dct_dataset_fancy = { } skip_attributes = ["datetime", "model_weights"] - +set_no_coherence = set() +set_no_corr = set() if __name__ == "__main__": @@ -63,9 +72,14 @@ if __name__ == "__main__": for root, dirs, files in os.walk(dir_path, topdown=False): for file_str in files: + if file_str == "results.csv": + continue path_dir = Path(root) path_file = path_dir / file_str - obj_results = pickle.load(open(path_file, 'rb')) + try: + obj_results = pickle.load(open(path_file, 'rb')) + except: + print("problem loading pickle file {}".format(path_file)) path_dir_split = str(path_dir).split("/") @@ -92,9 +106,31 @@ if __name__ == "__main__": continue if val_result == "": val_result = None + if key_result == "coherence" and val_result is None: + set_no_coherence.add(id_xp) + if key_result == "correlation" and val_result is None: + set_no_corr.add(id_xp) + dct_results[key_result].append(val_result) - print(path_file) + # class 'dict'>: {'model_weights': '', + # 'training_time': 0.0032033920288085938, + # 'datetime': datetime.datetime(2020, 3, 25, 0, 28, 34, 938400), + # 'train_score': 1.0, + # 'dev_score': 0.978021978021978, + # 'test_score': 0.9736842105263158, + # 'train_score_base': 1.0, + # 'dev_score_base': 0.978021978021978, + # 'test_score_base': 0.9736842105263158, + # 'score_metric': 'accuracy_score', + # 'base_score_metric': 'accuracy_score', + # 'coherence': 0.9892031711775613, + # 'correlation': 0.9510700193340448} + + # print(path_file) + + print("coh", set_no_coherence, len(set_no_coherence)) + print("cor", set_no_corr, len(set_no_corr)) final_df = pd.DataFrame.from_dict(dct_results) diff --git a/requirements.txt b/requirements.txt index 38a47c2beeff7ee073c27b9dd7ed9cabfbc12c4f..ef5021d7e1d513be852d7af1bbfae18e95ca08ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,79 @@ -# local package --e . - -# external requirements -click -Sphinx -coverage -awscli -flake8 -pytest -scikit-learn -git+git://github.com/darenr/scikit-optimize@master -python-dotenv -matplotlib -pandas +alabaster==0.7.12 +attrs==19.3.0 +awscli==1.16.272 +Babel==2.7.0 +backcall==0.1.0 +-e git+git@gitlab.lis-lab.fr:luc.giffon/bolsonaro.git@bbad0e522d6b4b392f1926fa935f2a7fac093411#egg=bolsonaro +botocore==1.13.8 +certifi==2019.11.28 +chardet==3.0.4 +Click==7.0 +colorama==0.4.1 +coverage==4.5.4 +cycler==0.10.0 +decorator==4.4.2 +docutils==0.15.2 +entrypoints==0.3 +flake8==3.7.9 +idna==2.8 +imagesize==1.1.0 +importlib-metadata==1.5.0 +ipython==7.13.0 +ipython-genutils==0.2.0 +jedi==0.16.0 +Jinja2==2.10.3 +jmespath==0.9.4 +joblib==0.14.0 +kiwisolver==1.1.0 +MarkupSafe==1.1.1 +matplotlib==3.1.1 +mccabe==0.6.1 +mkl-fft==1.0.14 +mkl-random==1.1.0 +mkl-service==2.3.0 +more-itertools==8.2.0 +numpy==1.17.3 +packaging==20.3 +pandas==0.25.3 +parso==0.6.2 +pexpect==4.8.0 +pickleshare==0.7.5 +plotly==4.5.2 +pluggy==0.13.1 +prompt-toolkit==3.0.3 +psutil==5.7.0 +ptyprocess==0.6.0 +py==1.8.1 +pyaml==20.3.1 +pyasn1==0.4.7 +pycodestyle==2.5.0 +pyflakes==2.1.1 +Pygments==2.6.1 +pyparsing==2.4.5 +pytest==5.4.1 +python-dateutil==2.8.1 +python-dotenv==0.10.3 +pytz==2019.3 +PyYAML==5.1.2 +requests==2.22.0 +retrying==1.3.3 +rsa==3.4.2 +s3transfer==0.2.1 +scikit-learn==0.21.3 +scikit-optimize==0.7.4 +scipy==1.3.1 +six==1.12.0 +snowballstemmer==2.0.0 +Sphinx==2.2.1 +sphinxcontrib-applehelp==1.0.1 +sphinxcontrib-devhelp==1.0.1 +sphinxcontrib-htmlhelp==1.0.2 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==1.0.2 +sphinxcontrib-serializinghtml==1.1.3 +tornado==6.0.3 +tqdm==4.43.0 +traitlets==4.3.3 +urllib3==1.25.6 +wcwidth==0.1.8 +zipp==2.2.0