Skip to content
Snippets Groups Projects
Commit b16a8704 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

MERGE

parents c61cb784 c603d6b3
No related branches found
No related tags found
No related merge requests found
# The base configuration of the benchmark
# Enable logging
log: True
# The name of each dataset in the directory on which the benchmark should be run
name: "multiview_mnist"
# A label for the result directory
label: "mnist"
# The type of dataset, currently supported ".hdf5", and ".csv"
file_type: ".hdf5"
# The views to use in the banchmark, an empty value will result in using all the views
views:
# The path to the directory where the datasets are stored, an absolute path is advised
pathf: "examples/data/"
# The niceness of the processes, useful to lower their priority
nice: 0
# The random state of the benchmark, useful for reproducibility
random_state: 42
# The number of parallel computing threads
nb_cores: 4
# Used to run the benchmark on the full dataset
full: True
# Used to be able to run more than one benchmark per minute
debug: False
# The directory in which the results will be stored, an absolute path is advised
res_dir: "examples/results/example_3/"
# If an error occurs in a classifier, if track_tracebacks is set to True, the
# benchmark saves the traceback and continues, if it is set to False, it will
# stop the benchmark and raise the error
track_tracebacks: True
# All the classification-realted configuration options
# If the dataset is multiclass, will use this multiclass-to-biclass method
multiclass_method: "oneVersusOne"
# The ratio number of test exmaples/number of train samples
split: 0.8
# The nubmer of folds in the cross validation process when hyper-paramter optimization is performed
nb_folds: 5
# The number of classes to select in the dataset
nb_class: 2
# The name of the classes to select in the dataset
classes:
# The type of algorithms to run during the benchmark (monoview and/or multiview)
type: ["monoview","multiview"]
# The name of the monoview algorithms to run, ["all"] to run all the available classifiers
algos_monoview: ["decision_tree", "adaboost", ]
# The names of the multiview algorithms to run, ["all"] to run all the available classifiers
algos_multiview: ["early_fusion_decision_tree", "early_fusion_adaboost"]
# The number of times the benchamrk is repeated with different train/test
# split, to have more statistically significant results
stats_iter: 5
# The metrics that will be use din the result analysis
metrics:
accuracy_score: {}
f1_score:
average: "micro"
# The metric that will be used in the hyper-parameter optimization process
metric_princ: "accuracy_score"
# The type of hyper-parameter optimization method
hps_type: 'Random'
# The number of iteration in the hyper-parameter optimization process
hps_args:
n_iter: 10
decision_tree:
max_depth: 3
adaboost:
base_estimator: "DecisionTreeClassifier"
n_estimators: 10
weighted_linear_late_fusion:
classifiers_names: "decision_tree"
classifier_configs:
decision_tree:
max_depth: 2
# The following arguments are classifier-specific, and are documented in each
# of the corresponding modules.
# In order to run multiple sets of parameters, use multiple values in the
# following lists, and set hps_type to None.
......@@ -12,9 +12,8 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
def __init__(self, random_state=None, base_estimator="DecisionTreeClassifier",
n_estimators=10, sampling_strategy="auto",
replacement=False, base_estimator_config=None):
base_estimator = self.get_base_estimator(base_estimator,
base_estimator_config)
replacement=False, base_estimator_config=None, **kwargs):
base_estimator = self.get_base_estimator(base_estimator, base_estimator_config, **kwargs)
super(ImbalanceBagging, self).__init__(random_state=random_state, base_estimator=base_estimator,
n_estimators=n_estimators,
sampling_strategy=sampling_strategy,
......
......@@ -267,7 +267,6 @@ def exec_multiview(directory, dataset_var, name, classification_indices,
logging.info("Start:\t Optimizing hyperparameters")
hps_beg = time.monotonic()
print(dataset_var.view_dict)
if hps_method != "None":
hps_method_class = getattr(hyper_parameter_search, hps_method)
estimator = getattr(classifier_module, classifier_name)(
......
from .additions.early_fusion_from_monoview import BaseEarlyFusion
from ..utils.base import base_boosting_estimators
from ..utils.hyper_parameter_search import CustomRandint, CustomUniform
classifier_class_name = "EarlyFusionImbalanceBagging"
class EarlyFusionImbalanceBagging(BaseEarlyFusion):
def __init__(self, random_state=None, base_estimator="DecisionTreeClassifier",
n_estimators=10, sampling_strategy="auto",
replacement=False, base_estimator_config=None, **kwargs):
BaseEarlyFusion.__init__(self, random_state=random_state,
monoview_classifier="imbalance_bagging",
base_estimator=base_estimator,
n_estimators=n_estimators, sampling_strategy=sampling_strategy,
replacement=replacement, base_estimator_config=base_estimator_config, **kwargs)
self.param_names = ["n_estimators", "base_estimator",
"sampling_strategy", ]
self.classed_params = ["base_estimator"]
self.distribs = [CustomRandint(low=1, high=50),
base_boosting_estimators,
["auto"]]
self.weird_strings = {"base_estimator": "class_name"}
self.base_estimator_config = base_estimator_config
\ No newline at end of file
from .additions.early_fusion_from_monoview import BaseEarlyFusion
from ..utils.hyper_parameter_search import CustomUniform, CustomRandint
classifier_class_name = "EarlyFusionSCM"
class EarlyFusionSCM(BaseEarlyFusion):
def __init__(self, random_state=None, model_type="conjunction",
max_rules=10, p=0.1, **kwargs):
BaseEarlyFusion.__init__(self, random_state=None,
monoview_classifier="scm",
model_type=model_type,
max_rules=max_rules, p=p, **kwargs)
self.param_names = ["model_type", "max_rules", "p", "random_state"]
self.distribs = [["conjunction", "disjunction"],
CustomRandint(low=1, high=15),
CustomUniform(loc=0, state=1), [random_state]]
self.classed_params = []
self.weird_strings = {}
\ No newline at end of file
......@@ -121,9 +121,9 @@ def plot_feature_importances(file_name, feature_importance,
hoverinfo=["text"],
colorscale="Hot",
reversescale=True))
fig.update_layout(
xaxis={"showgrid": False, "showticklabels": False, "ticks": ''},
yaxis={"showgrid": False, "showticklabels": False, "ticks": ''})
# fig.update_layout(
# xaxis={"showgrid": False, "showticklabels": False, "ticks": ''},
# yaxis={"showgrid": False, "showticklabels": False, "ticks": ''})
fig.update_layout(paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)')
plotly.offline.plot(fig, filename=file_name + ".html", auto_open=False)
......
......@@ -65,7 +65,14 @@ class BaseClassifier(BaseEstimator, ):
else:
return self.__class__.__name__ + " with no config."
def get_base_estimator(self, base_estimator, estimator_config):
def get_base_estimator(self, base_estimator, estimator_config, **kwargs):
if not kwargs is None:
for key, value in kwargs.items():
if key.startswith("base_estimator__"):
if estimator_config is None:
estimator_config = {}
param_name = key.split("base_estimator__")[1]
estimator_config[param_name] = value
if estimator_config is None:
estimator_config = {}
if base_estimator is None:
......
......@@ -43,12 +43,5 @@ def remove_compressed(exp_path):
if __name__=="__main__":
import os
explore_files("/home/baptiste/Documents/Gitwork/summit/results/BioBanQ_mv_status/debug_started_2022_07_01-22_23_04_bal_acc/")
# print(os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"))
# for dataset in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
# print(dataset)
# dataset_path = os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dataset)
# for exp_name in os.listdir(dataset_path):
# exp_path = os.path.join(dataset_path, exp_name)
# explore_files(exp_path, compress=False)
\ No newline at end of file
simplify_plotly("/home/baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
......@@ -166,7 +166,6 @@ class Random(RandomizedSearchCV, HPSearch):
param_distributions=param_distributions,
refit=refit, n_jobs=n_jobs, scoring=scoring,
cv=cv, random_state=random_state)
self.framework = framework
self.available_indices = learning_indices
self.view_indices = view_indices
......@@ -282,8 +281,6 @@ class CustomUniform(CustomDist):
return self.multiply(unif)
def format_params(params, pref=""):
if isinstance(params, dict):
dictionary = {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment