From 3173c5aba135365339ab1f6a35f68f813c402041 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin.1@ulaval.ca>
Date: Fri, 19 Jan 2018 20:04:07 +0100
Subject: [PATCH] Added arguments for FLF

---
 .../FatLateFusion/FatLateFusionModule.py                  | 8 ++++++--
 .../Fusion/Methods/LateFusionPackage/BayesianInference.py | 2 +-
 Code/MonoMultiViewClassifiers/utils/execution.py          | 8 +++++++-
 3 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/Code/MonoMultiViewClassifiers/MultiviewClassifiers/FatLateFusion/FatLateFusionModule.py b/Code/MonoMultiViewClassifiers/MultiviewClassifiers/FatLateFusion/FatLateFusionModule.py
index 14c8ff09..ad57342b 100644
--- a/Code/MonoMultiViewClassifiers/MultiviewClassifiers/FatLateFusion/FatLateFusionModule.py
+++ b/Code/MonoMultiViewClassifiers/MultiviewClassifiers/FatLateFusion/FatLateFusionModule.py
@@ -19,7 +19,8 @@ def getArgs(args, benchmark, views, viewsIndices, randomState, directory, result
                  "NB_CLASS": len(args.CL_classes),
                  "LABELS_NAMES": args.CL_classes,
                  "FatLateFusionKWARGS": {
-                     "monoviewDecisions": monoviewDecisions
+                     "monoviewDecisions": monoviewDecisions,
+                     "weights": args.FLF_weights
                  }
                  }
     argumentsList.append(arguments)
@@ -36,7 +37,10 @@ def genParamsSets(classificationKWARGS, randomState, nIter=1):
 class FatLateFusionClass:
 
     def __init__(self, randomState, NB_CORES=1, **kwargs):
-        self.weights = [1.0/len(["monoviewDecisions"]) for _ in range(len(["monoviewDecisions"]))]
+        if kwargs["weights"] == []:
+            self.weights = [1.0/len(["monoviewDecisions"]) for _ in range(len(["monoviewDecisions"]))]
+        else:
+            self.weights = np.array(kwargs["weights"])/np.sum(np.array(kwargs["weights"]))
         self.monoviewDecisions = kwargs["monoviewDecisions"]
 
     def setParams(self, paramsSet):
diff --git a/Code/MonoMultiViewClassifiers/MultiviewClassifiers/Fusion/Methods/LateFusionPackage/BayesianInference.py b/Code/MonoMultiViewClassifiers/MultiviewClassifiers/Fusion/Methods/LateFusionPackage/BayesianInference.py
index 25a5968d..a040afb2 100644
--- a/Code/MonoMultiViewClassifiers/MultiviewClassifiers/Fusion/Methods/LateFusionPackage/BayesianInference.py
+++ b/Code/MonoMultiViewClassifiers/MultiviewClassifiers/Fusion/Methods/LateFusionPackage/BayesianInference.py
@@ -59,7 +59,7 @@ class BayesianInference(LateFusionClassifier):
                                       NB_CORES=NB_CORES)
 
         if kwargs['fusionMethodConfig'][0] is None or kwargs['fusionMethodConfig'] == ['']:
-            self.weights = np.array([1.0 for classifier in kwargs['classifiersNames']])
+            self.weights = np.array([1.0 for _ in kwargs['classifiersNames']])
         else:
             self.weights = np.array(map(float, kwargs['fusionMethodConfig'][0]))
         self.needProbas = True
diff --git a/Code/MonoMultiViewClassifiers/utils/execution.py b/Code/MonoMultiViewClassifiers/utils/execution.py
index 325b05a5..e8d2d0e1 100644
--- a/Code/MonoMultiViewClassifiers/utils/execution.py
+++ b/Code/MonoMultiViewClassifiers/utils/execution.py
@@ -178,7 +178,7 @@ def parseTheArgs(arguments):
     groupEarlyFusion.add_argument('--FU_E_cl_names', metavar='STRING', action='store', nargs='+',
                                   help='Name of the classifiers used for each early fusion method', default=[''])
 
-    groupLateFusion = parser.add_argument_group('Late Early Fusion arguments')
+    groupLateFusion = parser.add_argument_group('Late Fusion arguments')
     groupLateFusion.add_argument('--FU_late_methods', metavar='STRING', action='store', nargs="+",
                                  help='Determine which late fusion method of fusion to use',
                                  default=[''])
@@ -191,6 +191,12 @@ def parseTheArgs(arguments):
     groupLateFusion.add_argument('--FU_L_select_monoview', metavar='STRING', action='store',
                                  help='Determine which method to use to select the monoview classifiers',
                                  default="intersect")
+
+    groupFatLateFusion = parser.add_argument_group('Fat Late Fusion arguments')
+    groupFatLateFusion.add_argument('--FLF_weights', metavar='FLOAT', action='store', nargs="+",
+                                 help='Determine which late fusion method of fusion to use', type=float,
+                                 default=[])
+
     args = parser.parse_args(arguments)
     return args
 
-- 
GitLab