Skip to content
Snippets Groups Projects
Commit b279acb5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added graphs to vizualize feature importances

parent ddc12c13
Branches
No related tags found
No related merge requests found
...@@ -4,10 +4,13 @@ from sklearn.model_selection import RandomizedSearchCV ...@@ -4,10 +4,13 @@ from sklearn.model_selection import RandomizedSearchCV
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from scipy.stats import randint from scipy.stats import randint
import numpy as np import numpy as np
import cPickle # import cPickle
# import matplotlib.pyplot as plt
# from matplotlib.ticker import FuncFormatter
from .. import Metrics from .. import Metrics
from ..utils.HyperParameterSearch import genHeatMaps from ..utils.HyperParameterSearch import genHeatMaps
from ..utils.Interpret import getFeatureImportance
# Author-Info # Author-Info
__author__ = "Baptiste Bauvin" __author__ = "Baptiste Bauvin"
...@@ -82,18 +85,5 @@ def getConfig(config): ...@@ -82,18 +85,5 @@ def getConfig(config):
config["1"]) config["1"])
def getInterpret(classifier, directory): def getInterpret(classifier, directory):
featureImportances = classifier.feature_importances_ interpretString = getFeatureImportance(classifier, directory)
sortedArgs = np.argsort(-featureImportances)
featureImportancesSorted = featureImportances[sortedArgs][:50]
featureIndicesSorted = sortedArgs[:50]
featuresImportancesDict = dict((featureIndex, featureImportance)
for featureIndex, featureImportance in enumerate(featureImportances)
if featureImportance != 0)
with open(directory+'-feature_importances.pickle', 'wb') as handle:
cPickle.dump(featuresImportancesDict, handle)
interpretString = "Feature importances : \n"
for featureIndex, featureImportance in zip(featureIndicesSorted, featureImportancesSorted):
if featureImportance>0:
interpretString+="- Feature index : "+str(featureIndex)+\
", feature importance : "+str(featureImportance)+"\n"
return interpretString return interpretString
\ No newline at end of file
...@@ -4,10 +4,11 @@ from sklearn.model_selection import RandomizedSearchCV ...@@ -4,10 +4,11 @@ from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint from scipy.stats import randint
import numpy as np import numpy as np
import graphviz import graphviz
import cPickle # import cPickle
from .. import Metrics from .. import Metrics
from ..utils.HyperParameterSearch import genHeatMaps from ..utils.HyperParameterSearch import genHeatMaps
from ..utils.Interpret import getFeatureImportance
# Author-Info # Author-Info
__author__ = "Baptiste Bauvin" __author__ = "Baptiste Bauvin"
...@@ -92,18 +93,5 @@ def getInterpret(classifier, directory): ...@@ -92,18 +93,5 @@ def getInterpret(classifier, directory):
dot_data = tree.export_graphviz(classifier, out_file=None) dot_data = tree.export_graphviz(classifier, out_file=None)
graph = graphviz.Source(dot_data) graph = graphviz.Source(dot_data)
graph.render(directory+"-tree.pdf") graph.render(directory+"-tree.pdf")
featureImportances = classifier.feature_importances_ interpretString = getFeatureImportance(classifier, directory)
sortedArgs = np.argsort(-featureImportances)
featureImportancesSorted = featureImportances[sortedArgs][:50]
featureIndicesSorted = sortedArgs[:50]
featuresImportancesDict = dict((featureIndex, featureImportance)
for featureIndex, featureImportance in enumerate(featureImportances)
if featureImportance != 0)
with open(directory + '-feature_importances.pickle', 'wb') as handle:
cPickle.dump(featuresImportancesDict, handle)
interpretString = "Feature importances : \n"
for featureIndex, featureImportance in zip(featureIndicesSorted, featureImportancesSorted):
if featureImportance > 0:
interpretString += "- Feature index : " + str(featureIndex) + \
", feature importance : " + str(featureImportance) + "\n"
return interpretString return interpretString
\ No newline at end of file
...@@ -3,10 +3,11 @@ from sklearn.pipeline import Pipeline ...@@ -3,10 +3,11 @@ from sklearn.pipeline import Pipeline
from sklearn.model_selection import RandomizedSearchCV from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint from scipy.stats import randint
import numpy as np import numpy as np
import cPickle # import cPickle
from .. import Metrics from .. import Metrics
from ..utils.HyperParameterSearch import genHeatMaps from ..utils.HyperParameterSearch import genHeatMaps
from ..utils.Interpret import getFeatureImportance
# Author-Info # Author-Info
__author__ = "Baptiste Bauvin" __author__ = "Baptiste Bauvin"
...@@ -90,18 +91,5 @@ def getConfig(config): ...@@ -90,18 +91,5 @@ def getConfig(config):
def getInterpret(classifier, directory): def getInterpret(classifier, directory):
featureImportances = classifier.feature_importances_ interpretString = getFeatureImportance(classifier, directory)
sortedArgs = np.argsort(-featureImportances)
featureImportancesSorted = featureImportances[sortedArgs][:50]
featureIndicesSorted = sortedArgs[:50]
featuresImportancesDict = dict((featureIndex, featureImportance)
for featureIndex, featureImportance in enumerate(featureImportances)
if featureImportance != 0)
with open(directory+'-feature_importances.pickle', 'wb') as handle:
cPickle.dump(featuresImportancesDict, handle)
interpretString = "Feature importances : \n"
for featureIndex, featureImportance in zip(featureIndicesSorted, featureImportancesSorted):
if featureImportance>0:
interpretString+="- Feature index : "+str(featureIndex)+ \
", feature importance : "+str(featureImportance)+"\n"
return interpretString return interpretString
...@@ -78,10 +78,6 @@ def paramsToSet(nIter, randomState): ...@@ -78,10 +78,6 @@ def paramsToSet(nIter, randomState):
return paramsSet return paramsSet
def getInterpret(classifier, directory):
return ""
def getKWARGS(kwargsList): def getKWARGS(kwargsList):
kwargsDict = {} kwargsDict = {}
for (kwargName, kwargValue) in kwargsList: for (kwargName, kwargValue) in kwargsList:
...@@ -134,3 +130,7 @@ def getConfig(config): ...@@ -134,3 +130,7 @@ def getConfig(config):
except: except:
return "\n\t\t- SCM with model_type: " + config["0"] + ", max_rules : " + str(config["1"]) + ", p : " + \ return "\n\t\t- SCM with model_type: " + config["0"] + ", max_rules : " + str(config["1"]) + ", p : " + \
str(config["2"]) str(config["2"])
def getInterpret(classifier, directory):
return ""
...@@ -91,4 +91,6 @@ def getConfig(config): ...@@ -91,4 +91,6 @@ def getConfig(config):
"1"] + ", alpha : " + str(config["2"]) "1"] + ", alpha : " + str(config["2"])
def getInterpret(classifier, directory): def getInterpret(classifier, directory):
# TODO : coeffs
return "" return ""
#
\ No newline at end of file
...@@ -75,4 +75,5 @@ def getConfig(config): ...@@ -75,4 +75,5 @@ def getConfig(config):
return "\n\t\t- SVM Linear with C : " + str(config["0"]) return "\n\t\t- SVM Linear with C : " + str(config["0"])
def getInterpret(classifier, directory): def getInterpret(classifier, directory):
# TODO : coeffs
return "" return ""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import cPickle
def percent(x, pos):
'The two args are the value and tick position'
return '%1.1f %%' % (x * 100)
def getFeatureImportance(classifier, directory, interpretString=""):
featureImportances = classifier.feature_importances_
sortedArgs = np.argsort(-featureImportances)
featureImportancesSorted = featureImportances[sortedArgs][:50]
featureIndicesSorted = sortedArgs[:50]
fig, ax = plt.subplots()
x = np.arange(50)
formatter = FuncFormatter(percent)
ax.yaxis.set_major_formatter(formatter)
plt.bar(x, featureImportancesSorted)
plt.title("Importance depending on feature")
fig.savefig(directory + "-feature_importances.png")
plt.close()
featuresImportancesDict = dict((featureIndex, featureImportance)
for featureIndex, featureImportance in enumerate(featureImportances)
if featureImportance != 0)
with open(directory+'-feature_importances.pickle', 'wb') as handle:
cPickle.dump(featuresImportancesDict, handle)
interpretString += "Feature importances : \n"
for featureIndex, featureImportance in zip(featureIndicesSorted, featureImportancesSorted):
if featureImportance>0:
interpretString+="- Feature index : "+str(featureIndex)+\
", feature importance : "+str(featureImportance)+"\n"
return interpretString
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment