diff --git a/pesto/data.py b/pesto/data.py index 3044bbdf3fcfc0818b198555f7711623517d8e62..ebdd151fe5eb3a4e02a9591fe9615aad5d8bd7fc 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.nn as nn +from torchaudio import functional as taudio from .utils import HarmonicCQT @@ -45,6 +46,9 @@ class Preprocessor(nn.Module): # if the sampling rate is provided, instantiate the CQT kernels if sampling_rate is not None: + top_bin = hcqt_kwargs['fmin'] * 2 ** (hcqt_kwargs['n_bins'] / hcqt_kwargs['bins_per_semitone'] / 12) + if sampling_rate / 2 < top_bin: + sampling_rate = int(round(top_bin*1e-3) * 2e3) self.hcqt_sr = sampling_rate self._reset_hcqt_kernels() @@ -60,6 +64,10 @@ class Preprocessor(nn.Module): torch.Tensor: log-magnitude CQT of batch of CQTs, shape (batch_size?, num_timesteps, num_harmonics, num_freqs) """ + + if sr < self.hcqt_sr: + x, sr = taudio.resample(x, sr, self.hcqt_sr), self.hcqt_sr + # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins) # in other words, time becomes the batch dimension, enabling efficient processing for long audios. complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2) diff --git a/pesto/loader.py b/pesto/loader.py index a270615997197cfa4abdcfeef30de6a68d67be98..1f0c77d91f2e9781d3b6698652be6e421156dc5d 100644 --- a/pesto/loader.py +++ b/pesto/loader.py @@ -44,7 +44,7 @@ def load_model(checkpoint: str, model = PESTO(encoder, preprocessor=preprocessor, crop_kwargs=hparams["pitch_shift"], - reduction=hparams["reduction"]) + reduction="awa") model.load_state_dict(state_dict, strict=False) model.eval() diff --git a/pesto/model.py b/pesto/model.py index b9cf5eda18349e38a2d4de1f739798bd5b1df0e6..748b6eeb82dc71d1fcc29dac5ca25d80e64fbcc9 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -138,7 +138,7 @@ class Resnet1d(nn.Module): Args: x (torch.Tensor): shape (batch, channels, freq_bins) """ - x = self.layernorm(x) + x = self.layernorm(x.unsqueeze(1)) x = self.conv1(x) for p in range(0, self.n_prefilt_layers - 1): diff --git a/pesto/utils/hcqt.py b/pesto/utils/hcqt.py index f7ff5610feb4b3c8d3bb29affc9a4ea5d2661538..270193625bd75bef1eb6115d27d91dc94590d51e 100644 --- a/pesto/utils/hcqt.py +++ b/pesto/utils/hcqt.py @@ -104,7 +104,7 @@ def create_cqt_kernels( freqs = fmin * 2.0 ** (np.r_[0:n_bins] / float(bins_per_octave)) else: - warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning) + #warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning) n_bins = np.ceil( bins_per_octave * np.log2(fmax / fmin) ) # Calculate the number of bins