From 834ffa8b0a8f432806eaccd0c040435f99b7a8ed Mon Sep 17 00:00:00 2001 From: Alain Riou <alain.riou14000@yahoo.com> Date: Wed, 17 Jan 2024 18:35:40 +0100 Subject: [PATCH] roll activations + better export --- pesto/model.py | 2 +- pesto/utils/export.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pesto/model.py b/pesto/model.py index 9fbaddf..b9cf5ed 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -215,7 +215,7 @@ class PESTO(nn.Module): if batch_size: activations = activations.view(batch_size, -1, activations.size(-1)) - activations = activations.roll(-self.shift.cpu().item(), dims=-1) + activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1) preds = reduce_activations(activations, reduction=self.reduction) diff --git a/pesto/utils/export.py b/pesto/utils/export.py index 05c7c5a..70aeb54 100644 --- a/pesto/utils/export.py +++ b/pesto/utils/export.py @@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1 activations = activations * confidence[:, None] plt.imshow(activations.T, aspect='auto', origin='lower', cmap='inferno', - extent=(timesteps[0], timesteps[-1]) + lims) + extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims) + plt.xlabel("Time (s)") + plt.ylabel("Pitch (semitones)") plt.title(output_file.rsplit('.', 2)[0]) plt.tight_layout() plt.savefig(output_file) -- GitLab