From bf678a8399d1c7c4d9b47290e62d0720c0c7fd36 Mon Sep 17 00:00:00 2001 From: "paul.best" <paul.best@lis-lab.fr> Date: Wed, 12 Jun 2024 12:22:22 +0200 Subject: [PATCH] small fixes --- pesto/data.py | 8 ++++++++ pesto/loader.py | 2 +- pesto/model.py | 2 +- pesto/utils/hcqt.py | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pesto/data.py b/pesto/data.py index 3044bbd..ebdd151 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.nn as nn +from torchaudio import functional as taudio from .utils import HarmonicCQT @@ -45,6 +46,9 @@ class Preprocessor(nn.Module): # if the sampling rate is provided, instantiate the CQT kernels if sampling_rate is not None: + top_bin = hcqt_kwargs['fmin'] * 2 ** (hcqt_kwargs['n_bins'] / hcqt_kwargs['bins_per_semitone'] / 12) + if sampling_rate / 2 < top_bin: + sampling_rate = int(round(top_bin*1e-3) * 2e3) self.hcqt_sr = sampling_rate self._reset_hcqt_kernels() @@ -60,6 +64,10 @@ class Preprocessor(nn.Module): torch.Tensor: log-magnitude CQT of batch of CQTs, shape (batch_size?, num_timesteps, num_harmonics, num_freqs) """ + + if sr < self.hcqt_sr: + x, sr = taudio.resample(x, sr, self.hcqt_sr), self.hcqt_sr + # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins) # in other words, time becomes the batch dimension, enabling efficient processing for long audios. complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2) diff --git a/pesto/loader.py b/pesto/loader.py index a270615..1f0c77d 100644 --- a/pesto/loader.py +++ b/pesto/loader.py @@ -44,7 +44,7 @@ def load_model(checkpoint: str, model = PESTO(encoder, preprocessor=preprocessor, crop_kwargs=hparams["pitch_shift"], - reduction=hparams["reduction"]) + reduction="awa") model.load_state_dict(state_dict, strict=False) model.eval() diff --git a/pesto/model.py b/pesto/model.py index b9cf5ed..748b6ee 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -138,7 +138,7 @@ class Resnet1d(nn.Module): Args: x (torch.Tensor): shape (batch, channels, freq_bins) """ - x = self.layernorm(x) + x = self.layernorm(x.unsqueeze(1)) x = self.conv1(x) for p in range(0, self.n_prefilt_layers - 1): diff --git a/pesto/utils/hcqt.py b/pesto/utils/hcqt.py index f7ff561..2701936 100644 --- a/pesto/utils/hcqt.py +++ b/pesto/utils/hcqt.py @@ -104,7 +104,7 @@ def create_cqt_kernels( freqs = fmin * 2.0 ** (np.r_[0:n_bins] / float(bins_per_octave)) else: - warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning) + #warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning) n_bins = np.ceil( bins_per_octave * np.log2(fmax / fmin) ) # Calculate the number of bins -- GitLab