-
- Downloads
Last changes'
%% Cell type:markdown id: tags: | ||
# Groupe de travail | ||
Le but de ce notebook est de tester l'idée de réduction des random forest | ||
%% Cell type:markdown id: tags: | ||
## Import scikit-learn | ||
%% Cell type:code id: tags: | ||
``` python | ||
from statistics import mean | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
from sklearn.datasets import load_boston, load_breast_cancer, fetch_california_housing | ||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | ||
from sklearn.linear_model import OrthogonalMatchingPursuit, OrthogonalMatchingPursuitCV | ||
from sklearn.metrics import mean_squared_error | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.neighbors.kde import KernelDensity | ||
``` | ||
%% Cell type:markdown id: tags: | ||
## Variables globales | ||
%% Cell type:code id: tags: | ||
``` python | ||
NB_TREES = 100 | ||
``` | ||
%% Cell type:markdown id: tags: | ||
## Load jeu de donnée | ||
%% Cell type:code id: tags: | ||
``` python | ||
X, y = fetch_california_housing(return_X_y=True) | ||
``` | ||
%% Output | ||
Downloading Cal. housing from https://ndownloader.figshare.com/files/5976036 to /home/l_bouscarrat/scikit_learn_data | ||
%% Cell type:code id: tags: | ||
``` python | ||
def train_forest(X_train, y_train, nb_trees, random_seed): | ||
''' | ||
Function that will train a random forest with nb_tress | ||
:param X_train: list of inputs | ||
:param y_train: list of results | ||
:param nb_trees: int, number of trees in the forest | ||
:param random_seed: int, seed for the random_states | ||
:return: a RandomForestRegressor | ||
''' | ||
# Entraînement de la forêt aléatoire | ||
regressor = RandomForestRegressor(n_estimators=nb_trees, random_state = random_seed) | ||
regressor.fit(X_train, y_train) | ||
return regressor | ||
def extract_subforest(random_forest, X_train, y_train, nb_trees_extracted): | ||
''' | ||
Function use to get the weight list of a subforest of size nb_trees_extracted for random_forest | ||
using OMP. | ||
:param random_forest: a RandomForestRegressor | ||
:param X_train: list of inputs | ||
:param y_train: list of results | ||
:param nb_trees_extracted: int, number of trees extracted | ||
:return: a list of int, weight of each tree | ||
''' | ||
# Accès à la la liste des arbres | ||
tree_list = random_forest.estimators_ | ||
# Création de la matrice des prédictions de chaque arbre | ||
# L'implémentation de scikit-learn est un peu différente que celle vue en réunion, D est de même taille que X | ||
# et chaque élément est composé de d signaux, d'où la création suivante de D où on créé une liste pour chaque | ||
# élément comprenant les valeurs prédites par chaque arbre | ||
D = [[tree.predict([elem])[0] for tree in tree_list] for elem in X_train] | ||
# OMP | ||
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=nb_trees_extracted, fit_intercept = False, normalize=False) | ||
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=nb_trees_extracted, fit_intercept = False, normalize = False) | ||
omp.fit(D, y_train) | ||
weights = omp.coef_ | ||
return weights | ||
def compute_results(weights, random_forest, X_train, X_dev, X_test, y_train, y_dev, y_test, | ||
nb_trees, nb_trees_extracted, random_seed): | ||
''' | ||
Compute the score of the different techniques | ||
:param weights: weights given by the OMP | ||
:param random_forest: a RandomForestRegressor | ||
:param X_train: list of inputs | ||
:param X_dev: list of inputs | ||
:param X_test: list of inputs | ||
:param y_train: list of results | ||
:param y_dev: list of results | ||
:param y_test: list of results | ||
:param nb_trees: int, number of trees in the main forest | ||
:param nb_trees_extracted: int, number of trees extracted from the main forest | ||
:param random_seed: int, seed for the random_states | ||
:return: 4 results of 4 different methods, in order: results of the main forest, | ||
results of the weighted results of the extracted trees, results of the mean results | ||
of the extracted trees, results of a random_forest train with nb_trees_extracted directly | ||
''' | ||
# Calcul des différents résultats | ||
res_base_forest = mean_squared_error(random_forest.predict(X_test), y_test) | ||
# Résultat de la forêt extraite avec l'OMP, où chaque arbre est multiplié par son poids | ||
y_pred = [sum([random_forest.estimators_[i].predict([elem])[0] * weights[i] for i in range(nb_trees)]) | ||
for elem in X_test] | ||
res_extract_weight = mean_squared_error(y_pred, y_test) | ||
# Résultat de la forêt extraite avec l'OMP, où chaque arbre est multiplié par son poids | ||
y_pred = [sum([random_forest.estimators_[i].predict([elem])[0] * weights[i] for i in range(nb_trees)])/sum(weights) | ||
for elem in X_test] | ||
res_extract_weight_norm = mean_squared_error(y_pred, y_test) | ||
# Résultat de la forêt extraite avec l'OMP, où on prends la moyenne des arbres extraits | ||
y_pred = [mean([random_forest.estimators_[i].predict([elem])[0] for i in range(nb_trees) if abs(weights[i]) >= 0.01]) | ||
for elem in X_test] | ||
res_extract_mean = mean_squared_error(y_pred, y_test) | ||
# Résultat d'une forêt avec le même nombre d'arbre que le nombre d'arbre extrait | ||
small_forest = train_forest(np.concatenate((X_train, X_dev)), np.concatenate((y_train, y_dev)), nb_trees_extracted, random_seed) | ||
res_small_forest = mean_squared_error(small_forest.predict(X_test), y_test) | ||
return res_base_forest, res_extract_weight, res_extract_weight_norm, res_extract_mean, res_small_forest, weights | ||
def extract_and_get_results(random_forest, X_train, X_dev, X_test, y_train, y_dev, y_test, nb_trees, | ||
nb_trees_extracted, random_seed): | ||
''' | ||
Extract the subforest and returns the resuts of the different methods | ||
:param X_train: list of inputs | ||
:param X_dev: list of inputs | ||
:param X_test: list of inputs | ||
:param y_train: list of results | ||
:param y_dev: list of results | ||
:param y_test: list of results | ||
:param nb_trees: int, number of trees in the main forest | ||
:param nb_trees_extracted: int, number of trees extracted from the main forest | ||
:param random_seed: int, seed for the random_states | ||
:return: 4 results of 4 different methods, in order: results of the main forest, | ||
results of the weighted results of the extracted trees, results of the mean results | ||
of the extracted trees, results of a random_forest train with nb_trees_extracted directly | ||
''' | ||
weights = extract_subforest(random_forest, X_dev, y_dev, nb_trees_extracted) | ||
res_base_forest, res_extract_weight, res_extract_weight_norm, res_extract_mean, res_small_forest = \ | ||
compute_results(weights, random_forest, X_train, X_dev, X_test, y_train, y_dev, y_test, | ||
nb_trees, nb_trees_extracted, random_seed) | ||
return res_base_forest, res_extract_weight, res_extract_weight_norm, res_extract_mean, res_small_forest, weights | ||
def train_extract_subforest(X_train, X_test, y_train, y_test, nb_trees, nb_trees_extracted, random_seed): | ||
''' | ||
Function that takes data, number of trees and a random seed. Train a forest with nb_trees, extract | ||
with OMP nb_trees_extracted and compare the results of the different method | ||
:param X_train: list of inputs | ||
:param X_test: list of inputs | ||
:param y_train: list of results | ||
:param y_test: list of results | ||
:param nb_trees: int, number of trees in the main forest | ||
:param nb_trees_extracted: int, number of trees extracted from the main forest | ||
:param random_seed: int, seed for the random_states | ||
:return: 4 results of 4 different methods, in order: results of the main forest, | ||
results of the weighted results of the extracted trees, results of the mean results | ||
of the extracted trees, results of a random_forest train with nb_trees_extracted directly | ||
''' | ||
random_forest = train_forest(X_train, y_train, nb_trees, random_seed) | ||
weight = extract_subforest(random_forest, X_train, y_train, nb_trees_extracted) | ||
res_base_forest, res_extract_weight, res_extract_mean, res_small_forest = \ | ||
compute_results(weight, random_forest, X_train, X_test, y_train, y_test, | ||
nb_trees, nb_trees_extracted, random_seed) | ||
return res_base_forest, res_extract_weight, res_extract_mean, res_small_forest | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
results_global = [] | ||
results_dev_global = [] | ||
results_without_dev_global = [] | ||
nb_trees = 100 | ||
random_seeds = list(range(10)) | ||
for random_seed in random_seeds: | ||
# Séparation train_test avec random_state | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = random_seed) | ||
X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, test_size = 0.2, random_state = random_seed) | ||
random_forest = train_forest(X_train, y_train, NB_TREES, random_seed) | ||
results = [] | ||
results_dev = [] | ||
results_without_dev = [] | ||
for nb_trees_extracted in [int(NB_TREES/k) for k in [2, 5, 10, 20, 50, 100]]: | ||
weights = extract_subforest(random_forest, X_dev, y_dev, nb_trees_extracted) | ||
weights_train = extract_subforest(random_forest, X_train, y_train, nb_trees_extracted) | ||
results.append(compute_results(weights, random_forest, X_train, X_dev, X_test, y_train, y_dev, y_test, | ||
nb_trees, nb_trees_extracted, random_seed)) | ||
results_without_dev.append(compute_results(weights_train, random_forest, X_train, X_train, | ||
X_test, y_train, y_train, y_test, | ||
nb_trees, nb_trees_extracted, random_seed) | ||
) | ||
results_dev.append(compute_results(weights, random_forest, X_train, X_dev, X_dev, y_train, y_dev, y_dev, | ||
nb_trees, nb_trees_extracted, random_seed)) | ||
results_global.append(results) | ||
results_dev_global.append(results_dev) | ||
results_without_dev_global.append(results_without_dev) | ||
print('over') | ||
``` | ||
%% Output | ||
over | ||
over | ||
over | ||
over | ||
over | ||
over | ||
over | ||
%% Cell type:code id: tags: | ||
``` python | ||
def plot_results(results_global, title_graph): | ||
def plot_mean_and_CI(mean, lb, ub, x_value, color_mean=None, color_shading=None, label=None): | ||
# plot the shaded range of the confidence intervals | ||
plt.fill_between(x_value, ub, lb, | ||
color=color_shading, alpha=.5) | ||
# plot the mean on top | ||
plt.plot(x_value, mean, color_mean, label = label) | ||
means_results = np.array( | ||
[ | ||
[mean( | ||
[results[i][k] for results in results_global] # loop over the different experiments | ||
) for i in range(len(results_global[0]))] # loop over the different number of trees extracted | ||
for k in range(5)]) # loop over the different methods | ||
std_results = np.array( | ||
[ | ||
[np.std( | ||
[results[i][k] for results in results_global] | ||
) for i in range(len(results_global[0]))] | ||
for k in range(5)]) | ||
x_value = [int(NB_TREES/k) for k in [2, 5, 10, 20, 50, 100]] | ||
# plot the data | ||
fig = plt.figure(1, figsize=(15, 10)) | ||
plot_mean_and_CI(means_results[0], means_results[0] + std_results[0], means_results[0] - std_results[0], | ||
x_value, color_mean='k', color_shading='k', label='Results of the base forest (on train set)') | ||
plot_mean_and_CI(means_results[1], means_results[1] + std_results[1], means_results[1] - std_results[1], | ||
x_value, color_mean='darkorange', color_shading='darkorange', | ||
label='Weighted results of the extracted trees') | ||
plot_mean_and_CI(means_results[2], means_results[2] + std_results[2], means_results[2] - std_results[2], | ||
x_value, color_mean='red', color_shading='red', | ||
label='Weighted results of the extracted trees normalized') | ||
plot_mean_and_CI(means_results[3], means_results[3] + std_results[3], means_results[3] - std_results[3], | ||
x_value, color_mean='b', color_shading='b', | ||
label='Mean results of the extracted trees') | ||
plot_mean_and_CI(means_results[4], means_results[4] + std_results[4], means_results[4] - std_results[4], | ||
x_value, color_mean='g', color_shading='g', | ||
label='Results of a forest train with number of trees extracted (train+dev set)') | ||
plt.xlabel('Number of trees extracted') | ||
plt.ylabel('MSE') | ||
plt.title(title_graph) | ||
plt.legend(loc="upper right") | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
plot_results(results_global, 'Reduction of a forest with 100 trees, 10 iterations with different seed, score on train set') | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
plot_results(results_dev_global, 'Reduction of a forest with 100 trees, 10 iterations with different seed, score on dev set') | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
plot_results(results_without_dev, | ||
'Reduction of a forest with 100 trees, 10 iterations with different seed, score when there is no dev set') | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
for results in results_global: | ||
x_value = [int(NB_TREES/k) for k in [5, 10, 50, 100, 500, 1000]] | ||
plt.xlabel('Number of trees extracted') | ||
plt.ylabel('MSE') | ||
plt.plot(x_value, [elem[1] for elem in results], color='darkorange', | ||
label='Weighted results of the average trees') | ||
plt.plot(x_value, [elem[2] for elem in results], color='red', | ||
label='Weighted results of the average trees normalized') | ||
plt.plot(x_value, [elem[3] for elem in results], color='blue', | ||
label='Mean results of the average trees') | ||
plt.plot(x_value, [elem[4] for elem in results], color='green', | ||
label='Results of a forest train with number of trees extracted') | ||
plt.plot(x_value, [elem[0] for elem in results], color='black', | ||
label='Results of the base forest') | ||
plt.figure(1, figsize=(15, 10)) | ||
plt.legend(loc="upper right") | ||
fig_acc_rec = plt.gcf() | ||
plt.show() | ||
``` | ||
%% Output | ||
%% Cell type:code id: tags: | ||
``` python | ||
def weight_density(list_weight): | ||
print(list_weight) | ||
X_plot = [np.exp(elem) for elem in list_weight] | ||
fig, ax = plt.subplots() | ||
for kernel in ['gaussian', 'tophat', 'epanechnikov']: | ||
kde = KernelDensity(kernel=kernel, bandwidth=0.5).fit(X_plot) | ||
log_dens = kde.score_samples(X_plot) | ||
ax.plot(X_plot[:, 0], np.exp(log_dens), '-', | ||
label="kernel = '{0}'".format(kernel)) | ||
ax.legend(loc='upper left') | ||
ax.plot(X[:, 0], -0.005 - 0.01 * np.random.random(X.shape[0]), '+k') | ||
ax.set_xlim(-4, 9) | ||
ax.set_ylim(-0.02, 0.4) | ||
plt.show() | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
for results in results_global: | ||
ax = pd.Series([[e for e in test[5] if e != 0] for test in results][1]).plot.kde(figsize=(15, 10)) | ||
legends = ['Experience '+ str(i+1) for i in range(10)] | ||
ax.legend(legends) | ||
``` | ||
%% Output | ||
<matplotlib.legend.Legend at 0x7f1437754c10> | ||
%% Cell type:code id: tags: | ||
``` python | ||
np.array( | ||
[ | ||
[ | ||
[results[i][k] for results in results_global] | ||
for i in range(len(results_global[0]))] | ||
for k in range(5)]) | ||
``` | ||
%% Output | ||
array([[[0.26899689, 0.26359377, 0.2780403 , 0.25029723, 0.26674508, | ||
0.25602716, 0.28057576, 0.2761758 , 0.25817293, 0.26356801], | ||
[0.26899689, 0.26359377, 0.2780403 , 0.25029723, 0.26674508, | ||
0.25602716, 0.28057576, 0.2761758 , 0.25817293, 0.26356801], | ||
[0.26899689, 0.26359377, 0.2780403 , 0.25029723, 0.26674508, | ||
0.25602716, 0.28057576, 0.2761758 , 0.25817293, 0.26356801], | ||
[0.26899689, 0.26359377, 0.2780403 , 0.25029723, 0.26674508, | ||
0.25602716, 0.28057576, 0.2761758 , 0.25817293, 0.26356801], | ||
[0.26899689, 0.26359377, 0.2780403 , 0.25029723, 0.26674508, | ||
0.25602716, 0.28057576, 0.2761758 , 0.25817293, 0.26356801], | ||
[0.26899689, 0.26359377, 0.2780403 , 0.25029723, 0.26674508, | ||
0.25602716, 0.28057576, 0.2761758 , 0.25817293, 0.26356801]], | ||
[[0.27542295, 0.27749768, 0.28513058, 0.26038702, 0.27043376, | ||
0.2655008 , 0.28448981, 0.28333658, 0.27387447, 0.2769381 ], | ||
[0.27746557, 0.27723817, 0.28723859, 0.26434651, 0.27067318, | ||
0.26330039, 0.28196962, 0.28240111, 0.27509222, 0.28088583], | ||
[0.28995364, 0.29198289, 0.29873153, 0.27618004, 0.2848853 , | ||
0.27857491, 0.29298835, 0.30077324, 0.2886711 , 0.28905086], | ||
[0.32365526, 0.32322906, 0.32710513, 0.29903915, 0.31318329, | ||
0.30669926, 0.32434317, 0.32110736, 0.31463418, 0.321466 ], | ||
[0.39986111, 0.42484653, 0.42855969, 0.39370378, 0.39935977, | ||
0.38460084, 0.41563938, 0.40814036, 0.39929003, 0.38932494], | ||
[0.58523066, 0.55891364, 0.59428021, 0.60547191, 0.5266932 , | ||
0.54086835, 0.57100958, 0.54292164, 0.53241884, 0.59593718]], | ||
[[0.27521601, 0.27770826, 0.28523181, 0.26038166, 0.27049563, | ||
0.26550442, 0.28434805, 0.28336574, 0.27317227, 0.27730912], | ||
[0.27744027, 0.27741863, 0.28736133, 0.26435666, 0.2708491 , | ||
0.26327205, 0.28195336, 0.28236196, 0.2744054 , 0.28125145], | ||
[0.29039242, 0.2925128 , 0.29908397, 0.27622541, 0.28540109, | ||
0.27863392, 0.2930088 , 0.30074197, 0.28805612, 0.28951146], | ||
[0.32484297, 0.32460791, 0.32853218, 0.29964306, 0.31475447, | ||
0.30729241, 0.3249432 , 0.32151156, 0.31435099, 0.32265734], | ||
[0.40624496, 0.4302276 , 0.43530539, 0.39867506, 0.40663919, | ||
0.38977503, 0.42116368, 0.41272401, 0.40457793, 0.39346472], | ||
[0.61110865, 0.57982481, 0.62469263, 0.63199171, 0.54630055, | ||
0.56558963, 0.59272349, 0.56177889, 0.55353829, 0.61516357]], | ||
[[0.27065035, 0.26652983, 0.27690332, 0.25221968, 0.26834736, | ||
0.25713719, 0.27970688, 0.27633639, 0.26457284, 0.26459536], | ||
[0.27745036, 0.27779328, 0.28618091, 0.26025249, 0.27057945, | ||
0.26101474, 0.28177424, 0.27897506, 0.27042052, 0.27522275], | ||
[0.29080491, 0.29277669, 0.29888537, 0.27686572, 0.28509987, | ||
0.2785929 , 0.29306583, 0.29749568, 0.28659403, 0.28920128], | ||
[0.32643403, 0.32525994, 0.32820348, 0.30076919, 0.31430383, | ||
0.30854091, 0.3247721 , 0.31945693, 0.31341696, 0.32260233], | ||
[0.40894501, 0.43023553, 0.43584108, 0.40142063, 0.40520717, | ||
0.39108446, 0.4205632 , 0.41085099, 0.40442852, 0.39712488], | ||
[0.61110865, 0.57982481, 0.62469263, 0.63199171, 0.54630055, | ||
0.56558963, 0.59272349, 0.56177889, 0.55353829, 0.61516357]], | ||
[[0.26184144, 0.25626252, 0.26511056, 0.24293248, 0.25853787, | ||
0.24899595, 0.27433988, 0.27443584, 0.24968665, 0.25521777], | ||
[0.27008137, 0.26487907, 0.27600019, 0.25217323, 0.26622961, | ||
0.25721161, 0.28795328, 0.28488014, 0.25274144, 0.26086461], | ||
[0.28205227, 0.2757543 , 0.29624245, 0.27094063, 0.29016011, | ||
0.27193868, 0.30978997, 0.29614998, 0.26827511, 0.27353369], | ||
[0.30842355, 0.30288144, 0.32913279, 0.30527809, 0.32101279, | ||
0.31426529, 0.3350261 , 0.33831256, 0.30012365, 0.30287159], | ||
[0.39608144, 0.38109562, 0.41662116, 0.40638002, 0.41791456, | ||
0.40641226, 0.44955332, 0.44545138, 0.40050687, 0.39659829], | ||
[0.52593732, 0.54251735, 0.57760869, 0.56082674, 0.56808121, | ||
0.55389761, 0.59152879, 0.62432776, 0.52053009, 0.54411424]]]) | ||
%% Cell type:code id: tags: | ||
``` python | ||
[[sum(elem[5]) for elem in results] for results in results_global] | ||
``` | ||
%% Output | ||
[[1.0019333893291256, | ||
1.0002339744254798, | ||
0.9965284128922761, | ||
0.9930020768164572, | ||
0.9713515521464255, | ||
0.9355587965148584], | ||
[0.9964172479701133, | ||
0.9965133913097529, | ||
0.9921861297905141, | ||
0.9842416540561515, | ||
0.9682575760922218, | ||
0.933691714442025], | ||
[0.9983332596273122, | ||
0.9980542966631237, | ||
0.9953384960946179, | ||
0.9863737744133908, | ||
0.9713936104999941, | ||
0.9322186005941735], | ||
[1.0038094772994455, | ||
1.0018194366490565, | ||
0.9977426702506628, | ||
0.9901515373986598, | ||
0.9656704215772347, | ||
0.9225711413699638], | ||
[0.9990610710125596, | ||
0.997516190665299, | ||
0.994208907140429, | ||
0.9866287246076226, | ||
0.9654844823617819, | ||
0.9349411984209011], | ||
[0.9988517681736824, | ||
0.9978719264484801, | ||
0.9951808165785492, | ||
0.9869717461401798, | ||
0.96626061534528, | ||
0.9275270848174226], | ||
[1.0036372042892556, | ||
1.0021281690286359, | ||
0.9992184310564347, | ||
0.9936213667248184, | ||
0.9726814533989055, | ||
0.9287908504721321], | ||
[1.0072738855021581, | ||
1.0058719271210204, | ||
1.0008091849171328, | ||
0.9924905567327407, | ||
0.9726743101033102, | ||
0.9357752656379753], | ||
[1.0124396297482838, | ||
1.0124081517188515, | ||
1.008883039483941, | ||
1.0028553033696677, | ||
0.9743542569764307, | ||
0.942597999044375], | ||
[0.9944072082781984, | ||
0.9943119815126942, | ||
0.9936514348889314, | ||
0.9866366329633924, | ||
0.9668343182393178, | ||
0.9188473811851859]] | ||
%% Cell type:code id: tags: | ||
``` python | ||
results_global[0] | ||
``` | ||
%% Cell type:markdown id: tags: | ||
## Entraînement de la forêt aléatoire | ||
%% Cell type:code id: tags: | ||
``` python | ||
regressor = RandomForestRegressor(n_estimators=NB_TREES, random_state = RANDOM_SEED) | ||
regressor.fit(X_train, y_train) | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
# Accès à la la liste des arbres | ||
tree_list = regressor.estimators_ | ||
``` | ||
%% Cell type:markdown id: tags: | ||
## Création de la matrice des prédictions de chaque arbre | ||
%% Cell type:code id: tags: | ||
``` python | ||
# L'implémentation de scikit-learn est un peu différente que celle vue en réunion, D est de même taille que X | ||
# et chaque élément est composé de d signaux, d'où la création suivante de D où on créé une liste pour chaque | ||
# élément comprenant les valeurs prédites par chaque arbre | ||
D = [[tree.predict([elem])[0] for tree in tree_list] for elem in X_train] | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=NB_TREES_EXTRACTED) | ||
omp.fit(D, y_train) | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
# Matrice avec poids de chaque arbre | ||
omp.coef_ | ||
``` | ||
%% Cell type:markdown id: tags: | ||
## Calcul des résultats des différentes méthodes | ||
%% Cell type:markdown id: tags: | ||
### Résultat de la forêt de base | ||
%% Cell type:code id: tags: | ||
``` python | ||
mean_squared_error(regressor.predict(X_test), y_test) | ||
``` | ||
%% Cell type:markdown id: tags: | ||
### Résultat de la forêt extraite avec l'OMP, où chaque arbre est multiplié par son poids | ||
%% Cell type:code id: tags: | ||
``` python | ||
y_pred = [sum([tree_list[i].predict([elem])[0] * omp.coef_[i] for i in range(NB_TREES)]) for elem in X_test] | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
mean_squared_error(y_pred, y_test) | ||
``` | ||
%% Cell type:markdown id: tags: | ||
### Résultat de la forêt extraite avec l'OMP, où on prends la moyenne des arbres extraits | ||
%% Cell type:code id: tags: | ||
``` python | ||
y_pred = [mean([tree_list[i].predict([elem])[0] for i in range(NB_TREES) if omp.coef_[i] != 0])for elem in X_test] | ||
mean_squared_error(y_pred, y_test) | ||
``` | ||
%% Cell type:markdown id: tags: | ||
### Résultat d'une forêt avec le même nombre d'arbre que le nombre d'arbre extrait | ||
%% Cell type:code id: tags: | ||
``` python | ||
regressor_small = RandomForestRegressor(n_estimators=NB_TREES_EXTRACTED, random_state=RANDOM_SEED) | ||
regressor_small.fit(X_train, y_train) | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
mean_squared_error(regressor_small.predict(X_test), y_test) | ||
``` | ||
%% Cell type:code id: tags: | ||
``` python | ||
``` | ||
... | ... |
Please register or sign in to comment