Skip to content
Snippets Groups Projects
Commit 834ffa8b authored by Alain Riou's avatar Alain Riou
Browse files

roll activations + better export

parent d0c6153b
No related branches found
No related tags found
No related merge requests found
...@@ -215,7 +215,7 @@ class PESTO(nn.Module): ...@@ -215,7 +215,7 @@ class PESTO(nn.Module):
if batch_size: if batch_size:
activations = activations.view(batch_size, -1, activations.size(-1)) activations = activations.view(batch_size, -1, activations.size(-1))
activations = activations.roll(-self.shift.cpu().item(), dims=-1) activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1)
preds = reduce_activations(activations, reduction=self.reduction) preds = reduce_activations(activations, reduction=self.reduction)
......
...@@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1 ...@@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1
activations = activations * confidence[:, None] activations = activations * confidence[:, None]
plt.imshow(activations.T, plt.imshow(activations.T,
aspect='auto', origin='lower', cmap='inferno', aspect='auto', origin='lower', cmap='inferno',
extent=(timesteps[0], timesteps[-1]) + lims) extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims)
plt.xlabel("Time (s)")
plt.ylabel("Pitch (semitones)")
plt.title(output_file.rsplit('.', 2)[0]) plt.title(output_file.rsplit('.', 2)[0])
plt.tight_layout() plt.tight_layout()
plt.savefig(output_file) plt.savefig(output_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment