From 834ffa8b0a8f432806eaccd0c040435f99b7a8ed Mon Sep 17 00:00:00 2001
From: Alain Riou <alain.riou14000@yahoo.com>
Date: Wed, 17 Jan 2024 18:35:40 +0100
Subject: [PATCH] roll activations + better export

---
 pesto/model.py        | 2 +-
 pesto/utils/export.py | 4 +++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/pesto/model.py b/pesto/model.py
index 9fbaddf..b9cf5ed 100644
--- a/pesto/model.py
+++ b/pesto/model.py
@@ -215,7 +215,7 @@ class PESTO(nn.Module):
         if batch_size:
             activations = activations.view(batch_size, -1, activations.size(-1))
 
-        activations = activations.roll(-self.shift.cpu().item(), dims=-1)
+        activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1)
 
         preds = reduce_activations(activations, reduction=self.reduction)
 
diff --git a/pesto/utils/export.py b/pesto/utils/export.py
index 05c7c5a..70aeb54 100644
--- a/pesto/utils/export.py
+++ b/pesto/utils/export.py
@@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1
     activations = activations * confidence[:, None]
     plt.imshow(activations.T,
                aspect='auto', origin='lower', cmap='inferno',
-               extent=(timesteps[0], timesteps[-1]) + lims)
+               extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims)
 
+    plt.xlabel("Time (s)")
+    plt.ylabel("Pitch (semitones)")
     plt.title(output_file.rsplit('.', 2)[0])
     plt.tight_layout()
     plt.savefig(output_file)
-- 
GitLab