Select Git revision
plot_usecase_exampleMumBo.py.md5
LateFusion.py 2.29 KiB
#!/usr/bin/env python
# -*- encoding: utf-8
import numpy as np
from joblib import Parallel, delayed
from sklearn.multiclass import OneVsOneClassifier
from sklearn.svm import SVC
import MonoviewClassifiers
from utils.Dataset import getV
def fitMonoviewClassifier(classifierName, data, labels, classifierConfig, needProbas):
monoviewClassifier = getattr(MonoviewClassifiers, classifierName)
if needProbas and not monoviewClassifier.canProbas():
monoviewClassifier = getattr(MonoviewClassifiers, "DecisionTree")
classifier = monoviewClassifier.fit(data,labels,**dict((str(configIndex), config) for configIndex, config in
enumerate(classifierConfig
)))
return classifier
def getAccuracies(LateFusionClassifiers):
return ""
def Intersect(resMono):
pass
class LateFusionClassifier(object):
def __init__(self, randomState, monoviewClassifiersNames, monoviewClassifiersConfigs, monoviewSelection, NB_CORES=1):
self.monoviewClassifiersNames = monoviewClassifiersNames
self.monoviewClassifiersConfigs = monoviewClassifiersConfigs
self.monoviewClassifiers = []
self.nbCores = NB_CORES
self.accuracies = np.zeros(len(monoviewClassifiersNames))
self.needProbas = False
self.monoviewSelection = monoviewSelection
self.randomState = randomState
def fit_hdf5(self, DATASET, trainIndices=None, viewsIndices=None):
if type(viewsIndices)==type(None):
viewsIndices = np.arange(DATASET.get("Metadata").attrs["nbView"])
if trainIndices == None:
trainIndices = range(DATASET.get("Metadata").attrs["datasetLength"])
monoviewSelectionMethod = locals()[self.monoviewSelection]
self.monoviewClassifiers = monoviewSelectionMethod()
self.monoviewClassifiers = Parallel(n_jobs=self.nbCores)(
delayed(fitMonoviewClassifier)(self.monoviewClassifiersNames[index],
getV(DATASET, viewIndex, trainIndices),
DATASET.get("Labels")[trainIndices],
self.monoviewClassifiersConfigs[index], self.needProbas)
for index, viewIndex in enumerate(viewsIndices))