From bf678a8399d1c7c4d9b47290e62d0720c0c7fd36 Mon Sep 17 00:00:00 2001
From: "paul.best" <paul.best@lis-lab.fr>
Date: Wed, 12 Jun 2024 12:22:22 +0200
Subject: [PATCH] small fixes

---
 pesto/data.py       | 8 ++++++++
 pesto/loader.py     | 2 +-
 pesto/model.py      | 2 +-
 pesto/utils/hcqt.py | 2 +-
 4 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/pesto/data.py b/pesto/data.py
index 3044bbd..ebdd151 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -2,6 +2,7 @@ from typing import Optional
 
 import torch
 import torch.nn as nn
+from torchaudio import functional as taudio
 
 from .utils import HarmonicCQT
 
@@ -45,6 +46,9 @@ class Preprocessor(nn.Module):
 
         # if the sampling rate is provided, instantiate the CQT kernels
         if sampling_rate is not None:
+            top_bin = hcqt_kwargs['fmin'] * 2 ** (hcqt_kwargs['n_bins'] / hcqt_kwargs['bins_per_semitone'] / 12)
+            if sampling_rate / 2 < top_bin:
+                sampling_rate = int(round(top_bin*1e-3) * 2e3)
             self.hcqt_sr = sampling_rate
             self._reset_hcqt_kernels()
 
@@ -60,6 +64,10 @@ class Preprocessor(nn.Module):
             torch.Tensor: log-magnitude CQT of batch of CQTs,
                 shape (batch_size?, num_timesteps, num_harmonics, num_freqs)
         """
+
+        if sr < self.hcqt_sr:
+            x, sr = taudio.resample(x, sr, self.hcqt_sr), self.hcqt_sr
+
         # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins)
         # in other words, time becomes the batch dimension, enabling efficient processing for long audios.
         complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2)
diff --git a/pesto/loader.py b/pesto/loader.py
index a270615..1f0c77d 100644
--- a/pesto/loader.py
+++ b/pesto/loader.py
@@ -44,7 +44,7 @@ def load_model(checkpoint: str,
     model = PESTO(encoder,
                   preprocessor=preprocessor,
                   crop_kwargs=hparams["pitch_shift"],
-                  reduction=hparams["reduction"])
+                  reduction="awa")
     model.load_state_dict(state_dict, strict=False)
     model.eval()
 
diff --git a/pesto/model.py b/pesto/model.py
index b9cf5ed..748b6ee 100644
--- a/pesto/model.py
+++ b/pesto/model.py
@@ -138,7 +138,7 @@ class Resnet1d(nn.Module):
         Args:
             x (torch.Tensor): shape (batch, channels, freq_bins)
         """
-        x = self.layernorm(x)
+        x = self.layernorm(x.unsqueeze(1))
 
         x = self.conv1(x)
         for p in range(0, self.n_prefilt_layers - 1):
diff --git a/pesto/utils/hcqt.py b/pesto/utils/hcqt.py
index f7ff561..2701936 100644
--- a/pesto/utils/hcqt.py
+++ b/pesto/utils/hcqt.py
@@ -104,7 +104,7 @@ def create_cqt_kernels(
         freqs = fmin * 2.0 ** (np.r_[0:n_bins] / float(bins_per_octave))
 
     else:
-        warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning)
+        #warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning)
         n_bins = np.ceil(
             bins_per_octave * np.log2(fmax / fmin)
         )  # Calculate the number of bins
-- 
GitLab