diff --git a/Code/MonoMultiViewClassifiers/MonoviewClassifiers/RandomForest.py b/Code/MonoMultiViewClassifiers/MonoviewClassifiers/RandomForest.py index 2d29ba19af94489728a161b9e36e9196c6f22006..66caa2d4c30baeaff8da66a9ae07651168e1ef0a 100644 --- a/Code/MonoMultiViewClassifiers/MonoviewClassifiers/RandomForest.py +++ b/Code/MonoMultiViewClassifiers/MonoviewClassifiers/RandomForest.py @@ -3,6 +3,7 @@ from sklearn.pipeline import Pipeline from sklearn.model_selection import RandomizedSearchCV from scipy.stats import randint import numpy as np +import cPickle from .. import Metrics from ..utils.HyperParameterSearch import genHeatMaps @@ -89,4 +90,18 @@ def getConfig(config): def getInterpret(classifier, directory): - pass + featureImportances = classifier.feature_importances_ + sortedArgs = np.argsort(-featureImportances) + featureImportancesSorted = featureImportances[sortedArgs][:50] + featureIndicesSorted = sortedArgs[:50] + featuresImportancesDict = dict((featureIndex, featureImportance) + for featureIndex, featureImportance in enumerate(featureImportances) + if featureImportance != 0) + with open(directory+'-feature_importances.pickle', 'wb') as handle: + cPickle.dump(featuresImportancesDict, handle) + interpretString = "Feature importances : \n" + for featureIndex, featureImportance in zip(featureIndicesSorted, featureImportancesSorted): + if featureImportance>0: + interpretString+="- Feature index : "+str(featureIndex)+ \ + ", feature importance : "+str(featureImportance)+"\n" + return interpretString