diff --git a/.github/workflows/test-workflow.yml b/.github/workflows/test-workflow.yml index 7fc7156c6d6d3ee4b417cbd211d322afb807ebc3..43bda75133354d34cb8d02e4c7a5e27b44110171 100644 --- a/.github/workflows/test-workflow.yml +++ b/.github/workflows/test-workflow.yml @@ -2,7 +2,7 @@ name: Test Workflow on: pull_request: - branches: [ "master" ] + branches: [ "dev" ] jobs: build: @@ -21,7 +21,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + sudo apt-get install -y libsox-dev python -m pip install --upgrade pip + python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu + python -m pip install matplotlib python -m pip install pytest python -m pip install . - name: Test with pytest diff --git a/.gitignore b/.gitignore index cc684f882c3744dc83af39ef5e8cbfb7e934fcb9..daf4967a0f60ce293927d6e5b44629d6af654304 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ .idea -*/__pycache__ +**/__pycache__ *.egg-info/ .DS_Store dist/ build/ + +**/*.csv +**/*.mid +**/*.png diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..b6fcb6ecf6f2179b378c9192e2ba3d3143d5cbe1 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,21 @@ +# Changelog + +## v1.0.0 + +- Change API under the hood to make it more object-oriented + - store all utilities inside a `PESTO` object that is a subclass of `nn.Module` + - make the API compatible with the checkpoints generated by the training repo +- add tests +- replace `setup.py` by `pyproject.toml` +- fix a few issues +- improve README and documentation + +## v0.1.1 + +- solve issue when exporting in PNG +- solve device issue when changing sampling rate (#17) + + +## v0.1.0 - 2023-10-17 + +Initial version \ No newline at end of file diff --git a/README.md b/README.md index a41a69ccdbddcd652cd2f309d84248276834478a..93c9572fae10943ea42b9988d572bfcbb4747c41 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ This repository is implemented in [PyTorch](https://pytorch.org/) and has the fo - [torchaudio](https://pytorch.org/audio/stable/) for audio loading - `matplotlib` for exporting pitch predictions as images (optional) +**Warning:** If installing in a clean environment, it may be safer to first install PyTorch [the recommended way](https://pytorch.org/get-started/locally/) before PESTO. + ## Usage ### Command-line interface @@ -107,48 +109,64 @@ import pesto # predict the pitch of your audio tensors directly within your own Python code x, sr = torchaudio.load("my_file.wav") -timesteps, pitch, confidence, activations = pesto.predict(x, sr, step_size=10.) +x = x.mean(dim=0) # PESTO takes mono audio as input +timesteps, pitch, confidence, activations = pesto.predict(x, sr) + +# or predict using your own custom checkpoint +predictions = pesto.predict(x, sr, model_name="/path/to/checkpoint.ckpt") # You can also predict pitches from audio files directly -pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], step_size=10., export_format=["csv"]) +pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], export_format=["csv"]) ``` +`pesto.predict` supports batched inputs, which should then have shape `(batch_size, num_samples)`. + +**Warning:** If you forget to convert a stereo audio in mono, channels will be treated as batch dimensions and you will +get predictions for each channel separately. #### Advanced usage -If not provided, `pesto.predict` will first load the CQT kernels and the model before performing +`pesto.predict` will first load the CQT kernels and the model before performing any pitch estimation. If you want to process a significant number of files, calling `predict` several times will then re-initialize the same model for each tensor. -To avoid this time-consuming step, one can manually instantiate the model and data processor, then pass them directly -as args to the `predict` function. To do so, one has to use the underlying methods from `pesto.utils`: +To avoid this time-consuming step, one can manually instantiate the model with `load_model`, +then call the forward method of the model instead: ```python import torch -from pesto import predict -from pesto.utils import load_model, load_dataprocessor +from pesto import load_model device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -model = load_model("mir-1k", device=device) -data_processor = load_dataprocessor(step_size=0.01, device=device) +pesto_model = load_model("mir-1k", step_size=20.).to(device) for x, sr in ...: - data_processor.sampling_rate = sr # The data_processor handles waveform->CQT conversion so it must know the sampling rate - predictions = predict(x, sr, model=model) + x = x.to(device) + predictions, confidence, activations = pesto_model(x, sr) ... ``` -Note that when passing a list of files to `pesto.predict_from_files(...)` or the CLI directly, the model is loaded only -once so you don't have to bother with that in general. -#### Batched pitch estimation +Note that the `PESTO` object returned by `load_model` is a subclass of `nn.Module` +and its `forward` method also supports batched inputs. +One can therefore easily integrate PESTO within their own architecture by doing: +```python +import torch +import torch.nn as nn + +from pesto import load_model -By default, the function `pesto.predict` takes an audio waveform represented as a Tensor object of shape `(num_channels, num_samples)`. -However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications. -`pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly. -Note that batched predictions are available only from the Python API and not from the CLI because: -- handling audios of different lengths is annoying, I don't want to bother with that -- when estimating pitch on +class MyGreatModel(nn.Module): + def __init__(self, step_size, sr=44100, *args, **kwargs): + super(MyGreatModel, self).__init__() + self.f0_estimator = load_model("mir-1k", step_size, sampling_rate=sr) + ... + + def forward(self, x): + with torch.no_grad(): + f0, conf = self.f0_estimator(x, convert_to_freq=True, return_activations=False) + ... +``` ## Performances diff --git a/pesto/__init__.py b/pesto/__init__.py index 82a92f8908b9f434140e3c1f7f44899a11cb8655..167a098923d22ab94ffa03e186547d9c6f3c2bfb 100644 --- a/pesto/__init__.py +++ b/pesto/__init__.py @@ -1 +1,4 @@ -from .core import predict, predict_from_files +from .core import load_model, predict, predict_from_files + + +__version__ = '1.0.0' diff --git a/pesto/core.py b/pesto/core.py index c55ac4afa24b020ac45e84e7eed4b46380c4524f..dd28a5872e8caa8a1f49160743b9d9fe05138a42 100644 --- a/pesto/core.py +++ b/pesto/core.py @@ -1,91 +1,87 @@ import os import warnings -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Tuple, Union import torch import torchaudio from tqdm import tqdm -from .utils import load_model, load_dataprocessor, reduce_activation -from .export import export +from .loader import load_model +from .model import PESTO +from .utils import export -@torch.inference_mode() -def predict( - x: torch.Tensor, - sr: Optional[int] = None, - model: Union[torch.nn.Module, str] = "mir-1k", - data_preprocessor=None, - step_size: Optional[float] = None, - reduction: str = "argmax", - num_chunks: int = 1, - convert_to_freq: bool = False -): +def _predict(x: torch.Tensor, + sr: int, + model: PESTO, + num_chunks: int = 1, + convert_to_freq: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + preds, confidence, activations = [], [], [] + try: + for chunk in x.chunk(chunks=num_chunks): + pred, conf, act = model(chunk, sr=sr, convert_to_freq=convert_to_freq, return_activations=True) + preds.append(pred) + confidence.append(conf) + activations.append(act) + except torch.cuda.OutOfMemoryError: + raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. " + "Please increase the number of chunks with option `-c`/`--chunks` " + "to reduce GPU memory usage.") + + preds = torch.cat(preds, dim=0) + confidence = torch.cat(confidence, dim=0) + activations = torch.cat(activations, dim=0) + + # compute timesteps + timesteps = torch.arange(preds.size(-1), device=x.device) * model.hop_size + + return timesteps, preds, confidence, activations + + +def predict(x: torch.Tensor, + sr: int, + step_size: float = 10., + model_name: str = "mir-1k", + reduction: str = "alwa", + num_chunks: int = 1, + convert_to_freq: bool = True, + inference_mode: bool = True, + no_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r"""Main prediction function. Args: - x (torch.Tensor): input audio tensor, - shape (num_channels, num_samples) or (batch_size, num_channels, num_samples) + x (torch.Tensor): input audio tensor, can be provided as a batch but should be mono, + shape (num_samples) or (batch_size, num_samples) sr (int, optional): sampling rate. If not specified, uses the current sampling rate of the model. - model: PESTO model. If a string is passed, it will load the model with the corresponding name. - Otherwise, the actual nn.Module will be used for doing predictions. - data_preprocessor: Module handling the data processing pipeline (waveform to CQT, cropping, etc.) step_size (float, optional): step size between each CQT frame in milliseconds. - If the data_preprocessor is passed, its value will be used instead. + If a `PESTO` object is passed as `model`, this will be ignored. + model_name: name of PESTO model. Can be a path to a custom PESTO checkpoint or the name of a standard model. reduction (str): reduction method for converting activation probabilities to log-frequencies. num_chunks (int): number of chunks to split the input audios in. Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage and prevent out-of-memory errors. convert_to_freq (bool): whether predictions should be converted to frequencies or not. - """ - # convert to mono - assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}" - batch_size = x.size(0) if x.ndim == 3 else None - x = x.mean(dim=-2) - - if data_preprocessor is None: - assert step_size is not None, \ - "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)" - data_preprocessor = load_dataprocessor(step_size=step_size / 1000., device=x.device) - - # If the sampling rate has changed, change the sampling rate accordingly - # It will automatically recompute the CQT kernels if needed - data_preprocessor.sampling_rate = sr - - if isinstance(model, str): - model = load_model(model, device=x.device) + inference_mode (bool): whether to run with `torch.inference_mode`. + no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored. - # apply model - cqt = data_preprocessor(x) - try: - activations = torch.cat([ - model(chunk) for chunk in cqt.chunk(chunks=num_chunks) - ]) - except torch.cuda.OutOfMemoryError: - raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. " - "Please increase the number of chunks with option `-c`/`--chunks` " - "to reduce GPU memory usage.") - - if batch_size: - total_batch_size, num_predictions = activations.size() - activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions) - - # shift activations as it should (PESTO predicts pitches up to an additive constant) - activations = activations.roll(model.abs_shift.cpu().item(), dims=-1) - - # convert model predictions to pitch values - pitch = reduce_activation(activations, reduction=reduction) - if convert_to_freq: - pitch = 440 * 2 ** ((pitch - 69) / 12) - - # for now, confidence is computed very naively just based on volume - confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch) - conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values - confidence = (confidence - conf_min) / (conf_max - conf_min) + Returns: + timesteps (torch.Tensor): timesteps corresponding to each pitch prediction, shape (num_timesteps) + preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps) + where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`) + confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1], + shape (batch_size?, num_timesteps) + activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim) + """ + # sanity checks + assert x.ndim <= 2, \ + f"Audio file should have shape (num_samples) or (batch_size, num_samples), but found shape {x.size()}." - timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size + inference_mode = inference_mode and no_grad + with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode): + model = load_model(model_name, step_size, sampling_rate=sr).to(x.device) + model.reduction = reduction - return timesteps, pitch, confidence, activations + return _predict(x, sr, model, num_chunks=num_chunks, convert_to_freq=convert_to_freq) def predict_from_files( @@ -97,8 +93,7 @@ def predict_from_files( export_format: Sequence[str] = ("csv",), no_convert_to_freq: bool = False, num_chunks: int = 1, - gpu: int = -1 -): + gpu: int = -1): r""" Args: @@ -107,12 +102,11 @@ def predict_from_files( output: step_size: hop length in milliseconds reduction: - export_format: + export_format (Sequence[str]): format to export the predictions to. + Currently formats supported are: ["csv", "npz", "json"]. no_convert_to_freq: whether convert output values to Hz or keep fractional MIDI pitches + num_chunks (int): number of chunks to divide the inputs into. Increase this value if you encounter OOM errors. gpu: index of GPU to use (-1 for CPU) - - Returns: - Pitch predictions, see `predict` for more details. """ if isinstance(audio_files, str): audio_files = [audio_files] @@ -122,37 +116,33 @@ def predict_from_files( gpu = -1 device = torch.device(f"cuda:{gpu:d}" if gpu >= 0 else "cpu") - # define data preprocessing - data_preprocessor = load_dataprocessor(step_size / 1000., device=device) - # define model - model = load_model(model_name, device=device) - predictions = None + model = load_model(model_name, step_size=step_size).to(device) + model.reduction = reduction pbar = tqdm(audio_files) - for file in pbar: - pbar.set_description(file) - # load audio file - try: - x, sr = torchaudio.load(file) - except Exception as e: - print(e, f"Skipping {file}...") - continue + with torch.inference_mode(): # here the purpose is to write results in disk, so there is no point storing gradients + for file in pbar: + pbar.set_description(file) - x = x.to(device) + # load audio file + try: + x, sr = torchaudio.load(file) + except Exception as e: + print(e, f"Skipping {file}...") + continue - # compute the predictions - predictions = predict(x, sr, model=model, data_preprocessor=data_preprocessor, reduction=reduction, - convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks) + x = x.mean(dim=0).to(device) # convert to mono then pass to the right device - output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0") - if output is not None: - os.makedirs(output, exist_ok=True) - output_file = os.path.join(output, os.path.basename(output_file)) + # compute the predictions + predictions = _predict(x, sr, model=model, convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks) - predictions = [p.cpu().numpy() for p in predictions] - for fmt in export_format: - export(fmt, output_file, *predictions) + output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0") + if output is not None: + os.makedirs(output, exist_ok=True) + output_file = os.path.join(output, os.path.basename(output_file)) - return predictions + predictions = [p.cpu().numpy() for p in predictions] + for fmt in export_format: + export(fmt, output_file, *predictions) diff --git a/pesto/data.py b/pesto/data.py index ea053726694aa78a6645a7d0c2223b02d81a2f2a..3044bbdf3fcfc0818b198555f7711623517d8e62 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -1,74 +1,95 @@ +from typing import Optional + import torch import torch.nn as nn -from .cqt import CQT +from .utils import HarmonicCQT + + +class ToLogMagnitude(nn.Module): + def __init__(self): + super(ToLogMagnitude, self).__init__() + self.eps = torch.finfo(torch.float32).eps + + def forward(self, x): + x = x.abs() + x.clamp_(min=self.eps).log10_().mul_(20) + return x + -class DataProcessor(nn.Module): +class Preprocessor(nn.Module): r""" Args: - step_size (float): step size between consecutive CQT frames (in SECONDS) + hop_size (float): step size between consecutive CQT frames (in milliseconds) """ def __init__(self, - step_size: float, - bins_per_semitone: int = 3, - device: torch.device = torch.device("cpu"), - **cqt_kwargs): - super(DataProcessor, self).__init__() - self.bins_per_semitone = bins_per_semitone - - # CQT-related stuff - self.cqt_kwargs = cqt_kwargs - self.cqt_kwargs["bins_per_octave"] = 12 * bins_per_semitone - self.cqt = None + hop_size: float, + sampling_rate: Optional[int] = None, + **hcqt_kwargs): + super(Preprocessor, self).__init__() + + # HCQT + self.hcqt_sr = None + self.hcqt_kernels = None + self.hop_size = hop_size + + self.hcqt_kwargs = hcqt_kwargs # log-magnitude - self.eps = torch.finfo(torch.float32).eps + self.to_log = ToLogMagnitude() - # cropping - self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5) - self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone + # register a dummy tensor to get implicit access to the module's device + self.register_buffer("_device", torch.zeros(()), persistent=False) - # handling different sampling rates - self._sampling_rate = None - self.step_size = step_size - self.device = device + # if the sampling rate is provided, instantiate the CQT kernels + if sampling_rate is not None: + self.hcqt_sr = sampling_rate + self._reset_hcqt_kernels() - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor: r""" Args: - x: audio waveform, any sampling rate, shape (num_samples) + x (torch.Tensor): audio waveform or batch of audio waveforms, any sampling rate, + shape (batch_size?, num_samples) + sr (int, optional): sampling rate Returns: - log-magnitude CQT, shape ( + torch.Tensor: log-magnitude CQT of batch of CQTs, + shape (batch_size?, num_timesteps, num_harmonics, num_freqs) """ - # compute CQT from input waveform, and invert dims for (batch_size, time_steps, freq_bins) - complex_cqt = torch.view_as_complex(self.cqt(x)).transpose(1, 2) - - # reshape and crop borders to fit training input shape - complex_cqt = complex_cqt[..., self.lowest_bin: self.highest_bin] - - # flatten eventual batch dimensions so that batched audios can be processed in parallel - complex_cqt = complex_cqt.flatten(0, 1).unsqueeze(1) + # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins) + # in other words, time becomes the batch dimension, enabling efficient processing for long audios. + complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2) + complex_cqt.squeeze_(0) # convert to dB - log_cqt = complex_cqt.abs().clamp_(min=self.eps).log10_().mul_(20) - return log_cqt - - def _init_cqt_layer(self, sr: int, hop_length: int): - self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device) + return self.to_log(complex_cqt) - @property - def sampling_rate(self) -> int: - return self._sampling_rate + def hcqt(self, audio: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor: + r"""Compute the Harmonic CQT of the input audio after eventually recreating the kernels + (in case the sampling rate has changed). - @sampling_rate.setter - def sampling_rate(self, sr: int): - if sr == self._sampling_rate: - return + Args: + audio (torch.Tensor): mono audio waveform, shape (batch_size, num_samples) + sr (int): sampling rate of the audio waveform. + If not specified, we assume it is the same as the previous processed audio waveform. - hop_length = int(self.step_size * sr + 0.5) - self._init_cqt_layer(sr, hop_length) - self._sampling_rate = sr + Returns: + torch.Tensor: Complex Harmonic CQT (HCQT) of the input, + shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2) + """ + # compute HCQT kernels if it does not exist or if the sampling rate has changed + if sr is not None and sr != self.hcqt_sr: + self.hcqt_sr = sr + self._reset_hcqt_kernels() + + return self.hcqt_kernels(audio) + + def _reset_hcqt_kernels(self) -> None: + hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5) + self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr, + hop_length=hop_length, + **self.hcqt_kwargs).to(self._device.device) diff --git a/pesto/loader.py b/pesto/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..a270615997197cfa4abdcfeef30de6a68d67be98 --- /dev/null +++ b/pesto/loader.py @@ -0,0 +1,51 @@ +import os +from typing import Optional + +import torch + +from .data import Preprocessor +from .model import PESTO, Resnet1d + + +def load_model(checkpoint: str, + step_size: float, + sampling_rate: Optional[int] = None) -> PESTO: + r"""Load a trained model from a checkpoint file. + See https://github.com/SonyCSLParis/pesto-full/blob/master/src/models/pesto.py for the structure of the checkpoint. + + Args: + checkpoint (str): path to the checkpoint or name of the checkpoint file (if using a provided checkpoint) + step_size (float): hop size in milliseconds + sampling_rate (int, optional): sampling rate of the audios. + If not provided, it can be inferred dynamically as well. + Returns: + PESTO: instance of PESTO model + """ + if os.path.exists(checkpoint): # handle user-provided checkpoints + model_path = checkpoint + else: + model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".ckpt") + if not os.path.exists(model_path): + raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.") + + # load checkpoint + checkpoint = torch.load(model_path, map_location=torch.device("cpu")) + hparams = checkpoint["hparams"] + hcqt_params = checkpoint["hcqt_params"] + state_dict = checkpoint["state_dict"] + + # instantiate preprocessor + preprocessor = Preprocessor(hop_size=step_size, sampling_rate=sampling_rate, **hcqt_params) + + # instantiate PESTO encoder + encoder = Resnet1d(**hparams["encoder"]) + + # instantiate main PESTO module and load its weights + model = PESTO(encoder, + preprocessor=preprocessor, + crop_kwargs=hparams["pitch_shift"], + reduction=hparams["reduction"]) + model.load_state_dict(state_dict, strict=False) + model.eval() + + return model diff --git a/pesto/main.py b/pesto/main.py index bc52df6cb0374c6798e23fdfe117e224c7983369..d117941f84eed1535970c16f5f4fa53029bf91f5 100644 --- a/pesto/main.py +++ b/pesto/main.py @@ -1,4 +1,4 @@ -from .parser import parse_args +from pesto.utils.parser import parse_args from .core import predict_from_files diff --git a/pesto/model.py b/pesto/model.py index 94337ec84570b1cd3d5aaf1609b06f3f49a30ae6..b9cf5eda18349e38a2d4de1f739798bd5b1df0e6 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -1,8 +1,15 @@ from functools import partial +from typing import Any, Mapping, Optional, Tuple, Union import torch import torch.nn as nn +from .utils import CropCQT +from .utils import reduce_activations + + +OUTPUT_TYPE = Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + class ToeplitzLinear(nn.Conv1d): def __init__(self, in_features, out_features): @@ -18,7 +25,7 @@ class ToeplitzLinear(nn.Conv1d): return super(ToeplitzLinear, self).forward(input.unsqueeze(-2)).squeeze(-2) -class PESTOEncoder(nn.Module): +class Resnet1d(nn.Module): """ Basic CNN similar to the one in Johannes Zeitler's report, but for longer HCQT input (always stride 1 in time) @@ -29,7 +36,8 @@ class PESTOEncoder(nn.Module): not over time (in order to work with variable length input). Outputs one channel with sigmoid activation. - Args (Defaults: BasicCNN by Johannes Zeitler but with 1 input channel): + Args (Defaults: BasicCNN by Johannes Zeitler but with 6 input channels): + n_chan_input: Number of input channels (harmonics in HCQT) n_chan_layers: Number of channels in the hidden layers (list) n_prefilt_layers: Number of repetitions of the prefiltering layer residual: If True, use residual connections for prefiltering (default: False) @@ -39,100 +47,191 @@ class PESTOEncoder(nn.Module): p_dropout: Dropout probability """ - def __init__( - self, - n_chan_layers=(20, 20, 10, 1), - n_prefilt_layers=1, - residual=False, - n_bins_in=216, - output_dim=128, - num_output_layers: int = 1 - ): - super(PESTOEncoder, self).__init__() - - activation_layer = partial(nn.LeakyReLU, negative_slope=0.3) - + def __init__(self, + n_chan_input=1, + n_chan_layers=(20, 20, 10, 1), + n_prefilt_layers=1, + prefilt_kernel_size=15, + residual=False, + n_bins_in=216, + output_dim=128, + activation_fn: str = "leaky", + a_lrelu=0.3, + p_dropout=0.2): + super(Resnet1d, self).__init__() + + self.hparams = dict(n_chan_input=n_chan_input, + n_chan_layers=n_chan_layers, + n_prefilt_layers=n_prefilt_layers, + prefilt_kernel_size=prefilt_kernel_size, + residual=residual, + n_bins_in=n_bins_in, + output_dim=output_dim, + activation_fn=activation_fn, + a_lrelu=a_lrelu, + p_dropout=p_dropout) + + if activation_fn == "relu": + activation_layer = nn.ReLU + elif activation_fn == "silu": + activation_layer = nn.SiLU + elif activation_fn == "leaky": + activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu) + else: + raise ValueError + + n_in = n_chan_input n_ch = n_chan_layers if len(n_ch) < 5: n_ch.append(1) - # Layer normalization over frequency - self.layernorm = nn.LayerNorm(normalized_shape=[1, n_bins_in]) + # Layer normalization over frequency and channels (harmonics of HCQT) + self.layernorm = nn.LayerNorm(normalized_shape=[n_in, n_bins_in]) # Prefiltering + prefilt_padding = prefilt_kernel_size // 2 self.conv1 = nn.Sequential( - nn.Conv1d(in_channels=1, out_channels=n_ch[0], kernel_size=15, padding=7, stride=1), - activation_layer() + nn.Conv1d(in_channels=n_in, + out_channels=n_ch[0], + kernel_size=prefilt_kernel_size, + padding=prefilt_padding, + stride=1), + activation_layer(), + nn.Dropout(p=p_dropout) ) self.n_prefilt_layers = n_prefilt_layers - self.prefilt_list = nn.ModuleList() - for p in range(1, n_prefilt_layers): - self.prefilt_list.append(nn.Sequential( - nn.Conv1d(in_channels=n_ch[0], out_channels=n_ch[0], kernel_size=15, padding=7, stride=1), - activation_layer() - )) + self.prefilt_layers = nn.ModuleList(*[ + nn.Sequential( + nn.Conv1d(in_channels=n_ch[0], + out_channels=n_ch[0], + kernel_size=prefilt_kernel_size, + padding=prefilt_padding, + stride=1), + activation_layer(), + nn.Dropout(p=p_dropout) + ) + for _ in range(n_prefilt_layers-1) + ]) self.residual = residual - self.conv2 = nn.Sequential( - nn.Conv1d( - in_channels=n_ch[0], - out_channels=n_ch[1], - kernel_size=1, - stride=1, - padding=0 - ), - activation_layer() - ) - - self.conv3 = nn.Sequential( - nn.Conv1d(in_channels=n_ch[1], out_channels=n_ch[2], kernel_size=1, padding=0, stride=1), - activation_layer() - ) - - self.conv4 = nn.Sequential( - nn.Conv1d(in_channels=n_ch[2], out_channels=n_ch[3], kernel_size=1, padding=0, stride=1), - activation_layer(), - nn.Dropout(), - nn.Conv1d(in_channels=n_ch[3], out_channels=n_ch[4], kernel_size=1, padding=0, stride=1) - ) + conv_layers = [] + for i in range(len(n_chan_layers)-1): + conv_layers.extend([ + nn.Conv1d(in_channels=n_ch[i], + out_channels=n_ch[i + 1], + kernel_size=1, + padding=0, + stride=1), + activation_layer(), + nn.Dropout(p=p_dropout) + ]) + self.conv_layers = nn.Sequential(*conv_layers) self.flatten = nn.Flatten(start_dim=1) - - layers = [] - pre_fc_dim = n_bins_in * n_ch[4] - for i in range(num_output_layers-1): - layers.extend([ - ToeplitzLinear(pre_fc_dim, pre_fc_dim), - activation_layer() - ]) - self.pre_fc = nn.Sequential(*layers) - self.fc = ToeplitzLinear(pre_fc_dim, output_dim) + self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim) self.final_norm = nn.Softmax(dim=-1) - self.register_buffer("abs_shift", torch.zeros((), dtype=torch.long), persistent=True) - def forward(self, x): r""" Args: x (torch.Tensor): shape (batch, channels, freq_bins) """ - x_norm = self.layernorm(x) + x = self.layernorm(x) - x = self.conv1(x_norm) + x = self.conv1(x) for p in range(0, self.n_prefilt_layers - 1): - prefilt_layer = self.prefilt_list[p] + prefilt_layer = self.prefilt_layers[p] if self.residual: x_new = prefilt_layer(x) x = x_new + x else: x = prefilt_layer(x) - conv2_lrelu = self.conv2(x) - conv3_lrelu = self.conv3(conv2_lrelu) - y_pred = self.conv4(conv3_lrelu) - y_pred = self.flatten(y_pred) - y_pred = self.pre_fc(y_pred) - y_pred = self.fc(y_pred) + x = self.conv_layers(x) + x = self.flatten(x) + + y_pred = self.fc(x) + return self.final_norm(y_pred) + + +class PESTO(nn.Module): + def __init__(self, + encoder: nn.Module, + preprocessor: nn.Module, + crop_kwargs: Optional[Mapping[str, Any]] = None, + reduction: str = "alwa"): + super(PESTO, self).__init__() + self.encoder = encoder + self.preprocessor = preprocessor + + # crop CQT + if crop_kwargs is None: + crop_kwargs = {} + self.crop_cqt = CropCQT(**crop_kwargs) + + self.reduction = reduction + + # constant shift to get absolute pitch from predictions + self.register_buffer('shift', torch.zeros((), dtype=torch.float), persistent=True) + + def forward(self, + audio_waveforms: torch.Tensor, + sr: Optional[int] = None, + convert_to_freq: bool = False, + return_activations: bool = False) -> OUTPUT_TYPE: + r""" + + Args: + audio_waveforms (torch.Tensor): mono audio waveform or batch of mono audio waveforms, + shape (batch_size?, num_samples) + sr (int, optional): sampling rate, defaults to the previously used sampling rate + convert_to_freq (bool): whether to convert the result to frequencies or return fractional semitones instead. + return_activations (bool): whether to return activations or pitch predictions only + + Returns: + preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps) + where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`) + confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1], + shape (batch_size?, num_timesteps) + activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim) + """ + batch_size = audio_waveforms.size(0) if audio_waveforms.ndim == 2 else None + x = self.preprocessor(audio_waveforms, sr=sr) + x = self.crop_cqt(x) # the CQT has to be cropped beforehand + + # for now, confidence is computed very naively just based on energy in the CQT + confidence = x.mean(dim=-2).max(dim=-1).values + conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values + confidence = (confidence - conf_min) / (conf_max - conf_min) + + # flatten batch_size and time_steps since anyway predictions are made on CQT frames independently + if batch_size: + x = x.flatten(0, 1) + + activations = self.encoder(x) + if batch_size: + activations = activations.view(batch_size, -1, activations.size(-1)) + + activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1) + + preds = reduce_activations(activations, reduction=self.reduction) + + if convert_to_freq: + preds = 440 * 2 ** ((preds - 69) / 12) + + if return_activations: + return preds, confidence, activations + + return preds, confidence + + @property + def bins_per_semitone(self) -> int: + return self.preprocessor.hcqt_kwargs["bins_per_semitone"] + + @property + def hop_size(self) -> float: + r"""Returns the hop size of the model (in milliseconds)""" + return self.preprocessor.hop_size diff --git a/pesto/utils.py b/pesto/utils.py deleted file mode 100644 index 46281ebc1f4fff05c5c10fd14bd056244482039b..0000000000000000000000000000000000000000 --- a/pesto/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from typing import Optional - -import torch - -from .config import model_args, cqt_args, bins_per_semitone -from .data import DataProcessor -from .model import PESTOEncoder - - -def load_dataprocessor(step_size, device: Optional[torch.device] = None): - return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device) - - -def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: - model = PESTOEncoder(**model_args).to(device) - model.eval() - - model_path = os.path.join(os.path.dirname(__file__), "weights", model_name + ".pth") - model.load_state_dict(torch.load(model_path, map_location=device)) - - return model - - -def reduce_activation(activations: torch.Tensor, reduction: str) -> torch.Tensor: - r"""Computes the pitch predictions from the activation outputs of the encoder. - Pitch predictions are returned in semitones, NOT in frequencies. - - Args: - activations: tensor of probability activations, shape (*, num_bins) - reduction: - - Returns: - torch.Tensor: pitch predictions, shape (*,) - """ - bps = bins_per_semitone - if reduction == "argmax": - pred = activations.argmax(dim=-1) - return pred.float() / bps - - all_pitches = (torch.arange(activations.size(-1), dtype=torch.float, device=activations.device)) / bps - if reduction == "mean": - return activations @ all_pitches - - if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe - center_bin = activations.argmax(dim=-1, keepdim=True) - window = torch.arange(-bps+1, bps, device=activations.device) - indices = window + center_bin - cropped_activations = activations.gather(-1, indices) - cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices) - return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1) - - raise ValueError diff --git a/pesto/utils/__init__.py b/pesto/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9be9b1b72be5579424e7411fe2291823838cd3d4 --- /dev/null +++ b/pesto/utils/__init__.py @@ -0,0 +1,4 @@ +from .crop_cqt import CropCQT +from .export import export +from .hcqt import HarmonicCQT +from .reduce_activations import reduce_activations \ No newline at end of file diff --git a/pesto/utils/crop_cqt.py b/pesto/utils/crop_cqt.py new file mode 100644 index 0000000000000000000000000000000000000000..65829c70e8ee90554acee1c027178f6292cd4230 --- /dev/null +++ b/pesto/utils/crop_cqt.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + + +class CropCQT(nn.Module): + def __init__(self, min_steps: int, max_steps: int): + super(CropCQT, self).__init__() + self.min_steps = min_steps + self.max_steps = max_steps + + # lower bin + self.lower_bin = self.max_steps + + def forward(self, spectrograms: torch.Tensor) -> torch.Tensor: + # WARNING: didn't check that it works, it may be dangerous + return spectrograms[..., self.max_steps: self.min_steps] + + # old implementation + batch_size, _, input_height = spectrograms.size() + + output_height = input_height - self.max_steps + self.min_steps + assert output_height > 0, \ + f"With input height {input_height:d} and output height {output_height:d}, impossible " \ + f"to have a range of {self.max_steps - self.min_steps:d} bins." + + return spectrograms[..., self.lower_bin: self.lower_bin + output_height] diff --git a/pesto/export.py b/pesto/utils/export.py similarity index 89% rename from pesto/export.py rename to pesto/utils/export.py index daab6a96f5d474f18e75cb914c3e7171ccdc3a49..70aeb54861b9bff67661fa58793be8e3230ae507 100644 --- a/pesto/export.py +++ b/pesto/utils/export.py @@ -41,11 +41,13 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1 bps = activations.shape[1] // 128 activations = activations[:, bps*lims[0]: bps*lims[1]] - activations = activations * confidence.unsqueeze(1) + activations = activations * confidence[:, None] plt.imshow(activations.T, aspect='auto', origin='lower', cmap='inferno', - extent=(timesteps[0], timesteps[-1]) + lims) + extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims) + plt.xlabel("Time (s)") + plt.ylabel("Pitch (semitones)") plt.title(output_file.rsplit('.', 2)[0]) plt.tight_layout() plt.savefig(output_file) diff --git a/pesto/cqt.py b/pesto/utils/hcqt.py similarity index 91% rename from pesto/cqt.py rename to pesto/utils/hcqt.py index 6fdfccb6372077fd1f94b645fded1178c90983d7..f7ff5610feb4b3c8d3bb29affc9a4ea5d2661538 100644 --- a/pesto/cqt.py +++ b/pesto/utils/hcqt.py @@ -354,3 +354,39 @@ class CQT(nn.Module): phase_real = torch.cos(torch.atan2(CQT_imag, CQT_real)) phase_imag = torch.sin(torch.atan2(CQT_imag, CQT_real)) return torch.stack((phase_real, phase_imag), -1) + + +class HarmonicCQT(nn.Module): + r"""Harmonic CQT layer, as described in Bittner et al. (20??)""" + def __init__( + self, + harmonics, + sr: int = 22050, + hop_length: int = 512, + fmin: float = 32.7, + fmax: Optional[float] = None, + bins_per_semitone: int = 1, + n_bins: int = 84, + center_bins: bool = True + ): + super(HarmonicCQT, self).__init__() + + if center_bins: + fmin = fmin / 2 ** ((bins_per_semitone - 1) / (24 * bins_per_semitone)) + + self.cqt_kernels = nn.ModuleList([ + CQT(sr=sr, hop_length=hop_length, fmin=h*fmin, fmax=fmax, n_bins=n_bins, + bins_per_octave=12*bins_per_semitone, output_format="Complex") + for h in harmonics + ]) + + def forward(self, audio_waveforms: torch.Tensor): + r"""Converts a batch of waveforms into a batch of HCQTs. + + Args: + audio_waveforms (torch.Tensor): Batch of waveforms, shape (batch_size, num_samples) + + Returns: + Harmonic CQT, shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2) + """ + return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_kernels], dim=1) diff --git a/pesto/parser.py b/pesto/utils/parser.py similarity index 100% rename from pesto/parser.py rename to pesto/utils/parser.py diff --git a/pesto/utils/reduce_activations.py b/pesto/utils/reduce_activations.py new file mode 100644 index 0000000000000000000000000000000000000000..b22dc508c2c0299b03c3038c9c50bf282bf1b76f --- /dev/null +++ b/pesto/utils/reduce_activations.py @@ -0,0 +1,36 @@ +import torch + + +def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor: + r""" + + Args: + activations: tensor of probability activations, shape (*, num_bins) + reduction (str): reduction method to compute pitch out of activations, + choose between "argmax", "mean", "alwa". + + Returns: + torch.Tensor: pitches as fractions of MIDI semitones, shape (*) + """ + device = activations.device + num_bins = activations.size(-1) + bps, r = divmod(num_bins, 128) + assert r == 0, f"Activations should have output size 128*bins_per_semitone, got {num_bins}." + + if reduction == "argmax": + pred = activations.argmax(dim=-1) + return pred.float() / bps + + all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps) + if reduction == "mean": + return torch.matmul(activations, all_pitches) + + if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe + center_bin = activations.argmax(dim=-1, keepdim=True) + window = torch.arange(1, 2 * bps, device=device) - bps # [-bps+1, -bps+2, ..., bps-2, bps-1] + indices = (center_bin + window).clip_(min=0, max=num_bins - 1) + cropped_activations = activations.gather(-1, indices) + cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices) + return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1) + + raise ValueError diff --git a/pesto/weights/mir-1k.ckpt b/pesto/weights/mir-1k.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..6fa22686ecb3f763c98d42f290527797da2bd3ae Binary files /dev/null and b/pesto/weights/mir-1k.ckpt differ diff --git a/pesto/weights/mir-1k.pth b/pesto/weights/mir-1k.pth deleted file mode 100644 index aadf78ceeff86ae4e86ac305053c5bcc9791ec7f..0000000000000000000000000000000000000000 Binary files a/pesto/weights/mir-1k.pth and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..b2fb7d05dcc35b499ede3b0562769602baaab66d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "pesto-pitch" +dynamic = ["version"] +authors = [ + {name = "Alain Riou", email = "alain.riou@sony.com"} +] +description = "Efficient pitch estimation with self-supervised learning" +readme = {file = "README.md", content-type = "text/markdown"} +requires-python = ">=3.8" +classifiers = [ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)', # If licence is provided must be on the repository + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', +] +dependencies = [ + 'numpy>=1.21.5', + 'scipy>=1.8.1', + 'tqdm>=4.66.1', + 'torch>=2.0.1', + 'torchaudio>=2.0.2' +] + +[project.optional-dependencies] +matplotlib = ["matplotlib"] +test = ["pytest"] + +[project.scripts] +pesto = "pesto.main:pesto" + +[project.urls] +source = "https://github.com/SonyCSLParis/pesto" + +[tool.pytest.ini_options] +testpaths = "tests/" + +[tool.setuptools.dynamic] +version = {attr = "pesto.__version__"} + +[tool.setuptools.package-data] +pesto = ["weights/*.ckpt"] diff --git a/setup.py b/setup.py deleted file mode 100644 index 6eafee00df5b97574f0ba784f53a51978fc303c3..0000000000000000000000000000000000000000 --- a/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -from pathlib import Path -from setuptools import setup, find_packages - -def get_readme_text(): - root_dir = Path(__file__).parent - readme_path = root_dir / "README.md" - return readme_path.read_text() - - -setup( - name='pesto-pitch', - version='0.1.0', - description='Efficient pitch estimation with self-supervised learning', - long_description=get_readme_text(), - long_description_content_type='text/markdown', - author='Alain Riou', - url='https://github.com/SonyCSLParis/pesto', - license='LGPL-3.0', - packages=find_packages(), - include_package_data=True, - package_data={ - 'pesto': ['weights/*'], # To include the .pth - }, - install_requires=[ - 'numpy>=1.21.5', - 'scipy>=1.8.1', - 'tqdm>=4.66.1', - 'torch>=2.0.1', - 'torchaudio>=2.0.2' - ], - classifiers=[ - # 'Development Status :: 1 - Planning', - # 'Development Status :: 2 - Pre-Alpha', - # 'Development Status :: 3 - Alpha', - # 'Development Status :: 4 - Beta', - 'Development Status :: 5 - Production/Stable', - # 'Development Status :: 6 - Mature', - # 'Development Status :: 7 - Inactive', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)', # If licence is provided must be on the repository - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - ], - entry_points={ - 'console_scripts': [ - 'pesto=pesto.main:pesto', # For the command line, executes function pesto() in pesto/main as 'pesto' - ], - }, - python_requires='>=3.8', -) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/audios/example.wav b/tests/audios/example.wav new file mode 100644 index 0000000000000000000000000000000000000000..9f592085cff7d80a91c27b0f5920df629b36cd36 Binary files /dev/null and b/tests/audios/example.wav differ diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index 1f4096eab0b3ed39a4b697104f2f9d33ddad8dc8..0000000000000000000000000000000000000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest -import pesto - - -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, True) # add assertion here - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..ef40055cf51781976bb5f07757cbf48406ecde78 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,41 @@ +import glob +import itertools +import os + +import pytest + +import torch + + +AUDIOS_DIR = os.path.join(os.path.dirname(__file__), "audios") + + +@pytest.mark.parametrize("file, fmt, convert_to_freq", + itertools.product(glob.glob(AUDIOS_DIR + "/*.wav"), ["csv", "npz", "png"], [True, False])) +def test_cli(file, fmt, convert_to_freq): + if convert_to_freq: + suffix = ".f0." + fmt + option = "" + else: + suffix = ".semitones." + fmt + option = " -F" + os.system(f"pesto {file} --export_format " + fmt + option) + out_file = file.rsplit('.', 1)[0] + suffix + assert os.path.isfile(out_file) + os.unlink(out_file) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") +@pytest.mark.parametrize("file, fmt, convert_to_freq", + itertools.product(glob.glob(AUDIOS_DIR + "/*.wav"), ["csv", "npz", "png"], [True, False])) +def test_cli_gpu(file, fmt, convert_to_freq): + if convert_to_freq: + suffix = ".f0." + fmt + option = "" + else: + suffix = ".semitones." + fmt + option = " -F" + os.system(f"pesto {file} --gpu 0 --export_format " + fmt + option) + out_file = file.rsplit('.', 1)[0] + suffix + assert os.path.isfile(out_file) + os.unlink(out_file) diff --git a/tests/test_performances.py b/tests/test_performances.py new file mode 100644 index 0000000000000000000000000000000000000000..9aed8de0b4172658236ddbfce55f15549346b7d2 --- /dev/null +++ b/tests/test_performances.py @@ -0,0 +1,23 @@ +import itertools + +import pytest + +import torch + +from pesto import load_model +from .utils import generate_synth_data + + +@pytest.fixture +def model(): + return load_model('mir-1k', step_size=10.) + + +@pytest.mark.parametrize('pitch, sr, reduction', + itertools.product(range(50, 80), [16000, 44100, 48000], ["argmax", "alwa"])) +def test_performances(model, pitch, sr, reduction): + x = generate_synth_data(pitch, sr=sr) + + preds, conf = model(x, sr=sr, return_activations=False) + + torch.testing.assert_close(preds, torch.full_like(preds, pitch), atol=0.1, rtol=0.1) diff --git a/tests/test_predict.py b/tests/test_predict.py deleted file mode 100644 index 464090415c47109523e91779d4f40e19495c9cf1..0000000000000000000000000000000000000000 --- a/tests/test_predict.py +++ /dev/null @@ -1 +0,0 @@ -# TODO diff --git a/tests/test_shape.py b/tests/test_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..eed58bb8e4179c5f91c242fefe4e398f574512c6 --- /dev/null +++ b/tests/test_shape.py @@ -0,0 +1,97 @@ +import itertools + +import pytest + +import torch + +from pesto import load_model, predict +from .utils import generate_synth_data + + +@pytest.fixture +def model(): + return load_model('mir-1k', step_size=10.) + + +@pytest.fixture +def synth_data_16k(): + return generate_synth_data(pitch=69, duration=5., sr=16000), 16000 + + +@pytest.mark.parametrize('reduction', ["argmax", "mean", "alwa"]) +def test_shape_no_batch(model, synth_data_16k, reduction): + x, sr = synth_data_16k + + model.reduction = reduction + + num_samples = x.size(-1) + + num_timesteps = int(num_samples * 1000 / (model.hop_size * sr)) + 1 + + preds, conf, activations = model(x, sr=sr, return_activations=True) + + assert preds.shape == (num_timesteps,) + assert conf.shape == (num_timesteps,) + assert activations.shape == (num_timesteps, 128 * model.bins_per_semitone) + + +@pytest.mark.parametrize('sr, reduction', + itertools.product([16000, 44100, 48000], ["argmax", "mean", "alwa"])) +def test_shape_batch(model, sr, reduction): + model.reduction = reduction + + batch_size = 13 + + batch = torch.stack([ + generate_synth_data(pitch=p, duration=5., sr=sr) + for p in range(50, 50+batch_size) + ]) + + num_timesteps = int(batch.size(-1) * 1000 / (model.hop_size * sr)) + 1 + + preds, conf, activations = model(batch, sr=sr, return_activations=True) + + assert preds.shape == (batch_size, num_timesteps) + assert conf.shape == (batch_size, num_timesteps) + assert activations.shape == (batch_size, num_timesteps, 128 * model.bins_per_semitone) + + +@pytest.mark.parametrize('step_size, reduction', + itertools.product([10., 20., 50., 100], ["argmax", "mean", "alwa"])) +def test_predict_shape_no_batch(synth_data_16k, step_size, reduction): + x, sr = synth_data_16k + + num_samples = x.size(-1) + + num_timesteps = int(num_samples * 1000 / (step_size * sr)) + 1 + + timesteps, preds, conf, activations = predict(x, + sr, + step_size=step_size, + reduction=reduction) + + assert timesteps.shape == (num_timesteps,) + assert preds.shape == (num_timesteps,) + assert conf.shape == (num_timesteps,) + + +@pytest.mark.parametrize('sr, step_size, reduction', + itertools.product([16000, 44100, 48000], [10., 20., 50., 100.], ["argmax", "mean", "alwa"])) +def test_predict_shape_batch(sr, step_size, reduction): + batch_size = 13 + + batch = torch.stack([ + generate_synth_data(pitch=p, duration=5., sr=sr) + for p in range(50, 50+batch_size) + ]) + + num_timesteps = int(batch.size(-1) * 1000 / (step_size * sr)) + 1 + + timesteps, preds, conf, activations = predict(batch, + sr=sr, + step_size=step_size, + reduction=reduction) + + assert timesteps.shape == (num_timesteps,) + assert preds.shape == (batch_size, num_timesteps) + assert conf.shape == (batch_size, num_timesteps) diff --git a/tests/test_timesteps.py b/tests/test_timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..546ee09beb2b00cfd3a0c316c88d5167c69d9c72 --- /dev/null +++ b/tests/test_timesteps.py @@ -0,0 +1,18 @@ +import pytest + +import torch + +from pesto import predict +from .utils import generate_synth_data + + +@pytest.fixture +def synth_data_16k(): + return generate_synth_data(pitch=69, duration=5., sr=16000), 16000 + + +@pytest.mark.parametrize('step_size', [10., 20., 50., 100]) +def test_build_timesteps(synth_data_16k, step_size): + timesteps, *_ = predict(*synth_data_16k, step_size=step_size) + diff = timesteps[1:] - timesteps[:-1] + torch.testing.assert_close(diff, torch.full_like(diff, step_size)) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71210e340044aa7e463908c841b45d8e719d457e --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,20 @@ +import torch + + +def mid_to_hz(pitch: int): + return 440 * 2 ** ((pitch - 69) / 12) + + +def generate_synth_data(pitch: int, num_harmonics: int = 5, duration=2, sr=16000): + f0 = mid_to_hz(pitch) + t = torch.arange(0, duration, 1/sr) + harmonics = torch.stack([ + torch.cos(2 * torch.pi * k * f0 * t + torch.rand(())) + for k in range(1, num_harmonics+1) + ], dim=1) + # volume = torch.randn(()) * torch.arange(num_harmonics).neg().div(0.5).exp() + volume = torch.rand(num_harmonics) + volume[0] = 1 + volume *= torch.randn(()) + audio = torch.sum(volume * harmonics, dim=1) + return audio \ No newline at end of file