Skip to content
Snippets Groups Projects
Unverified Commit cb7c420c authored by Alain Riou's avatar Alain Riou Committed by GitHub
Browse files

Merge pull request #26 from SonyCSLParis/dev

v1.0.0
parents a709b8f5 4b099e7b
No related branches found
No related tags found
No related merge requests found
Showing with 601 additions and 291 deletions
......@@ -2,7 +2,7 @@ name: Test Workflow
on:
pull_request:
branches: [ "master" ]
branches: [ "dev" ]
jobs:
build:
......@@ -21,7 +21,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install -y libsox-dev
python -m pip install --upgrade pip
python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
python -m pip install matplotlib
python -m pip install pytest
python -m pip install .
- name: Test with pytest
......
.idea
*/__pycache__
**/__pycache__
*.egg-info/
.DS_Store
dist/
build/
**/*.csv
**/*.mid
**/*.png
# Changelog
## v1.0.0
- Change API under the hood to make it more object-oriented
- store all utilities inside a `PESTO` object that is a subclass of `nn.Module`
- make the API compatible with the checkpoints generated by the training repo
- add tests
- replace `setup.py` by `pyproject.toml`
- fix a few issues
- improve README and documentation
## v0.1.1
- solve issue when exporting in PNG
- solve device issue when changing sampling rate (#17)
## v0.1.0 - 2023-10-17
Initial version
\ No newline at end of file
......@@ -22,6 +22,8 @@ This repository is implemented in [PyTorch](https://pytorch.org/) and has the fo
- [torchaudio](https://pytorch.org/audio/stable/) for audio loading
- `matplotlib` for exporting pitch predictions as images (optional)
**Warning:** If installing in a clean environment, it may be safer to first install PyTorch [the recommended way](https://pytorch.org/get-started/locally/) before PESTO.
## Usage
### Command-line interface
......@@ -107,48 +109,64 @@ import pesto
# predict the pitch of your audio tensors directly within your own Python code
x, sr = torchaudio.load("my_file.wav")
timesteps, pitch, confidence, activations = pesto.predict(x, sr, step_size=10.)
x = x.mean(dim=0) # PESTO takes mono audio as input
timesteps, pitch, confidence, activations = pesto.predict(x, sr)
# or predict using your own custom checkpoint
predictions = pesto.predict(x, sr, model_name="/path/to/checkpoint.ckpt")
# You can also predict pitches from audio files directly
pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], step_size=10., export_format=["csv"])
pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], export_format=["csv"])
```
`pesto.predict` supports batched inputs, which should then have shape `(batch_size, num_samples)`.
**Warning:** If you forget to convert a stereo audio in mono, channels will be treated as batch dimensions and you will
get predictions for each channel separately.
#### Advanced usage
If not provided, `pesto.predict` will first load the CQT kernels and the model before performing
`pesto.predict` will first load the CQT kernels and the model before performing
any pitch estimation. If you want to process a significant number of files, calling `predict` several times will then
re-initialize the same model for each tensor.
To avoid this time-consuming step, one can manually instantiate the model and data processor, then pass them directly
as args to the `predict` function. To do so, one has to use the underlying methods from `pesto.utils`:
To avoid this time-consuming step, one can manually instantiate the model with `load_model`,
then call the forward method of the model instead:
```python
import torch
from pesto import predict
from pesto.utils import load_model, load_dataprocessor
from pesto import load_model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model("mir-1k", device=device)
data_processor = load_dataprocessor(step_size=0.01, device=device)
pesto_model = load_model("mir-1k", step_size=20.).to(device)
for x, sr in ...:
data_processor.sampling_rate = sr # The data_processor handles waveform->CQT conversion so it must know the sampling rate
predictions = predict(x, sr, model=model)
x = x.to(device)
predictions, confidence, activations = pesto_model(x, sr)
...
```
Note that when passing a list of files to `pesto.predict_from_files(...)` or the CLI directly, the model is loaded only
once so you don't have to bother with that in general.
#### Batched pitch estimation
Note that the `PESTO` object returned by `load_model` is a subclass of `nn.Module`
and its `forward` method also supports batched inputs.
One can therefore easily integrate PESTO within their own architecture by doing:
```python
import torch
import torch.nn as nn
from pesto import load_model
By default, the function `pesto.predict` takes an audio waveform represented as a Tensor object of shape `(num_channels, num_samples)`.
However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications.
`pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly.
class MyGreatModel(nn.Module):
def __init__(self, step_size, sr=44100, *args, **kwargs):
super(MyGreatModel, self).__init__()
self.f0_estimator = load_model("mir-1k", step_size, sampling_rate=sr)
...
Note that batched predictions are available only from the Python API and not from the CLI because:
- handling audios of different lengths is annoying, I don't want to bother with that
- when estimating pitch on
def forward(self, x):
with torch.no_grad():
f0, conf = self.f0_estimator(x, convert_to_freq=True, return_activations=False)
...
```
## Performances
......
from .core import predict, predict_from_files
from .core import load_model, predict, predict_from_files
__version__ = '1.0.0'
import os
import warnings
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Tuple, Union
import torch
import torchaudio
from tqdm import tqdm
from .utils import load_model, load_dataprocessor, reduce_activation
from .export import export
from .loader import load_model
from .model import PESTO
from .utils import export
@torch.inference_mode()
def predict(
x: torch.Tensor,
sr: Optional[int] = None,
model: Union[torch.nn.Module, str] = "mir-1k",
data_preprocessor=None,
step_size: Optional[float] = None,
reduction: str = "argmax",
def _predict(x: torch.Tensor,
sr: int,
model: PESTO,
num_chunks: int = 1,
convert_to_freq: bool = False
):
convert_to_freq: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
preds, confidence, activations = [], [], []
try:
for chunk in x.chunk(chunks=num_chunks):
pred, conf, act = model(chunk, sr=sr, convert_to_freq=convert_to_freq, return_activations=True)
preds.append(pred)
confidence.append(conf)
activations.append(act)
except torch.cuda.OutOfMemoryError:
raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
"Please increase the number of chunks with option `-c`/`--chunks` "
"to reduce GPU memory usage.")
preds = torch.cat(preds, dim=0)
confidence = torch.cat(confidence, dim=0)
activations = torch.cat(activations, dim=0)
# compute timesteps
timesteps = torch.arange(preds.size(-1), device=x.device) * model.hop_size
return timesteps, preds, confidence, activations
def predict(x: torch.Tensor,
sr: int,
step_size: float = 10.,
model_name: str = "mir-1k",
reduction: str = "alwa",
num_chunks: int = 1,
convert_to_freq: bool = True,
inference_mode: bool = True,
no_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Main prediction function.
Args:
x (torch.Tensor): input audio tensor,
shape (num_channels, num_samples) or (batch_size, num_channels, num_samples)
x (torch.Tensor): input audio tensor, can be provided as a batch but should be mono,
shape (num_samples) or (batch_size, num_samples)
sr (int, optional): sampling rate. If not specified, uses the current sampling rate of the model.
model: PESTO model. If a string is passed, it will load the model with the corresponding name.
Otherwise, the actual nn.Module will be used for doing predictions.
data_preprocessor: Module handling the data processing pipeline (waveform to CQT, cropping, etc.)
step_size (float, optional): step size between each CQT frame in milliseconds.
If the data_preprocessor is passed, its value will be used instead.
If a `PESTO` object is passed as `model`, this will be ignored.
model_name: name of PESTO model. Can be a path to a custom PESTO checkpoint or the name of a standard model.
reduction (str): reduction method for converting activation probabilities to log-frequencies.
num_chunks (int): number of chunks to split the input audios in.
Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage
and prevent out-of-memory errors.
convert_to_freq (bool): whether predictions should be converted to frequencies or not.
"""
# convert to mono
assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}"
batch_size = x.size(0) if x.ndim == 3 else None
x = x.mean(dim=-2)
if data_preprocessor is None:
assert step_size is not None, \
"If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)"
data_preprocessor = load_dataprocessor(step_size=step_size / 1000., device=x.device)
# If the sampling rate has changed, change the sampling rate accordingly
# It will automatically recompute the CQT kernels if needed
data_preprocessor.sampling_rate = sr
inference_mode (bool): whether to run with `torch.inference_mode`.
no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored.
if isinstance(model, str):
model = load_model(model, device=x.device)
# apply model
cqt = data_preprocessor(x)
try:
activations = torch.cat([
model(chunk) for chunk in cqt.chunk(chunks=num_chunks)
])
except torch.cuda.OutOfMemoryError:
raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
"Please increase the number of chunks with option `-c`/`--chunks` "
"to reduce GPU memory usage.")
if batch_size:
total_batch_size, num_predictions = activations.size()
activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions)
# shift activations as it should (PESTO predicts pitches up to an additive constant)
activations = activations.roll(model.abs_shift.cpu().item(), dims=-1)
# convert model predictions to pitch values
pitch = reduce_activation(activations, reduction=reduction)
if convert_to_freq:
pitch = 440 * 2 ** ((pitch - 69) / 12)
# for now, confidence is computed very naively just based on volume
confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch)
conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
confidence = (confidence - conf_min) / (conf_max - conf_min)
Returns:
timesteps (torch.Tensor): timesteps corresponding to each pitch prediction, shape (num_timesteps)
preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
shape (batch_size?, num_timesteps)
activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
"""
# sanity checks
assert x.ndim <= 2, \
f"Audio file should have shape (num_samples) or (batch_size, num_samples), but found shape {x.size()}."
timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size
inference_mode = inference_mode and no_grad
with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode):
model = load_model(model_name, step_size, sampling_rate=sr).to(x.device)
model.reduction = reduction
return timesteps, pitch, confidence, activations
return _predict(x, sr, model, num_chunks=num_chunks, convert_to_freq=convert_to_freq)
def predict_from_files(
......@@ -97,8 +93,7 @@ def predict_from_files(
export_format: Sequence[str] = ("csv",),
no_convert_to_freq: bool = False,
num_chunks: int = 1,
gpu: int = -1
):
gpu: int = -1):
r"""
Args:
......@@ -107,12 +102,11 @@ def predict_from_files(
output:
step_size: hop length in milliseconds
reduction:
export_format:
export_format (Sequence[str]): format to export the predictions to.
Currently formats supported are: ["csv", "npz", "json"].
no_convert_to_freq: whether convert output values to Hz or keep fractional MIDI pitches
num_chunks (int): number of chunks to divide the inputs into. Increase this value if you encounter OOM errors.
gpu: index of GPU to use (-1 for CPU)
Returns:
Pitch predictions, see `predict` for more details.
"""
if isinstance(audio_files, str):
audio_files = [audio_files]
......@@ -122,14 +116,13 @@ def predict_from_files(
gpu = -1
device = torch.device(f"cuda:{gpu:d}" if gpu >= 0 else "cpu")
# define data preprocessing
data_preprocessor = load_dataprocessor(step_size / 1000., device=device)
# define model
model = load_model(model_name, device=device)
predictions = None
model = load_model(model_name, step_size=step_size).to(device)
model.reduction = reduction
pbar = tqdm(audio_files)
with torch.inference_mode(): # here the purpose is to write results in disk, so there is no point storing gradients
for file in pbar:
pbar.set_description(file)
......@@ -140,11 +133,10 @@ def predict_from_files(
print(e, f"Skipping {file}...")
continue
x = x.to(device)
x = x.mean(dim=0).to(device) # convert to mono then pass to the right device
# compute the predictions
predictions = predict(x, sr, model=model, data_preprocessor=data_preprocessor, reduction=reduction,
convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
predictions = _predict(x, sr, model=model, convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0")
if output is not None:
......@@ -154,5 +146,3 @@ def predict_from_files(
predictions = [p.cpu().numpy() for p in predictions]
for fmt in export_format:
export(fmt, output_file, *predictions)
return predictions
from typing import Optional
import torch
import torch.nn as nn
from .cqt import CQT
from .utils import HarmonicCQT
class ToLogMagnitude(nn.Module):
def __init__(self):
super(ToLogMagnitude, self).__init__()
self.eps = torch.finfo(torch.float32).eps
def forward(self, x):
x = x.abs()
x.clamp_(min=self.eps).log10_().mul_(20)
return x
class DataProcessor(nn.Module):
class Preprocessor(nn.Module):
r"""
Args:
step_size (float): step size between consecutive CQT frames (in SECONDS)
hop_size (float): step size between consecutive CQT frames (in milliseconds)
"""
def __init__(self,
step_size: float,
bins_per_semitone: int = 3,
device: torch.device = torch.device("cpu"),
**cqt_kwargs):
super(DataProcessor, self).__init__()
self.bins_per_semitone = bins_per_semitone
# CQT-related stuff
self.cqt_kwargs = cqt_kwargs
self.cqt_kwargs["bins_per_octave"] = 12 * bins_per_semitone
self.cqt = None
hop_size: float,
sampling_rate: Optional[int] = None,
**hcqt_kwargs):
super(Preprocessor, self).__init__()
# HCQT
self.hcqt_sr = None
self.hcqt_kernels = None
self.hop_size = hop_size
self.hcqt_kwargs = hcqt_kwargs
# log-magnitude
self.eps = torch.finfo(torch.float32).eps
self.to_log = ToLogMagnitude()
# cropping
self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# register a dummy tensor to get implicit access to the module's device
self.register_buffer("_device", torch.zeros(()), persistent=False)
# handling different sampling rates
self._sampling_rate = None
self.step_size = step_size
self.device = device
# if the sampling rate is provided, instantiate the CQT kernels
if sampling_rate is not None:
self.hcqt_sr = sampling_rate
self._reset_hcqt_kernels()
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
r"""
Args:
x: audio waveform, any sampling rate, shape (num_samples)
x (torch.Tensor): audio waveform or batch of audio waveforms, any sampling rate,
shape (batch_size?, num_samples)
sr (int, optional): sampling rate
Returns:
log-magnitude CQT, shape (
torch.Tensor: log-magnitude CQT of batch of CQTs,
shape (batch_size?, num_timesteps, num_harmonics, num_freqs)
"""
# compute CQT from input waveform, and invert dims for (batch_size, time_steps, freq_bins)
complex_cqt = torch.view_as_complex(self.cqt(x)).transpose(1, 2)
# reshape and crop borders to fit training input shape
complex_cqt = complex_cqt[..., self.lowest_bin: self.highest_bin]
# flatten eventual batch dimensions so that batched audios can be processed in parallel
complex_cqt = complex_cqt.flatten(0, 1).unsqueeze(1)
# compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins)
# in other words, time becomes the batch dimension, enabling efficient processing for long audios.
complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2)
complex_cqt.squeeze_(0)
# convert to dB
log_cqt = complex_cqt.abs().clamp_(min=self.eps).log10_().mul_(20)
return log_cqt
def _init_cqt_layer(self, sr: int, hop_length: int):
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device)
return self.to_log(complex_cqt)
@property
def sampling_rate(self) -> int:
return self._sampling_rate
def hcqt(self, audio: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
r"""Compute the Harmonic CQT of the input audio after eventually recreating the kernels
(in case the sampling rate has changed).
@sampling_rate.setter
def sampling_rate(self, sr: int):
if sr == self._sampling_rate:
return
Args:
audio (torch.Tensor): mono audio waveform, shape (batch_size, num_samples)
sr (int): sampling rate of the audio waveform.
If not specified, we assume it is the same as the previous processed audio waveform.
hop_length = int(self.step_size * sr + 0.5)
self._init_cqt_layer(sr, hop_length)
self._sampling_rate = sr
Returns:
torch.Tensor: Complex Harmonic CQT (HCQT) of the input,
shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
"""
# compute HCQT kernels if it does not exist or if the sampling rate has changed
if sr is not None and sr != self.hcqt_sr:
self.hcqt_sr = sr
self._reset_hcqt_kernels()
return self.hcqt_kernels(audio)
def _reset_hcqt_kernels(self) -> None:
hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5)
self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr,
hop_length=hop_length,
**self.hcqt_kwargs).to(self._device.device)
import os
from typing import Optional
import torch
from .data import Preprocessor
from .model import PESTO, Resnet1d
def load_model(checkpoint: str,
step_size: float,
sampling_rate: Optional[int] = None) -> PESTO:
r"""Load a trained model from a checkpoint file.
See https://github.com/SonyCSLParis/pesto-full/blob/master/src/models/pesto.py for the structure of the checkpoint.
Args:
checkpoint (str): path to the checkpoint or name of the checkpoint file (if using a provided checkpoint)
step_size (float): hop size in milliseconds
sampling_rate (int, optional): sampling rate of the audios.
If not provided, it can be inferred dynamically as well.
Returns:
PESTO: instance of PESTO model
"""
if os.path.exists(checkpoint): # handle user-provided checkpoints
model_path = checkpoint
else:
model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".ckpt")
if not os.path.exists(model_path):
raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.")
# load checkpoint
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
hparams = checkpoint["hparams"]
hcqt_params = checkpoint["hcqt_params"]
state_dict = checkpoint["state_dict"]
# instantiate preprocessor
preprocessor = Preprocessor(hop_size=step_size, sampling_rate=sampling_rate, **hcqt_params)
# instantiate PESTO encoder
encoder = Resnet1d(**hparams["encoder"])
# instantiate main PESTO module and load its weights
model = PESTO(encoder,
preprocessor=preprocessor,
crop_kwargs=hparams["pitch_shift"],
reduction=hparams["reduction"])
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
from .parser import parse_args
from pesto.utils.parser import parse_args
from .core import predict_from_files
......
from functools import partial
from typing import Any, Mapping, Optional, Tuple, Union
import torch
import torch.nn as nn
from .utils import CropCQT
from .utils import reduce_activations
OUTPUT_TYPE = Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
class ToeplitzLinear(nn.Conv1d):
def __init__(self, in_features, out_features):
......@@ -18,7 +25,7 @@ class ToeplitzLinear(nn.Conv1d):
return super(ToeplitzLinear, self).forward(input.unsqueeze(-2)).squeeze(-2)
class PESTOEncoder(nn.Module):
class Resnet1d(nn.Module):
"""
Basic CNN similar to the one in Johannes Zeitler's report,
but for longer HCQT input (always stride 1 in time)
......@@ -29,7 +36,8 @@ class PESTOEncoder(nn.Module):
not over time (in order to work with variable length input).
Outputs one channel with sigmoid activation.
Args (Defaults: BasicCNN by Johannes Zeitler but with 1 input channel):
Args (Defaults: BasicCNN by Johannes Zeitler but with 6 input channels):
n_chan_input: Number of input channels (harmonics in HCQT)
n_chan_layers: Number of channels in the hidden layers (list)
n_prefilt_layers: Number of repetitions of the prefiltering layer
residual: If True, use residual connections for prefiltering (default: False)
......@@ -39,100 +47,191 @@ class PESTOEncoder(nn.Module):
p_dropout: Dropout probability
"""
def __init__(
self,
def __init__(self,
n_chan_input=1,
n_chan_layers=(20, 20, 10, 1),
n_prefilt_layers=1,
prefilt_kernel_size=15,
residual=False,
n_bins_in=216,
output_dim=128,
num_output_layers: int = 1
):
super(PESTOEncoder, self).__init__()
activation_layer = partial(nn.LeakyReLU, negative_slope=0.3)
activation_fn: str = "leaky",
a_lrelu=0.3,
p_dropout=0.2):
super(Resnet1d, self).__init__()
self.hparams = dict(n_chan_input=n_chan_input,
n_chan_layers=n_chan_layers,
n_prefilt_layers=n_prefilt_layers,
prefilt_kernel_size=prefilt_kernel_size,
residual=residual,
n_bins_in=n_bins_in,
output_dim=output_dim,
activation_fn=activation_fn,
a_lrelu=a_lrelu,
p_dropout=p_dropout)
if activation_fn == "relu":
activation_layer = nn.ReLU
elif activation_fn == "silu":
activation_layer = nn.SiLU
elif activation_fn == "leaky":
activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu)
else:
raise ValueError
n_in = n_chan_input
n_ch = n_chan_layers
if len(n_ch) < 5:
n_ch.append(1)
# Layer normalization over frequency
self.layernorm = nn.LayerNorm(normalized_shape=[1, n_bins_in])
# Layer normalization over frequency and channels (harmonics of HCQT)
self.layernorm = nn.LayerNorm(normalized_shape=[n_in, n_bins_in])
# Prefiltering
prefilt_padding = prefilt_kernel_size // 2
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
activation_layer()
nn.Conv1d(in_channels=n_in,
out_channels=n_ch[0],
kernel_size=prefilt_kernel_size,
padding=prefilt_padding,
stride=1),
activation_layer(),
nn.Dropout(p=p_dropout)
)
self.n_prefilt_layers = n_prefilt_layers
self.prefilt_list = nn.ModuleList()
for p in range(1, n_prefilt_layers):
self.prefilt_list.append(nn.Sequential(
nn.Conv1d(in_channels=n_ch[0], out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
activation_layer()
))
self.prefilt_layers = nn.ModuleList(*[
nn.Sequential(
nn.Conv1d(in_channels=n_ch[0],
out_channels=n_ch[0],
kernel_size=prefilt_kernel_size,
padding=prefilt_padding,
stride=1),
activation_layer(),
nn.Dropout(p=p_dropout)
)
for _ in range(n_prefilt_layers-1)
])
self.residual = residual
self.conv2 = nn.Sequential(
nn.Conv1d(
in_channels=n_ch[0],
out_channels=n_ch[1],
conv_layers = []
for i in range(len(n_chan_layers)-1):
conv_layers.extend([
nn.Conv1d(in_channels=n_ch[i],
out_channels=n_ch[i + 1],
kernel_size=1,
stride=1,
padding=0
),
activation_layer()
)
self.conv3 = nn.Sequential(
nn.Conv1d(in_channels=n_ch[1], out_channels=n_ch[2], kernel_size=1, padding=0, stride=1),
activation_layer()
)
self.conv4 = nn.Sequential(
nn.Conv1d(in_channels=n_ch[2], out_channels=n_ch[3], kernel_size=1, padding=0, stride=1),
padding=0,
stride=1),
activation_layer(),
nn.Dropout(),
nn.Conv1d(in_channels=n_ch[3], out_channels=n_ch[4], kernel_size=1, padding=0, stride=1)
)
nn.Dropout(p=p_dropout)
])
self.conv_layers = nn.Sequential(*conv_layers)
self.flatten = nn.Flatten(start_dim=1)
layers = []
pre_fc_dim = n_bins_in * n_ch[4]
for i in range(num_output_layers-1):
layers.extend([
ToeplitzLinear(pre_fc_dim, pre_fc_dim),
activation_layer()
])
self.pre_fc = nn.Sequential(*layers)
self.fc = ToeplitzLinear(pre_fc_dim, output_dim)
self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
self.final_norm = nn.Softmax(dim=-1)
self.register_buffer("abs_shift", torch.zeros((), dtype=torch.long), persistent=True)
def forward(self, x):
r"""
Args:
x (torch.Tensor): shape (batch, channels, freq_bins)
"""
x_norm = self.layernorm(x)
x = self.layernorm(x)
x = self.conv1(x_norm)
x = self.conv1(x)
for p in range(0, self.n_prefilt_layers - 1):
prefilt_layer = self.prefilt_list[p]
prefilt_layer = self.prefilt_layers[p]
if self.residual:
x_new = prefilt_layer(x)
x = x_new + x
else:
x = prefilt_layer(x)
conv2_lrelu = self.conv2(x)
conv3_lrelu = self.conv3(conv2_lrelu)
y_pred = self.conv4(conv3_lrelu)
y_pred = self.flatten(y_pred)
y_pred = self.pre_fc(y_pred)
y_pred = self.fc(y_pred)
x = self.conv_layers(x)
x = self.flatten(x)
y_pred = self.fc(x)
return self.final_norm(y_pred)
class PESTO(nn.Module):
def __init__(self,
encoder: nn.Module,
preprocessor: nn.Module,
crop_kwargs: Optional[Mapping[str, Any]] = None,
reduction: str = "alwa"):
super(PESTO, self).__init__()
self.encoder = encoder
self.preprocessor = preprocessor
# crop CQT
if crop_kwargs is None:
crop_kwargs = {}
self.crop_cqt = CropCQT(**crop_kwargs)
self.reduction = reduction
# constant shift to get absolute pitch from predictions
self.register_buffer('shift', torch.zeros((), dtype=torch.float), persistent=True)
def forward(self,
audio_waveforms: torch.Tensor,
sr: Optional[int] = None,
convert_to_freq: bool = False,
return_activations: bool = False) -> OUTPUT_TYPE:
r"""
Args:
audio_waveforms (torch.Tensor): mono audio waveform or batch of mono audio waveforms,
shape (batch_size?, num_samples)
sr (int, optional): sampling rate, defaults to the previously used sampling rate
convert_to_freq (bool): whether to convert the result to frequencies or return fractional semitones instead.
return_activations (bool): whether to return activations or pitch predictions only
Returns:
preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
shape (batch_size?, num_timesteps)
activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
"""
batch_size = audio_waveforms.size(0) if audio_waveforms.ndim == 2 else None
x = self.preprocessor(audio_waveforms, sr=sr)
x = self.crop_cqt(x) # the CQT has to be cropped beforehand
# for now, confidence is computed very naively just based on energy in the CQT
confidence = x.mean(dim=-2).max(dim=-1).values
conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
confidence = (confidence - conf_min) / (conf_max - conf_min)
# flatten batch_size and time_steps since anyway predictions are made on CQT frames independently
if batch_size:
x = x.flatten(0, 1)
activations = self.encoder(x)
if batch_size:
activations = activations.view(batch_size, -1, activations.size(-1))
activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1)
preds = reduce_activations(activations, reduction=self.reduction)
if convert_to_freq:
preds = 440 * 2 ** ((preds - 69) / 12)
if return_activations:
return preds, confidence, activations
return preds, confidence
@property
def bins_per_semitone(self) -> int:
return self.preprocessor.hcqt_kwargs["bins_per_semitone"]
@property
def hop_size(self) -> float:
r"""Returns the hop size of the model (in milliseconds)"""
return self.preprocessor.hop_size
from .crop_cqt import CropCQT
from .export import export
from .hcqt import HarmonicCQT
from .reduce_activations import reduce_activations
\ No newline at end of file
import torch
import torch.nn as nn
class CropCQT(nn.Module):
def __init__(self, min_steps: int, max_steps: int):
super(CropCQT, self).__init__()
self.min_steps = min_steps
self.max_steps = max_steps
# lower bin
self.lower_bin = self.max_steps
def forward(self, spectrograms: torch.Tensor) -> torch.Tensor:
# WARNING: didn't check that it works, it may be dangerous
return spectrograms[..., self.max_steps: self.min_steps]
# old implementation
batch_size, _, input_height = spectrograms.size()
output_height = input_height - self.max_steps + self.min_steps
assert output_height > 0, \
f"With input height {input_height:d} and output height {output_height:d}, impossible " \
f"to have a range of {self.max_steps - self.min_steps:d} bins."
return spectrograms[..., self.lower_bin: self.lower_bin + output_height]
......@@ -41,11 +41,13 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1
bps = activations.shape[1] // 128
activations = activations[:, bps*lims[0]: bps*lims[1]]
activations = activations * confidence.unsqueeze(1)
activations = activations * confidence[:, None]
plt.imshow(activations.T,
aspect='auto', origin='lower', cmap='inferno',
extent=(timesteps[0], timesteps[-1]) + lims)
extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims)
plt.xlabel("Time (s)")
plt.ylabel("Pitch (semitones)")
plt.title(output_file.rsplit('.', 2)[0])
plt.tight_layout()
plt.savefig(output_file)
......@@ -354,3 +354,39 @@ class CQT(nn.Module):
phase_real = torch.cos(torch.atan2(CQT_imag, CQT_real))
phase_imag = torch.sin(torch.atan2(CQT_imag, CQT_real))
return torch.stack((phase_real, phase_imag), -1)
class HarmonicCQT(nn.Module):
r"""Harmonic CQT layer, as described in Bittner et al. (20??)"""
def __init__(
self,
harmonics,
sr: int = 22050,
hop_length: int = 512,
fmin: float = 32.7,
fmax: Optional[float] = None,
bins_per_semitone: int = 1,
n_bins: int = 84,
center_bins: bool = True
):
super(HarmonicCQT, self).__init__()
if center_bins:
fmin = fmin / 2 ** ((bins_per_semitone - 1) / (24 * bins_per_semitone))
self.cqt_kernels = nn.ModuleList([
CQT(sr=sr, hop_length=hop_length, fmin=h*fmin, fmax=fmax, n_bins=n_bins,
bins_per_octave=12*bins_per_semitone, output_format="Complex")
for h in harmonics
])
def forward(self, audio_waveforms: torch.Tensor):
r"""Converts a batch of waveforms into a batch of HCQTs.
Args:
audio_waveforms (torch.Tensor): Batch of waveforms, shape (batch_size, num_samples)
Returns:
Harmonic CQT, shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
"""
return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_kernels], dim=1)
File moved
import os
from typing import Optional
import torch
from .config import model_args, cqt_args, bins_per_semitone
from .data import DataProcessor
from .model import PESTOEncoder
def load_dataprocessor(step_size, device: Optional[torch.device] = None):
return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device)
def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
model = PESTOEncoder(**model_args).to(device)
model.eval()
model_path = os.path.join(os.path.dirname(__file__), "weights", model_name + ".pth")
model.load_state_dict(torch.load(model_path, map_location=device))
return model
def reduce_activation(activations: torch.Tensor, reduction: str) -> torch.Tensor:
r"""Computes the pitch predictions from the activation outputs of the encoder.
Pitch predictions are returned in semitones, NOT in frequencies.
def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor:
r"""
Args:
activations: tensor of probability activations, shape (*, num_bins)
reduction:
reduction (str): reduction method to compute pitch out of activations,
choose between "argmax", "mean", "alwa".
Returns:
torch.Tensor: pitch predictions, shape (*,)
torch.Tensor: pitches as fractions of MIDI semitones, shape (*)
"""
bps = bins_per_semitone
device = activations.device
num_bins = activations.size(-1)
bps, r = divmod(num_bins, 128)
assert r == 0, f"Activations should have output size 128*bins_per_semitone, got {num_bins}."
if reduction == "argmax":
pred = activations.argmax(dim=-1)
return pred.float() / bps
all_pitches = (torch.arange(activations.size(-1), dtype=torch.float, device=activations.device)) / bps
all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps)
if reduction == "mean":
return activations @ all_pitches
return torch.matmul(activations, all_pitches)
if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe
center_bin = activations.argmax(dim=-1, keepdim=True)
window = torch.arange(-bps+1, bps, device=activations.device)
indices = window + center_bin
window = torch.arange(1, 2 * bps, device=device) - bps # [-bps+1, -bps+2, ..., bps-2, bps-1]
indices = (center_bin + window).clip_(min=0, max=num_bins - 1)
cropped_activations = activations.gather(-1, indices)
cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)
......
File added
File deleted
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "pesto-pitch"
dynamic = ["version"]
authors = [
{name = "Alain Riou", email = "alain.riou@sony.com"}
]
description = "Efficient pitch estimation with self-supervised learning"
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.8"
classifiers = [
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)', # If licence is provided must be on the repository
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
]
dependencies = [
'numpy>=1.21.5',
'scipy>=1.8.1',
'tqdm>=4.66.1',
'torch>=2.0.1',
'torchaudio>=2.0.2'
]
[project.optional-dependencies]
matplotlib = ["matplotlib"]
test = ["pytest"]
[project.scripts]
pesto = "pesto.main:pesto"
[project.urls]
source = "https://github.com/SonyCSLParis/pesto"
[tool.pytest.ini_options]
testpaths = "tests/"
[tool.setuptools.dynamic]
version = {attr = "pesto.__version__"}
[tool.setuptools.package-data]
pesto = ["weights/*.ckpt"]
from pathlib import Path
from setuptools import setup, find_packages
def get_readme_text():
root_dir = Path(__file__).parent
readme_path = root_dir / "README.md"
return readme_path.read_text()
setup(
name='pesto-pitch',
version='0.1.0',
description='Efficient pitch estimation with self-supervised learning',
long_description=get_readme_text(),
long_description_content_type='text/markdown',
author='Alain Riou',
url='https://github.com/SonyCSLParis/pesto',
license='LGPL-3.0',
packages=find_packages(),
include_package_data=True,
package_data={
'pesto': ['weights/*'], # To include the .pth
},
install_requires=[
'numpy>=1.21.5',
'scipy>=1.8.1',
'tqdm>=4.66.1',
'torch>=2.0.1',
'torchaudio>=2.0.2'
],
classifiers=[
# 'Development Status :: 1 - Planning',
# 'Development Status :: 2 - Pre-Alpha',
# 'Development Status :: 3 - Alpha',
# 'Development Status :: 4 - Beta',
'Development Status :: 5 - Production/Stable',
# 'Development Status :: 6 - Mature',
# 'Development Status :: 7 - Inactive',
'Intended Audience :: Developers',
'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)', # If licence is provided must be on the repository
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
],
entry_points={
'console_scripts': [
'pesto=pesto.main:pesto', # For the command line, executes function pesto() in pesto/main as 'pesto'
],
},
python_requires='>=3.8',
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment