Skip to content
Snippets Groups Projects
Commit eefd7b6b authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

- Fix weight_density. TODO: add legend (colored by seed)

- Add a TODO for the subset in train
- Remove an useless new line
parent 553764d7
No related branches found
2 merge requests!5Resolve "Add plots",!3clean scripts
...@@ -30,6 +30,8 @@ class Trainer(object): ...@@ -30,6 +30,8 @@ class Trainer(object):
y_omp = y_forest y_omp = y_forest
self._logger.debug('Fitting both the forest and OMP on train+dev subsets.') self._logger.debug('Fitting both the forest and OMP on train+dev subsets.')
# TODO: add an option to train forest to train+dev and OMP to dev
model.fit( model.fit(
X_forest=X_forest, X_forest=X_forest,
y_forest=y_forest, y_forest=y_forest,
......
...@@ -12,14 +12,19 @@ class Plotter(object): ...@@ -12,14 +12,19 @@ class Plotter(object):
:param all_experiment_weights: The weights for the different experiments :param all_experiment_weights: The weights for the different experiments
:param file path: str, path where the figure will be saved :param file path: str, path where the figure will be saved
''' '''
all_experiment_weights = np.array(list(all_experiment_weights.values()))
n = len(all_experiment_weights)
colors = Plotter.get_colors_from_cmap(n)
fig, ax = plt.subplots() fig, ax = plt.subplots()
for weights in all_experiment_weights.values(): for i in range(n):
pd.Series([weight for weight in weights if weight != 0]).plot.kde(figsize=(15, 10), ax=ax) for weights in all_experiment_weights[i]:
pd.Series([weight for weight in weights if weight != 0]).plot.kde(
figsize=(15, 10), ax=ax, color=colors[i])
legends = ['Experience ' + str(i+1) for i in range(len(all_experiment_weights))] ax.set_title('Density weights of the OMP')
ax.legend(legends)
fig.savefig(file_path, dpi=fig.dpi) fig.savefig(file_path, dpi=fig.dpi)
fig.title('Density weights of the OMP')
plt.close(fig) plt.close(fig)
@staticmethod @staticmethod
......
...@@ -117,7 +117,6 @@ if __name__ == "__main__": ...@@ -117,7 +117,6 @@ if __name__ == "__main__":
) )
# Plot the density of the weights # Plot the density of the weights
Plotter.weight_density( Plotter.weight_density(
file_path=args.results_dir + os.sep + experiment_id + os.sep + 'density_weight.png', file_path=args.results_dir + os.sep + experiment_id + os.sep + 'density_weight.png',
all_experiment_weights=experiment_weights all_experiment_weights=experiment_weights
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment