Skip to content
Snippets Groups Projects
Commit 18c4ac3a authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Last modifs

parent 1388f134
No related branches found
No related tags found
No related merge requests found
Showing
with 23 additions and 37 deletions
......@@ -2,7 +2,6 @@
<module type="WEB_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="module" module-name="multiview_generator" />
<orderEntry type="module" module-name="short_projects" />
......
......@@ -3,15 +3,15 @@
# Enable logging
log: True
# The name of each dataset in the directory on which the benchmark should be run
name: ["doc_summit"]
name: ["tuto"]
# A label for the resul directory
label: "example_1"
label:
# The type of dataset, currently supported ".hdf5", and ".csv"
file_type: ".hdf5"
# The views to use in the banchmark, an empty value will result in using all the views
views:
# The path to the directory where the datasets are stored, an absolute path is advised
pathf: "examples/data/"
pathf: "."
# The niceness of the processes, useful to lower their priority
nice: 0
# The random state of the benchmark, useful for reproducibility
......@@ -23,7 +23,7 @@ full: True
# Used to be able to run more than one benchmark per minute
debug: False
# The directory in which the results will be stored, an absolute path is advised
res_dir: "examples/results/example_1/"
res_dir: "."
# If an error occurs in a classifier, if track_tracebacks is set to True, the
# benchmark saves the traceback and continues, if it is set to False, it will
# stop the benchmark and raise the error
......@@ -34,25 +34,23 @@ track_tracebacks: True
# The ratio of test examples/number of train examples
split: 0.35
# The nubmer of folds in the cross validation process when hyper-paramter optimization is performed
nb_folds: 2
nb_folds: 5
# The number of classes to select in the dataset
nb_class:
# The name of the classes to select in the dataset
classes:
# The type of algorithms to run during the benchmark (monoview and/or multiview)
type: ["monoview","multiview"]
type: ["monoview"]
# The name of the monoview algorithms to run, ["all"] to run all the available classifiers
algos_monoview: ["decision_tree"]
# The names of the multiview algorithms to run, ["all"] to run all the available classifiers
algos_multiview: ["weighted_linear_late_fusion",]
algos_multiview: []
# The number of times the benchamrk is repeated with different train/test
# split, to have more statistically significant results
stats_iter: 1
stats_iter: 5
# The metrics that will be use din the result analysis
metrics:
accuracy_score: {}
f1_score:
average: "micro"
# The metric that will be used in the hyper-parameter optimization process
metric_princ: "accuracy_score"
# The type of hyper-parameter optimization method
......@@ -65,14 +63,3 @@ hps_args: {}
decision_tree:
max_depth: 3
weighted_linear_early_fusion:
monoview_classifier_name: "decision_tree"
monoview_classifier_config:
decision_tree:
max_depth: 6
weighted_linear_late_fusion:
classifiers_names: "decision_tree"
classifier_configs:
decision_tree:
max_depth: 3
......@@ -2,13 +2,13 @@
import os
def execute(config_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "examples", "config_files", "config_example_1.yml")):
def execute(config_path=None):
from multiview_platform import versions as vs
vs.test_versions()
import sys
from multiview_platform.mono_multi_view_classifiers import exec_classif
if sys.argv[1:]:
if config_path is None:
exec_classif.exec_classif(sys.argv[1:])
else:
if config_path == "example 0":
......
......@@ -26,7 +26,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
----------
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
n_estimators : int number of estimators
......
......@@ -15,7 +15,7 @@ class RandomForest(RandomForestClassifier, BaseMonoviewClassifier):
Parameters
----------
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
n_estimators : int (default : 10) number of estimators
......
......@@ -15,7 +15,7 @@ class SGD(SGDClassifier, BaseMonoviewClassifier):
Parameters
----------
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
loss : str , (default = "hinge")
......
......@@ -15,7 +15,7 @@ class SVMLinear(SVCClassifier, BaseMonoviewClassifier):
Parameters
----------
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
......
......@@ -17,7 +17,7 @@ class SVMPoly(SVCClassifier, BaseMonoviewClassifier):
Parameters
----------
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
......
......@@ -16,7 +16,7 @@ class SVMRBF(SVCClassifier, BaseMonoviewClassifier):
Parameters
----------
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
C :
......
......@@ -152,7 +152,7 @@ def exec_multiview_multicore(directory, core_index, name, learning_rate,
labels_dictionary
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
labels :
......@@ -217,7 +217,7 @@ def exec_multiview(directory, dataset_var, name, classification_indices,
labels_dictionary : dict dictionary of labels
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
labels
......
......@@ -20,7 +20,7 @@ class BaseMultiviewClassifier(BaseClassifier):
Parameters
----------
random_state : int seed, RandomState instance, or None (default=None)
The seed of the pseudo random number generator to use when
The seed of the pseudo random number multiview_generator to use when
shuffling the data.
"""
......
......@@ -226,7 +226,7 @@ def plot_2d(data, classifiers_names, nb_classifiers, file_name, labels=None,
fig.update_layout(paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)')
fig.update_xaxes(showticklabels=True, )
plotly.offline.plot(fig, filename=file_name + "err.html",
plotly.offline.plot(fig, filename=file_name + "error_analysis_2D.html",
auto_open=False)
del fig
......
......@@ -153,9 +153,9 @@ def init_log_file(name, views, cl_type, log, debug, label,
"""
if views is None:
views = []
result_directory = os.path.join(os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
result_directory)
# result_directory = os.path.join(os.path.dirname(
# os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
# result_directory)
if debug:
result_directory = os.path.join(result_directory, name,
"debug_started_" + time.strftime(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment