Skip to content
Snippets Groups Projects
Commit bf678a83 authored by Paul Best's avatar Paul Best
Browse files

small fixes

parent beccc9d8
Branches main
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ from typing import Optional ...@@ -2,6 +2,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchaudio import functional as taudio
from .utils import HarmonicCQT from .utils import HarmonicCQT
...@@ -45,6 +46,9 @@ class Preprocessor(nn.Module): ...@@ -45,6 +46,9 @@ class Preprocessor(nn.Module):
# if the sampling rate is provided, instantiate the CQT kernels # if the sampling rate is provided, instantiate the CQT kernels
if sampling_rate is not None: if sampling_rate is not None:
top_bin = hcqt_kwargs['fmin'] * 2 ** (hcqt_kwargs['n_bins'] / hcqt_kwargs['bins_per_semitone'] / 12)
if sampling_rate / 2 < top_bin:
sampling_rate = int(round(top_bin*1e-3) * 2e3)
self.hcqt_sr = sampling_rate self.hcqt_sr = sampling_rate
self._reset_hcqt_kernels() self._reset_hcqt_kernels()
...@@ -60,6 +64,10 @@ class Preprocessor(nn.Module): ...@@ -60,6 +64,10 @@ class Preprocessor(nn.Module):
torch.Tensor: log-magnitude CQT of batch of CQTs, torch.Tensor: log-magnitude CQT of batch of CQTs,
shape (batch_size?, num_timesteps, num_harmonics, num_freqs) shape (batch_size?, num_timesteps, num_harmonics, num_freqs)
""" """
if sr < self.hcqt_sr:
x, sr = taudio.resample(x, sr, self.hcqt_sr), self.hcqt_sr
# compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins) # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins)
# in other words, time becomes the batch dimension, enabling efficient processing for long audios. # in other words, time becomes the batch dimension, enabling efficient processing for long audios.
complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2) complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2)
......
...@@ -44,7 +44,7 @@ def load_model(checkpoint: str, ...@@ -44,7 +44,7 @@ def load_model(checkpoint: str,
model = PESTO(encoder, model = PESTO(encoder,
preprocessor=preprocessor, preprocessor=preprocessor,
crop_kwargs=hparams["pitch_shift"], crop_kwargs=hparams["pitch_shift"],
reduction=hparams["reduction"]) reduction="awa")
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
model.eval() model.eval()
......
...@@ -138,7 +138,7 @@ class Resnet1d(nn.Module): ...@@ -138,7 +138,7 @@ class Resnet1d(nn.Module):
Args: Args:
x (torch.Tensor): shape (batch, channels, freq_bins) x (torch.Tensor): shape (batch, channels, freq_bins)
""" """
x = self.layernorm(x) x = self.layernorm(x.unsqueeze(1))
x = self.conv1(x) x = self.conv1(x)
for p in range(0, self.n_prefilt_layers - 1): for p in range(0, self.n_prefilt_layers - 1):
......
...@@ -104,7 +104,7 @@ def create_cqt_kernels( ...@@ -104,7 +104,7 @@ def create_cqt_kernels(
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / float(bins_per_octave)) freqs = fmin * 2.0 ** (np.r_[0:n_bins] / float(bins_per_octave))
else: else:
warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning) #warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning)
n_bins = np.ceil( n_bins = np.ceil(
bins_per_octave * np.log2(fmax / fmin) bins_per_octave * np.log2(fmax / fmin)
) # Calculate the number of bins ) # Calculate the number of bins
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment