diff --git a/pesto/model.py b/pesto/model.py index 9fbaddf6e7323a2b9007ba7e5171f2e5b1989b11..b9cf5eda18349e38a2d4de1f739798bd5b1df0e6 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -215,7 +215,7 @@ class PESTO(nn.Module): if batch_size: activations = activations.view(batch_size, -1, activations.size(-1)) - activations = activations.roll(-self.shift.cpu().item(), dims=-1) + activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1) preds = reduce_activations(activations, reduction=self.reduction) diff --git a/pesto/utils/export.py b/pesto/utils/export.py index 05c7c5af93008a5e6c57a301ff621c14b740a0c9..70aeb54861b9bff67661fa58793be8e3230ae507 100644 --- a/pesto/utils/export.py +++ b/pesto/utils/export.py @@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1 activations = activations * confidence[:, None] plt.imshow(activations.T, aspect='auto', origin='lower', cmap='inferno', - extent=(timesteps[0], timesteps[-1]) + lims) + extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims) + plt.xlabel("Time (s)") + plt.ylabel("Pitch (semitones)") plt.title(output_file.rsplit('.', 2)[0]) plt.tight_layout() plt.savefig(output_file)