diff --git a/README.md b/README.md index f703e23ae3c086fe4150d364527ebbc89b32d56b..ac4d40043b01c9d10152d6a10327a61d3f040847 100644 --- a/README.md +++ b/README.md @@ -107,46 +107,64 @@ import pesto # predict the pitch of your audio tensors directly within your own Python code x, sr = torchaudio.load("my_file.wav") -timesteps, pitch, confidence, activations = pesto.predict(x, sr, step_size=10.) +x = x.mean(dim=0) # PESTO takes mono audio as input +timesteps, pitch, confidence, activations = pesto.predict(x, sr) + +# or predict using your own custom checkpoint +predictions = pesto.predict(x, sr, model_name="/path/to/checkpoint.ckpt") # You can also predict pitches from audio files directly -pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], step_size=10., export_format=["csv"]) +pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], export_format=["csv"]) ``` +`pesto.predict` supports batched inputs, which should then have shape `(batch_size, num_samples)`. + +**Warning:** If you forget to convert a stereo audio in mono, channels will be treated as batch dimensions and you will +get predictions for each channel separately. #### Advanced usage -If not provided, `pesto.predict` will first load the CQT kernels and the model before performing +`pesto.predict` will first load the CQT kernels and the model before performing any pitch estimation. If you want to process a significant number of files, calling `predict` several times will then re-initialize the same model for each tensor. -To avoid this time-consuming step, one can manually instantiate the model and data processor, then pass them directly -as args to the `predict` function. To do so, one has to use the underlying methods from `pesto.utils`: +To avoid this time-consuming step, one can manually instantiate the model with `load_model`, +then call the forward method of the model instead: ```python import torch -from pesto import predict -from pesto.utils import load_model, load_dataprocessor +from pesto import load_model device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -model = load_model("mir-1k", device=device) -data_processor = load_dataprocessor(step_size=0.01, device=device) +pesto_model = load_model("mir-1k", step_size=20.).to(device) for x, sr in ...: - data_processor.sampling_rate = sr # The data_processor handles waveform->CQT conversion so it must know the sampling rate - predictions = predict(x, sr, model=model) + x = x.to(device) + predictions, confidence, activations = pesto_model(x, sr) ... ``` -Note that when passing a list of files to `pesto.predict_from_files(...)` or the CLI directly, the model is loaded only -once so you don't have to bother with that in general. -#### Batched pitch estimation +Note that the `PESTO` object returned by `load_model` is a subclass of `nn.Module` +and its `forward` method also supports batched inputs. +One can therefore easily integrate PESTO within their own architecture by doing: +```python +import torch +import torch.nn as nn -By default, the function `pesto.predict` takes an audio waveform represented as a Tensor object of shape `(num_channels, num_samples)`. -However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications. -`pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly. +from pesto import load_model -Note that batched predictions are available only from the Python API and not from the CLI because handling audios of different lengths is annoying, and I don't want to bother with that. + +class MyGreatModel(nn.Module): + def __init__(self, step_size, sr=44100, *args, **kwargs): + super(MyGreatModel, self).__init__() + self.f0_estimator = load_model("mir-1k", step_size, sampling_rate=sr) + ... + + def forward(self, x): + with torch.no_grad(): + f0, conf = self.f0_estimator(x, convert_to_freq=True, return_activations=False) + ... +``` ## Performances diff --git a/pesto/__init__.py b/pesto/__init__.py index 82a92f8908b9f434140e3c1f7f44899a11cb8655..cff2aebc8285d3f463252869e8f432cb0189bade 100644 --- a/pesto/__init__.py +++ b/pesto/__init__.py @@ -1 +1 @@ -from .core import predict, predict_from_files +from .core import load_model, predict, predict_from_files diff --git a/pesto/core.py b/pesto/core.py index 30e2e98a6a6ed409ff8c003214273c4aadab4799..899baa6f933fb7d8403e8a5b0adfaa3b2aaa3f0e 100644 --- a/pesto/core.py +++ b/pesto/core.py @@ -1,38 +1,61 @@ import os import warnings -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Tuple, Union import torch import torchaudio from tqdm import tqdm -from .export import export -from .utils import load_model, load_dataprocessor, reduce_activation - - -def predict( - x: torch.Tensor, - sr: Optional[int] = None, - model: Union[torch.nn.Module, str] = "mir-1k", - data_preprocessor=None, - step_size: Optional[float] = None, - reduction: str = "argmax", - num_chunks: int = 1, - convert_to_freq: bool = False, - inference_mode: bool = True, - no_grad: bool = True -): +from .loader import load_model +from .model import PESTO +from .utils import export + + +def _predict(x: torch.Tensor, + sr: int, + model: PESTO, + num_chunks: int = 1, + convert_to_freq: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + preds, confidence, activations = [], [], [] + try: + for chunk in x.chunk(chunks=num_chunks): + pred, conf, act = model(chunk, sr=sr, convert_to_freq=convert_to_freq, return_activations=True) + preds.append(pred) + confidence.append(conf) + activations.append(act) + except torch.cuda.OutOfMemoryError: + raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. " + "Please increase the number of chunks with option `-c`/`--chunks` " + "to reduce GPU memory usage.") + + preds = torch.cat(preds, dim=0) + confidence = torch.cat(confidence, dim=0) + activations = torch.cat(activations, dim=0) + + # compute timesteps + timesteps = torch.arange(preds.size(-1), device=x.device) * model.hop_size + + return timesteps, preds, confidence, activations + + +def predict(x: torch.Tensor, + sr: int, + step_size: float = 10., + model_name: str = "mir-1k", + reduction: str = "alwa", + num_chunks: int = 1, + convert_to_freq: bool = True, + inference_mode: bool = True, + no_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r"""Main prediction function. Args: - x (torch.Tensor): input audio tensor, - shape (num_channels, num_samples) or (batch_size, num_channels, num_samples) + x (torch.Tensor): input audio tensor, can be provided as a batch but should be mono, + shape (num_samples) or (batch_size, num_samples) sr (int, optional): sampling rate. If not specified, uses the current sampling rate of the model. - model: PESTO model. If a string is passed, it will load the model with the corresponding name. - Otherwise, the actual nn.Module will be used for doing predictions. - data_preprocessor: Module handling the data processing pipeline (waveform to CQT, cropping, etc.) step_size (float, optional): step size between each CQT frame in milliseconds. - If the data_preprocessor is passed, its value will be used instead. + If a `PESTO` object is passed as `model`, this will be ignored. + model_name: name of PESTO model. Can be a path to a custom PESTO checkpoint or the name of a standard model. reduction (str): reduction method for converting activation probabilities to log-frequencies. num_chunks (int): number of chunks to split the input audios in. Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage @@ -40,58 +63,25 @@ def predict( convert_to_freq (bool): whether predictions should be converted to frequencies or not. inference_mode (bool): whether to run with `torch.inference_mode`. no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored. + + Returns: + timesteps (torch.Tensor): timesteps corresponding to each pitch prediction, shape (num_timesteps) + preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps) + where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`) + confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1], + shape (batch_size?, num_timesteps) + activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim) """ + # sanity checks + assert x.ndim <= 2, \ + f"Audio file should have shape (num_samples) or (batch_size, num_samples), but found shape {x.size()}." + inference_mode = inference_mode and no_grad with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode): - # convert to mono - assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}" - batch_size = x.size(0) if x.ndim == 3 else None - x = x.mean(dim=-2) - - if data_preprocessor is None: - assert step_size is not None and sr is not None, \ - "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)" - data_preprocessor = load_dataprocessor(step_size=step_size / 1000., sampling_rate=sr, device=x.device) - - # If the sampling rate has changed, change the sampling rate accordingly - # It will automatically recompute the CQT kernels if needed - if sr is not None: - data_preprocessor.sampling_rate = sr - - if isinstance(model, str): - model = load_model(model, device=x.device) + model = load_model(model_name, step_size, sampling_rate=sr).to(x.device) + model.reduction = reduction - # apply model - cqt = data_preprocessor(x) - try: - activations = torch.cat([ - model(chunk) for chunk in cqt.chunk(chunks=num_chunks) - ]) - except torch.cuda.OutOfMemoryError: - raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. " - "Please increase the number of chunks with option `-c`/`--chunks` " - "to reduce GPU memory usage.") - - if batch_size: - total_batch_size, num_predictions = activations.size() - activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions) - - # shift activations as it should (PESTO predicts pitches up to an additive constant) - activations = activations.roll(model.abs_shift.cpu().item(), dims=-1) - - # convert model predictions to pitch values - pitch = reduce_activation(activations, reduction=reduction) - if convert_to_freq: - pitch = 440 * 2 ** ((pitch - 69) / 12) - - # for now, confidence is computed very naively just based on volume - confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch) - conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values - confidence = (confidence - conf_min) / (conf_max - conf_min) - - timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size - - return timesteps, pitch, confidence, activations + return _predict(x, sr, model, num_chunks=num_chunks, convert_to_freq=convert_to_freq) def predict_from_files( @@ -113,12 +103,11 @@ def predict_from_files( output: step_size: hop length in milliseconds reduction: - export_format: + export_format (Sequence[str]): format to export the predictions to. + Currently formats supported are: ["csv", "npz", "json"]. no_convert_to_freq: whether convert output values to Hz or keep fractional MIDI pitches + num_chunks (int): number of chunks to divide the inputs into. Increase this value if you encounter OOM errors. gpu: index of GPU to use (-1 for CPU) - - Returns: - Pitch predictions, see `predict` for more details. """ if isinstance(audio_files, str): audio_files = [audio_files] @@ -128,37 +117,34 @@ def predict_from_files( gpu = -1 device = torch.device(f"cuda:{gpu:d}" if gpu >= 0 else "cpu") - # define data preprocessing - data_preprocessor = load_dataprocessor(step_size / 1000., device=device) - # define model - model = load_model(model_name, device=device) - predictions = None + model = load_model(model_name, step_size=step_size).to(device) + model.reduction = reduction pbar = tqdm(audio_files) - for file in pbar: - pbar.set_description(file) - # load audio file - try: - x, sr = torchaudio.load(file) - except Exception as e: - print(e, f"Skipping {file}...") - continue + with torch.inference_mode(): # here the purpose is to write results in disk, so there is no point storing gradients + for file in pbar: + pbar.set_description(file) + + # load audio file + try: + x, sr = torchaudio.load(file) + except Exception as e: + print(e, f"Skipping {file}...") + continue - x = x.to(device) + x = x.mean(dim=0).to(device) # convert to mono then pass to the right device - # compute the predictions - predictions = predict(x, sr, model=model, data_preprocessor=data_preprocessor, reduction=reduction, - convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks) + # compute the predictions + predictions = _predict(x, sr, model=model, convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks) - output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0") - if output is not None: - os.makedirs(output, exist_ok=True) - output_file = os.path.join(output, os.path.basename(output_file)) + output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0") + if output is not None: + os.makedirs(output, exist_ok=True) + output_file = os.path.join(output, os.path.basename(output_file)) - predictions = [p.cpu().numpy() for p in predictions] - for fmt in export_format: - export(fmt, output_file, *predictions) + predictions = [p.cpu().numpy() for p in predictions] + for fmt in export_format: + export(fmt, output_file, *predictions) - return predictions diff --git a/pesto/data.py b/pesto/data.py index f5ed52e7fe383d4b3d54a3fbf19fd627b6e99b67..7d8b81d77b2a982c559e258d3f7c8cde301f6d75 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -3,79 +3,92 @@ from typing import Optional import torch import torch.nn as nn -from .cqt import CQT +from .utils import HarmonicCQT -class DataProcessor(nn.Module): +class ToLogMagnitude(nn.Module): + def __init__(self): + super(ToLogMagnitude, self).__init__() + self.eps = torch.finfo(torch.float32).eps + + def forward(self, x): + x = x.abs() + x.clamp_(min=self.eps).log10_().mul_(20) + return x + + + +class Preprocessor(nn.Module): r""" Args: - step_size (float): step size between consecutive CQT frames (in SECONDS) + hop_size (float): step size between consecutive CQT frames (in milliseconds) """ - _sampling_rate: Optional[int] = None - def __init__(self, - step_size: float, - bins_per_semitone: int = 3, + hop_size: float, sampling_rate: Optional[int] = None, - **cqt_kwargs): - super(DataProcessor, self).__init__() - self.step_size = step_size - self.bins_per_semitone = bins_per_semitone + **hcqt_kwargs): + super(Preprocessor, self).__init__() - # CQT-related stuff - self.cqt_kwargs = cqt_kwargs - self.cqt_kwargs["bins_per_octave"] = 12 * bins_per_semitone - self.cqt = None + # HCQT + self.hcqt_sr = None + self.hcqt_kernels = None + self.hop_size = hop_size - # log-magnitude - self.eps = torch.finfo(torch.float32).eps + self.hcqt_kwargs = hcqt_kwargs - # cropping - self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5) - self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone + # log-magnitude + self.to_log = ToLogMagnitude() # register a dummy tensor to get implicit access to the module's device self.register_buffer("_device", torch.zeros(()), persistent=False) - # sampling rate is lazily initialized + # if the sampling rate is provided, instantiate the CQT kernels if sampling_rate is not None: - self.sampling_rate = sampling_rate + self.hcqt_sr = sampling_rate + self._reset_hcqt_kernels() - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor: r""" Args: - x: audio waveform, any sampling rate, shape (num_samples) + x (torch.Tensor): audio waveform or batch of audio waveforms, any sampling rate, + shape (batch_size?, num_samples) + sr (int, optional): sampling rate Returns: - log-magnitude CQT, shape ( + torch.Tensor: log-magnitude CQT of batch of CQTs, + shape (batch_size?, num_timesteps, num_harmonics, num_freqs) """ - # compute CQT from input waveform, and invert dims for (batch_size, time_steps, freq_bins) - complex_cqt = torch.view_as_complex(self.cqt(x)).transpose(1, 2) - - # reshape and crop borders to fit training input shape - complex_cqt = complex_cqt[..., self.lowest_bin: self.highest_bin] - - # flatten eventual batch dimensions so that batched audios can be processed in parallel - complex_cqt = complex_cqt.flatten(0, 1).unsqueeze(1) + # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins) + # in other words, time becomes the batch dimension, enabling efficient processing for long audios. + complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2) # convert to dB - log_cqt = complex_cqt.abs().clamp_(min=self.eps).log10_().mul_(20) - return log_cqt - - def _init_cqt_layer(self, sr: int, hop_length: int): - self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device) + return self.to_log(complex_cqt) - @property - def sampling_rate(self) -> int: - return self._sampling_rate + def hcqt(self, audio: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor: + r"""Compute the Harmonic CQT of the input audio after eventually recreating the kernels + (in case the sampling rate has changed). - @sampling_rate.setter - def sampling_rate(self, sr: int) -> None: - if sr == self._sampling_rate: - return + Args: + audio (torch.Tensor): mono audio waveform, shape (batch_size, num_samples) + sr (int): sampling rate of the audio waveform. + If not specified, we assume it is the same as the previous processed audio waveform. - hop_length = int(self.step_size * sr + 0.5) - self._init_cqt_layer(sr, hop_length) - self._sampling_rate = sr + Returns: + torch.Tensor: Complex Harmonic CQT (HCQT) of the input, + shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2) + """ + # compute HCQT kernels if it does not exist or if the sampling rate has changed + if sr is not None and sr != self.hcqt_sr: + self.hcqt_sr = sr + self._reset_hcqt_layer() + + return self.hcqt_kernels(audio) + + def _reset_hcqt_layer(self) -> None: + hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5) + self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr, + hop_length=hop_length, + **self.hcqt_kwargs).to(self._device.device) diff --git a/pesto/loader.py b/pesto/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..609d4b5df3cc4b9e2e1c891034b9479379cb774e --- /dev/null +++ b/pesto/loader.py @@ -0,0 +1,51 @@ +import os +from typing import Optional + +import torch + +from .data import Preprocessor +from .model import PESTO, Resnet1d + + +def load_model(checkpoint: str, + step_size: float, + sampling_rate: Optional[int] = None) -> PESTO: + r"""Load a trained model from a checkpoint file. + See https://github.com/SonyCSLParis/pesto-full/blob/master/src/models/pesto.py for the structure of the checkpoint. + + Args: + checkpoint (str): path to the checkpoint or name of the checkpoint file (if using a provided checkpoint) + step_size (float): hop size in milliseconds + sampling_rate (int, optional): sampling rate of the audios. + If not provided, it can be inferred dynamically as well. + Returns: + PESTO: instance of PESTO model + """ + if os.path.exists(checkpoint): # handle user-provided checkpoints + model_path = checkpoint + else: + model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".pth") + if not os.path.exists(model_path): + raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.") + + # load checkpoint + checkpoint = torch.load(model_path, map_location=torch.device("cpu")) + hparams = checkpoint["hparams"] + hcqt_params = checkpoint["hcqt_params"] + state_dict = checkpoint["state_dict"] + + # instantiate preprocessor + preprocessor = Preprocessor(hop_size=step_size, sampling_rate=sampling_rate, **hcqt_params) + + # instantiate PESTO encoder + encoder = Resnet1d(**hparams["encoder"]) + + # instantiate main PESTO module and load its weights + model = PESTO(encoder, + preprocessor=preprocessor, + crop_kwargs=hparams["pitch_shift"], + reduction=hparams["reduction"]) + model.load_state_dict(state_dict) + model.eval() + + return model diff --git a/pesto/main.py b/pesto/main.py index bc52df6cb0374c6798e23fdfe117e224c7983369..d117941f84eed1535970c16f5f4fa53029bf91f5 100644 --- a/pesto/main.py +++ b/pesto/main.py @@ -1,4 +1,4 @@ -from .parser import parse_args +from pesto.utils.parser import parse_args from .core import predict_from_files diff --git a/pesto/model.py b/pesto/model.py index 94337ec84570b1cd3d5aaf1609b06f3f49a30ae6..27edd932faa90d51e8d23f5a55dde498a33fbb0d 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -1,8 +1,15 @@ from functools import partial +from typing import Any, Mapping, Optional, Tuple, Union import torch import torch.nn as nn +from .utils import CropCQT +from .utils import reduce_activations + + +OUTPUT_TYPE = Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + class ToeplitzLinear(nn.Conv1d): def __init__(self, in_features, out_features): @@ -18,7 +25,7 @@ class ToeplitzLinear(nn.Conv1d): return super(ToeplitzLinear, self).forward(input.unsqueeze(-2)).squeeze(-2) -class PESTOEncoder(nn.Module): +class Resnet1d(nn.Module): """ Basic CNN similar to the one in Johannes Zeitler's report, but for longer HCQT input (always stride 1 in time) @@ -29,7 +36,8 @@ class PESTOEncoder(nn.Module): not over time (in order to work with variable length input). Outputs one channel with sigmoid activation. - Args (Defaults: BasicCNN by Johannes Zeitler but with 1 input channel): + Args (Defaults: BasicCNN by Johannes Zeitler but with 6 input channels): + n_chan_input: Number of input channels (harmonics in HCQT) n_chan_layers: Number of channels in the hidden layers (list) n_prefilt_layers: Number of repetitions of the prefiltering layer residual: If True, use residual connections for prefiltering (default: False) @@ -39,100 +47,188 @@ class PESTOEncoder(nn.Module): p_dropout: Dropout probability """ - def __init__( - self, - n_chan_layers=(20, 20, 10, 1), - n_prefilt_layers=1, - residual=False, - n_bins_in=216, - output_dim=128, - num_output_layers: int = 1 - ): - super(PESTOEncoder, self).__init__() - - activation_layer = partial(nn.LeakyReLU, negative_slope=0.3) - + def __init__(self, + n_chan_input=1, + n_chan_layers=(20, 20, 10, 1), + n_prefilt_layers=1, + prefilt_kernel_size=15, + residual=False, + n_bins_in=216, + output_dim=128, + activation_fn: str = "leaky", + a_lrelu=0.3, + p_dropout=0.2): + super(Resnet1d, self).__init__() + + self.hparams = dict(n_chan_input=n_chan_input, + n_chan_layers=n_chan_layers, + n_prefilt_layers=n_prefilt_layers, + prefilt_kernel_size=prefilt_kernel_size, + residual=residual, + n_bins_in=n_bins_in, + output_dim=output_dim, + activation_fn=activation_fn, + a_lrelu=a_lrelu, + p_dropout=p_dropout) + + if activation_fn == "relu": + activation_layer = nn.ReLU + elif activation_fn == "silu": + activation_layer = nn.SiLU + elif activation_fn == "leaky": + activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu) + else: + raise ValueError + + n_in = n_chan_input n_ch = n_chan_layers if len(n_ch) < 5: n_ch.append(1) - # Layer normalization over frequency - self.layernorm = nn.LayerNorm(normalized_shape=[1, n_bins_in]) + # Layer normalization over frequency and channels (harmonics of HCQT) + self.layernorm = nn.LayerNorm(normalized_shape=[n_in, n_bins_in]) # Prefiltering + prefilt_padding = prefilt_kernel_size // 2 self.conv1 = nn.Sequential( - nn.Conv1d(in_channels=1, out_channels=n_ch[0], kernel_size=15, padding=7, stride=1), - activation_layer() + nn.Conv1d(in_channels=n_in, + out_channels=n_ch[0], + kernel_size=prefilt_kernel_size, + padding=prefilt_padding, + stride=1), + activation_layer(), + nn.Dropout(p=p_dropout) ) self.n_prefilt_layers = n_prefilt_layers - self.prefilt_list = nn.ModuleList() - for p in range(1, n_prefilt_layers): - self.prefilt_list.append(nn.Sequential( - nn.Conv1d(in_channels=n_ch[0], out_channels=n_ch[0], kernel_size=15, padding=7, stride=1), - activation_layer() - )) + self.prefilt_layers = nn.ModuleList(*[ + nn.Sequential( + nn.Conv1d(in_channels=n_ch[0], + out_channels=n_ch[0], + kernel_size=prefilt_kernel_size, + padding=prefilt_padding, + stride=1), + activation_layer(), + nn.Dropout(p=p_dropout) + ) + for _ in range(n_prefilt_layers-1) + ]) self.residual = residual - self.conv2 = nn.Sequential( - nn.Conv1d( - in_channels=n_ch[0], - out_channels=n_ch[1], - kernel_size=1, - stride=1, - padding=0 - ), - activation_layer() - ) - - self.conv3 = nn.Sequential( - nn.Conv1d(in_channels=n_ch[1], out_channels=n_ch[2], kernel_size=1, padding=0, stride=1), - activation_layer() - ) - - self.conv4 = nn.Sequential( - nn.Conv1d(in_channels=n_ch[2], out_channels=n_ch[3], kernel_size=1, padding=0, stride=1), - activation_layer(), - nn.Dropout(), - nn.Conv1d(in_channels=n_ch[3], out_channels=n_ch[4], kernel_size=1, padding=0, stride=1) - ) + conv_layers = [] + for i in range(len(n_chan_layers)-1): + conv_layers.extend([ + nn.Conv1d(in_channels=n_ch[i], + out_channels=n_ch[i + 1], + kernel_size=1, + padding=0, + stride=1), + activation_layer(), + nn.Dropout(p=p_dropout) + ]) + self.conv_layers = nn.Sequential(*conv_layers) self.flatten = nn.Flatten(start_dim=1) - - layers = [] - pre_fc_dim = n_bins_in * n_ch[4] - for i in range(num_output_layers-1): - layers.extend([ - ToeplitzLinear(pre_fc_dim, pre_fc_dim), - activation_layer() - ]) - self.pre_fc = nn.Sequential(*layers) - self.fc = ToeplitzLinear(pre_fc_dim, output_dim) + self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim) self.final_norm = nn.Softmax(dim=-1) - self.register_buffer("abs_shift", torch.zeros((), dtype=torch.long), persistent=True) - def forward(self, x): r""" Args: x (torch.Tensor): shape (batch, channels, freq_bins) """ - x_norm = self.layernorm(x) + x = self.layernorm(x) - x = self.conv1(x_norm) + x = self.conv1(x) for p in range(0, self.n_prefilt_layers - 1): - prefilt_layer = self.prefilt_list[p] + prefilt_layer = self.prefilt_layers[p] if self.residual: x_new = prefilt_layer(x) x = x_new + x else: x = prefilt_layer(x) - conv2_lrelu = self.conv2(x) - conv3_lrelu = self.conv3(conv2_lrelu) - y_pred = self.conv4(conv3_lrelu) - y_pred = self.flatten(y_pred) - y_pred = self.pre_fc(y_pred) - y_pred = self.fc(y_pred) + x = self.conv_layers(x) + x = self.flatten(x) + + y_pred = self.fc(x) + return self.final_norm(y_pred) + + +class PESTO(nn.Module): + def __init__(self, + encoder: nn.Module, + preprocessor: nn.Module, + crop_kwargs: Mapping[str, Any] | None = None, + reduction: str = "alwa"): + super(PESTO, self).__init__() + self.encoder = encoder + self.preprocessor = preprocessor + + # crop CQT + if crop_kwargs is None: + crop_kwargs = {} + self.crop_cqt = CropCQT(**crop_kwargs) + + self.reduction = reduction + + # constant shift to get absolute pitch from predictions + self.register_buffer('shift', torch.zeros((), dtype=torch.float), persistent=True) + + def forward(self, + audio_waveforms: torch.Tensor, + sr: Optional[int] = None, + convert_to_freq: bool = False, + return_activations: bool = False) -> OUTPUT_TYPE: + r""" + + Args: + audio_waveforms (torch.Tensor): mono audio waveform or batch of mono audio waveforms, + shape (batch_size?, num_samples) + sr (int, optional): sampling rate, defaults to the previously used sampling rate + convert_to_freq (bool): whether to convert the result to frequencies or return fractional semitones instead. + return_activations (bool): whether to return activations or pitch predictions only + + Returns: + preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps) + where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`) + confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1], + shape (batch_size?, num_timesteps) + activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim) + """ + batch_size = audio_waveforms.size(0) if audio_waveforms.ndim == 2 else None + x = self.preprocessor(audio_waveforms, sr=sr) + x = self.crop_cqt(x) # the CQT has to be cropped beforehand + + # for now, confidence is computed very naively just based on energy in the CQT + confidence = x.mean(dim=-2).max(dim=-1).values + conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values + confidence = (confidence - conf_min) / (conf_max - conf_min) + + # flatten batch_size and time_steps since anyway predictions are made on CQT frames independently + if batch_size: + x = x.flatten(0, 1) + + activations = self.encoder(x) + if batch_size: + activations = activations.view(batch_size, -1, activations.size(-1)) + + preds = reduce_activations(activations, reduction=self.reduction) + + # decrease by shift to get absolute pitch + preds.sub_(self.shift) + + if convert_to_freq: + preds = 440 * 2 ** ((preds - 69) / 12) + + if return_activations: + return preds, confidence, activations + + return preds, confidence + + @property + def hop_size(self) -> float: + r"""Returns the hop size of the model (in milliseconds)""" + return self.preprocessor.hop_size diff --git a/pesto/utils.py b/pesto/utils.py deleted file mode 100644 index 18722603ce06cde714ff3030405e904d1c16b71e..0000000000000000000000000000000000000000 --- a/pesto/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from typing import Optional - -import torch - -from .config import model_args, cqt_args, bins_per_semitone -from .data import DataProcessor -from .model import PESTOEncoder - - -def load_dataprocessor(step_size, sampling_rate: Optional[int] = None, device: Optional[torch.device] = None): - return DataProcessor(step_size=step_size, sampling_rate=sampling_rate, **cqt_args).to(device) - - -def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: - model = PESTOEncoder(**model_args).to(device) - model.eval() - - model_path = os.path.join(os.path.dirname(__file__), "weights", model_name + ".pth") - model.load_state_dict(torch.load(model_path, map_location=device)) - - return model - - -def reduce_activation(activations: torch.Tensor, reduction: str) -> torch.Tensor: - r"""Computes the pitch predictions from the activation outputs of the encoder. - Pitch predictions are returned in semitones, NOT in frequencies. - - Args: - activations: tensor of probability activations, shape (*, num_bins) - reduction: - - Returns: - torch.Tensor: pitch predictions, shape (*,) - """ - bps = bins_per_semitone - if reduction == "argmax": - pred = activations.argmax(dim=-1) - return pred.float() / bps - - all_pitches = (torch.arange(activations.size(-1), dtype=torch.float, device=activations.device)) / bps - if reduction == "mean": - return activations @ all_pitches - - if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe - center_bin = activations.argmax(dim=-1, keepdim=True) - window = torch.arange(-bps+1, bps, device=activations.device) - indices = window + center_bin - cropped_activations = activations.gather(-1, indices) - cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices) - return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1) - - raise ValueError diff --git a/pesto/utils/__init__.py b/pesto/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9be9b1b72be5579424e7411fe2291823838cd3d4 --- /dev/null +++ b/pesto/utils/__init__.py @@ -0,0 +1,4 @@ +from .crop_cqt import CropCQT +from .export import export +from .hcqt import HarmonicCQT +from .reduce_activations import reduce_activations \ No newline at end of file diff --git a/pesto/utils/crop_cqt.py b/pesto/utils/crop_cqt.py new file mode 100644 index 0000000000000000000000000000000000000000..65829c70e8ee90554acee1c027178f6292cd4230 --- /dev/null +++ b/pesto/utils/crop_cqt.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + + +class CropCQT(nn.Module): + def __init__(self, min_steps: int, max_steps: int): + super(CropCQT, self).__init__() + self.min_steps = min_steps + self.max_steps = max_steps + + # lower bin + self.lower_bin = self.max_steps + + def forward(self, spectrograms: torch.Tensor) -> torch.Tensor: + # WARNING: didn't check that it works, it may be dangerous + return spectrograms[..., self.max_steps: self.min_steps] + + # old implementation + batch_size, _, input_height = spectrograms.size() + + output_height = input_height - self.max_steps + self.min_steps + assert output_height > 0, \ + f"With input height {input_height:d} and output height {output_height:d}, impossible " \ + f"to have a range of {self.max_steps - self.min_steps:d} bins." + + return spectrograms[..., self.lower_bin: self.lower_bin + output_height] diff --git a/pesto/export.py b/pesto/utils/export.py similarity index 100% rename from pesto/export.py rename to pesto/utils/export.py diff --git a/pesto/cqt.py b/pesto/utils/hcqt.py similarity index 91% rename from pesto/cqt.py rename to pesto/utils/hcqt.py index 6fdfccb6372077fd1f94b645fded1178c90983d7..4c92564ce46c05002f0d7c9ffca548818b7db0ca 100644 --- a/pesto/cqt.py +++ b/pesto/utils/hcqt.py @@ -354,3 +354,39 @@ class CQT(nn.Module): phase_real = torch.cos(torch.atan2(CQT_imag, CQT_real)) phase_imag = torch.sin(torch.atan2(CQT_imag, CQT_real)) return torch.stack((phase_real, phase_imag), -1) + + +class HarmonicCQT(nn.Module): + r"""Harmonic CQT layer, as described in Bittner et al. (20??)""" + def __init__( + self, + harmonics, + sr: int = 22050, + hop_length: int = 512, + fmin: float = 32.7, + fmax: float | None = None, + bins_per_semitone: int = 1, + n_bins: int = 84, + center_bins: bool = True + ): + super(HarmonicCQT, self).__init__() + + if center_bins: + fmin = fmin / 2 ** ((bins_per_semitone - 1) / (24 * bins_per_semitone)) + + self.cqt_kernels = nn.ModuleList([ + CQT(sr=sr, hop_length=hop_length, fmin=h*fmin, fmax=fmax, n_bins=n_bins, + bins_per_octave=12*bins_per_semitone, output_format="Complex") + for h in harmonics + ]) + + def forward(self, audio_waveforms: torch.Tensor): + r"""Converts a batch of waveforms into a batch of HCQTs. + + Args: + audio_waveforms (torch.Tensor): Batch of waveforms, shape (batch_size, num_samples) + + Returns: + Harmonic CQT, shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2) + """ + return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_kernels], dim=1) diff --git a/pesto/parser.py b/pesto/utils/parser.py similarity index 100% rename from pesto/parser.py rename to pesto/utils/parser.py diff --git a/pesto/utils/reduce_activations.py b/pesto/utils/reduce_activations.py new file mode 100644 index 0000000000000000000000000000000000000000..191cb4f060ce53e1081b8696b5408ab0af1d6b28 --- /dev/null +++ b/pesto/utils/reduce_activations.py @@ -0,0 +1,36 @@ +import torch + + +def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor: + r""" + + Args: + activations: tensor of probability activations, shape (*, num_bins) + reduction (str): reduction method to compute pitch out of activations, + choose between "argmax", "mean", "alwa". + + Returns: + torch.Tensor: pitches as fractions of MIDI semitones, shape (*) + """ + device = activations.device + num_bins = activations.size(-1) + bps, r = divmod(num_bins, 128) + assert r == 0, f"Activations should have output size 128*bins_per_semitone, got {num_bins}." + + if reduction == "argmax": + pred = activations.argmax(dim=-1) + return pred.float() / bps + + all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps) + if reduction == "mean": + return torch.matmul(activations, all_pitches) + + if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe + center_bin = activations.argmax(dim=-1, keepdim=True) + window = torch.arange(1, 2 * bps, device=device) - bps # [-bps+1, -bps+2, ..., bps-2, bps-1] + indices = (center_bin + window).clip_(min=0, max=num_bins - 1) + cropped_activations = activations.gather(-1, indices) + cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(1, indices) + return (cropped_activations * cropped_pitches).sum(dim=1) / cropped_activations.sum(dim=1) + + raise ValueError diff --git a/pesto/weights/mir-1k.pth b/pesto/weights/mir-1k.pth deleted file mode 100644 index aadf78ceeff86ae4e86ac305053c5bcc9791ec7f..0000000000000000000000000000000000000000 Binary files a/pesto/weights/mir-1k.pth and /dev/null differ