Skip to content
Snippets Groups Projects
Commit c6eaa6f8 authored by Alain Riou's avatar Alain Riou
Browse files

make compatible with training code

parent 1c57b6b1
Branches
No related tags found
No related merge requests found
......@@ -107,46 +107,64 @@ import pesto
# predict the pitch of your audio tensors directly within your own Python code
x, sr = torchaudio.load("my_file.wav")
timesteps, pitch, confidence, activations = pesto.predict(x, sr, step_size=10.)
x = x.mean(dim=0) # PESTO takes mono audio as input
timesteps, pitch, confidence, activations = pesto.predict(x, sr)
# or predict using your own custom checkpoint
predictions = pesto.predict(x, sr, model_name="/path/to/checkpoint.ckpt")
# You can also predict pitches from audio files directly
pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], step_size=10., export_format=["csv"])
pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], export_format=["csv"])
```
`pesto.predict` supports batched inputs, which should then have shape `(batch_size, num_samples)`.
**Warning:** If you forget to convert a stereo audio in mono, channels will be treated as batch dimensions and you will
get predictions for each channel separately.
#### Advanced usage
If not provided, `pesto.predict` will first load the CQT kernels and the model before performing
`pesto.predict` will first load the CQT kernels and the model before performing
any pitch estimation. If you want to process a significant number of files, calling `predict` several times will then
re-initialize the same model for each tensor.
To avoid this time-consuming step, one can manually instantiate the model and data processor, then pass them directly
as args to the `predict` function. To do so, one has to use the underlying methods from `pesto.utils`:
To avoid this time-consuming step, one can manually instantiate the model with `load_model`,
then call the forward method of the model instead:
```python
import torch
from pesto import predict
from pesto.utils import load_model, load_dataprocessor
from pesto import load_model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model("mir-1k", device=device)
data_processor = load_dataprocessor(step_size=0.01, device=device)
pesto_model = load_model("mir-1k", step_size=20.).to(device)
for x, sr in ...:
data_processor.sampling_rate = sr # The data_processor handles waveform->CQT conversion so it must know the sampling rate
predictions = predict(x, sr, model=model)
x = x.to(device)
predictions, confidence, activations = pesto_model(x, sr)
...
```
Note that when passing a list of files to `pesto.predict_from_files(...)` or the CLI directly, the model is loaded only
once so you don't have to bother with that in general.
#### Batched pitch estimation
Note that the `PESTO` object returned by `load_model` is a subclass of `nn.Module`
and its `forward` method also supports batched inputs.
One can therefore easily integrate PESTO within their own architecture by doing:
```python
import torch
import torch.nn as nn
from pesto import load_model
By default, the function `pesto.predict` takes an audio waveform represented as a Tensor object of shape `(num_channels, num_samples)`.
However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications.
`pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly.
Note that batched predictions are available only from the Python API and not from the CLI because handling audios of different lengths is annoying, and I don't want to bother with that.
class MyGreatModel(nn.Module):
def __init__(self, step_size, sr=44100, *args, **kwargs):
super(MyGreatModel, self).__init__()
self.f0_estimator = load_model("mir-1k", step_size, sampling_rate=sr)
...
def forward(self, x):
with torch.no_grad():
f0, conf = self.f0_estimator(x, convert_to_freq=True, return_activations=False)
...
```
## Performances
......
from .core import predict, predict_from_files
from .core import load_model, predict, predict_from_files
import os
import warnings
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Tuple, Union
import torch
import torchaudio
from tqdm import tqdm
from .export import export
from .utils import load_model, load_dataprocessor, reduce_activation
from .loader import load_model
from .model import PESTO
from .utils import export
def predict(
x: torch.Tensor,
sr: Optional[int] = None,
model: Union[torch.nn.Module, str] = "mir-1k",
data_preprocessor=None,
step_size: Optional[float] = None,
reduction: str = "argmax",
def _predict(x: torch.Tensor,
sr: int,
model: PESTO,
num_chunks: int = 1,
convert_to_freq: bool = False,
convert_to_freq: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
preds, confidence, activations = [], [], []
try:
for chunk in x.chunk(chunks=num_chunks):
pred, conf, act = model(chunk, sr=sr, convert_to_freq=convert_to_freq, return_activations=True)
preds.append(pred)
confidence.append(conf)
activations.append(act)
except torch.cuda.OutOfMemoryError:
raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
"Please increase the number of chunks with option `-c`/`--chunks` "
"to reduce GPU memory usage.")
preds = torch.cat(preds, dim=0)
confidence = torch.cat(confidence, dim=0)
activations = torch.cat(activations, dim=0)
# compute timesteps
timesteps = torch.arange(preds.size(-1), device=x.device) * model.hop_size
return timesteps, preds, confidence, activations
def predict(x: torch.Tensor,
sr: int,
step_size: float = 10.,
model_name: str = "mir-1k",
reduction: str = "alwa",
num_chunks: int = 1,
convert_to_freq: bool = True,
inference_mode: bool = True,
no_grad: bool = True
):
no_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Main prediction function.
Args:
x (torch.Tensor): input audio tensor,
shape (num_channels, num_samples) or (batch_size, num_channels, num_samples)
x (torch.Tensor): input audio tensor, can be provided as a batch but should be mono,
shape (num_samples) or (batch_size, num_samples)
sr (int, optional): sampling rate. If not specified, uses the current sampling rate of the model.
model: PESTO model. If a string is passed, it will load the model with the corresponding name.
Otherwise, the actual nn.Module will be used for doing predictions.
data_preprocessor: Module handling the data processing pipeline (waveform to CQT, cropping, etc.)
step_size (float, optional): step size between each CQT frame in milliseconds.
If the data_preprocessor is passed, its value will be used instead.
If a `PESTO` object is passed as `model`, this will be ignored.
model_name: name of PESTO model. Can be a path to a custom PESTO checkpoint or the name of a standard model.
reduction (str): reduction method for converting activation probabilities to log-frequencies.
num_chunks (int): number of chunks to split the input audios in.
Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage
......@@ -40,58 +63,25 @@ def predict(
convert_to_freq (bool): whether predictions should be converted to frequencies or not.
inference_mode (bool): whether to run with `torch.inference_mode`.
no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored.
Returns:
timesteps (torch.Tensor): timesteps corresponding to each pitch prediction, shape (num_timesteps)
preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
shape (batch_size?, num_timesteps)
activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
"""
# sanity checks
assert x.ndim <= 2, \
f"Audio file should have shape (num_samples) or (batch_size, num_samples), but found shape {x.size()}."
inference_mode = inference_mode and no_grad
with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode):
# convert to mono
assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}"
batch_size = x.size(0) if x.ndim == 3 else None
x = x.mean(dim=-2)
if data_preprocessor is None:
assert step_size is not None and sr is not None, \
"If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)"
data_preprocessor = load_dataprocessor(step_size=step_size / 1000., sampling_rate=sr, device=x.device)
# If the sampling rate has changed, change the sampling rate accordingly
# It will automatically recompute the CQT kernels if needed
if sr is not None:
data_preprocessor.sampling_rate = sr
if isinstance(model, str):
model = load_model(model, device=x.device)
# apply model
cqt = data_preprocessor(x)
try:
activations = torch.cat([
model(chunk) for chunk in cqt.chunk(chunks=num_chunks)
])
except torch.cuda.OutOfMemoryError:
raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
"Please increase the number of chunks with option `-c`/`--chunks` "
"to reduce GPU memory usage.")
if batch_size:
total_batch_size, num_predictions = activations.size()
activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions)
model = load_model(model_name, step_size, sampling_rate=sr).to(x.device)
model.reduction = reduction
# shift activations as it should (PESTO predicts pitches up to an additive constant)
activations = activations.roll(model.abs_shift.cpu().item(), dims=-1)
# convert model predictions to pitch values
pitch = reduce_activation(activations, reduction=reduction)
if convert_to_freq:
pitch = 440 * 2 ** ((pitch - 69) / 12)
# for now, confidence is computed very naively just based on volume
confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch)
conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
confidence = (confidence - conf_min) / (conf_max - conf_min)
timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size
return timesteps, pitch, confidence, activations
return _predict(x, sr, model, num_chunks=num_chunks, convert_to_freq=convert_to_freq)
def predict_from_files(
......@@ -113,12 +103,11 @@ def predict_from_files(
output:
step_size: hop length in milliseconds
reduction:
export_format:
export_format (Sequence[str]): format to export the predictions to.
Currently formats supported are: ["csv", "npz", "json"].
no_convert_to_freq: whether convert output values to Hz or keep fractional MIDI pitches
num_chunks (int): number of chunks to divide the inputs into. Increase this value if you encounter OOM errors.
gpu: index of GPU to use (-1 for CPU)
Returns:
Pitch predictions, see `predict` for more details.
"""
if isinstance(audio_files, str):
audio_files = [audio_files]
......@@ -128,14 +117,13 @@ def predict_from_files(
gpu = -1
device = torch.device(f"cuda:{gpu:d}" if gpu >= 0 else "cpu")
# define data preprocessing
data_preprocessor = load_dataprocessor(step_size / 1000., device=device)
# define model
model = load_model(model_name, device=device)
predictions = None
model = load_model(model_name, step_size=step_size).to(device)
model.reduction = reduction
pbar = tqdm(audio_files)
with torch.inference_mode(): # here the purpose is to write results in disk, so there is no point storing gradients
for file in pbar:
pbar.set_description(file)
......@@ -146,11 +134,10 @@ def predict_from_files(
print(e, f"Skipping {file}...")
continue
x = x.to(device)
x = x.mean(dim=0).to(device) # convert to mono then pass to the right device
# compute the predictions
predictions = predict(x, sr, model=model, data_preprocessor=data_preprocessor, reduction=reduction,
convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
predictions = _predict(x, sr, model=model, convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0")
if output is not None:
......@@ -161,4 +148,3 @@ def predict_from_files(
for fmt in export_format:
export(fmt, output_file, *predictions)
return predictions
......@@ -3,79 +3,92 @@ from typing import Optional
import torch
import torch.nn as nn
from .cqt import CQT
from .utils import HarmonicCQT
class DataProcessor(nn.Module):
class ToLogMagnitude(nn.Module):
def __init__(self):
super(ToLogMagnitude, self).__init__()
self.eps = torch.finfo(torch.float32).eps
def forward(self, x):
x = x.abs()
x.clamp_(min=self.eps).log10_().mul_(20)
return x
class Preprocessor(nn.Module):
r"""
Args:
step_size (float): step size between consecutive CQT frames (in SECONDS)
hop_size (float): step size between consecutive CQT frames (in milliseconds)
"""
_sampling_rate: Optional[int] = None
def __init__(self,
step_size: float,
bins_per_semitone: int = 3,
hop_size: float,
sampling_rate: Optional[int] = None,
**cqt_kwargs):
super(DataProcessor, self).__init__()
self.step_size = step_size
self.bins_per_semitone = bins_per_semitone
**hcqt_kwargs):
super(Preprocessor, self).__init__()
# CQT-related stuff
self.cqt_kwargs = cqt_kwargs
self.cqt_kwargs["bins_per_octave"] = 12 * bins_per_semitone
self.cqt = None
# HCQT
self.hcqt_sr = None
self.hcqt_kernels = None
self.hop_size = hop_size
# log-magnitude
self.eps = torch.finfo(torch.float32).eps
self.hcqt_kwargs = hcqt_kwargs
# cropping
self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# log-magnitude
self.to_log = ToLogMagnitude()
# register a dummy tensor to get implicit access to the module's device
self.register_buffer("_device", torch.zeros(()), persistent=False)
# sampling rate is lazily initialized
# if the sampling rate is provided, instantiate the CQT kernels
if sampling_rate is not None:
self.sampling_rate = sampling_rate
self.hcqt_sr = sampling_rate
self._reset_hcqt_kernels()
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
r"""
Args:
x: audio waveform, any sampling rate, shape (num_samples)
x (torch.Tensor): audio waveform or batch of audio waveforms, any sampling rate,
shape (batch_size?, num_samples)
sr (int, optional): sampling rate
Returns:
log-magnitude CQT, shape (
torch.Tensor: log-magnitude CQT of batch of CQTs,
shape (batch_size?, num_timesteps, num_harmonics, num_freqs)
"""
# compute CQT from input waveform, and invert dims for (batch_size, time_steps, freq_bins)
complex_cqt = torch.view_as_complex(self.cqt(x)).transpose(1, 2)
# reshape and crop borders to fit training input shape
complex_cqt = complex_cqt[..., self.lowest_bin: self.highest_bin]
# flatten eventual batch dimensions so that batched audios can be processed in parallel
complex_cqt = complex_cqt.flatten(0, 1).unsqueeze(1)
# compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins)
# in other words, time becomes the batch dimension, enabling efficient processing for long audios.
complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2)
# convert to dB
log_cqt = complex_cqt.abs().clamp_(min=self.eps).log10_().mul_(20)
return log_cqt
def _init_cqt_layer(self, sr: int, hop_length: int):
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device)
return self.to_log(complex_cqt)
@property
def sampling_rate(self) -> int:
return self._sampling_rate
def hcqt(self, audio: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
r"""Compute the Harmonic CQT of the input audio after eventually recreating the kernels
(in case the sampling rate has changed).
@sampling_rate.setter
def sampling_rate(self, sr: int) -> None:
if sr == self._sampling_rate:
return
Args:
audio (torch.Tensor): mono audio waveform, shape (batch_size, num_samples)
sr (int): sampling rate of the audio waveform.
If not specified, we assume it is the same as the previous processed audio waveform.
hop_length = int(self.step_size * sr + 0.5)
self._init_cqt_layer(sr, hop_length)
self._sampling_rate = sr
Returns:
torch.Tensor: Complex Harmonic CQT (HCQT) of the input,
shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
"""
# compute HCQT kernels if it does not exist or if the sampling rate has changed
if sr is not None and sr != self.hcqt_sr:
self.hcqt_sr = sr
self._reset_hcqt_layer()
return self.hcqt_kernels(audio)
def _reset_hcqt_layer(self) -> None:
hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5)
self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr,
hop_length=hop_length,
**self.hcqt_kwargs).to(self._device.device)
import os
from typing import Optional
import torch
from .data import Preprocessor
from .model import PESTO, Resnet1d
def load_model(checkpoint: str,
step_size: float,
sampling_rate: Optional[int] = None) -> PESTO:
r"""Load a trained model from a checkpoint file.
See https://github.com/SonyCSLParis/pesto-full/blob/master/src/models/pesto.py for the structure of the checkpoint.
Args:
checkpoint (str): path to the checkpoint or name of the checkpoint file (if using a provided checkpoint)
step_size (float): hop size in milliseconds
sampling_rate (int, optional): sampling rate of the audios.
If not provided, it can be inferred dynamically as well.
Returns:
PESTO: instance of PESTO model
"""
if os.path.exists(checkpoint): # handle user-provided checkpoints
model_path = checkpoint
else:
model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".pth")
if not os.path.exists(model_path):
raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.")
# load checkpoint
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
hparams = checkpoint["hparams"]
hcqt_params = checkpoint["hcqt_params"]
state_dict = checkpoint["state_dict"]
# instantiate preprocessor
preprocessor = Preprocessor(hop_size=step_size, sampling_rate=sampling_rate, **hcqt_params)
# instantiate PESTO encoder
encoder = Resnet1d(**hparams["encoder"])
# instantiate main PESTO module and load its weights
model = PESTO(encoder,
preprocessor=preprocessor,
crop_kwargs=hparams["pitch_shift"],
reduction=hparams["reduction"])
model.load_state_dict(state_dict)
model.eval()
return model
from .parser import parse_args
from pesto.utils.parser import parse_args
from .core import predict_from_files
......
from functools import partial
from typing import Any, Mapping, Optional, Tuple, Union
import torch
import torch.nn as nn
from .utils import CropCQT
from .utils import reduce_activations
OUTPUT_TYPE = Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
class ToeplitzLinear(nn.Conv1d):
def __init__(self, in_features, out_features):
......@@ -18,7 +25,7 @@ class ToeplitzLinear(nn.Conv1d):
return super(ToeplitzLinear, self).forward(input.unsqueeze(-2)).squeeze(-2)
class PESTOEncoder(nn.Module):
class Resnet1d(nn.Module):
"""
Basic CNN similar to the one in Johannes Zeitler's report,
but for longer HCQT input (always stride 1 in time)
......@@ -29,7 +36,8 @@ class PESTOEncoder(nn.Module):
not over time (in order to work with variable length input).
Outputs one channel with sigmoid activation.
Args (Defaults: BasicCNN by Johannes Zeitler but with 1 input channel):
Args (Defaults: BasicCNN by Johannes Zeitler but with 6 input channels):
n_chan_input: Number of input channels (harmonics in HCQT)
n_chan_layers: Number of channels in the hidden layers (list)
n_prefilt_layers: Number of repetitions of the prefiltering layer
residual: If True, use residual connections for prefiltering (default: False)
......@@ -39,100 +47,188 @@ class PESTOEncoder(nn.Module):
p_dropout: Dropout probability
"""
def __init__(
self,
def __init__(self,
n_chan_input=1,
n_chan_layers=(20, 20, 10, 1),
n_prefilt_layers=1,
prefilt_kernel_size=15,
residual=False,
n_bins_in=216,
output_dim=128,
num_output_layers: int = 1
):
super(PESTOEncoder, self).__init__()
activation_layer = partial(nn.LeakyReLU, negative_slope=0.3)
activation_fn: str = "leaky",
a_lrelu=0.3,
p_dropout=0.2):
super(Resnet1d, self).__init__()
self.hparams = dict(n_chan_input=n_chan_input,
n_chan_layers=n_chan_layers,
n_prefilt_layers=n_prefilt_layers,
prefilt_kernel_size=prefilt_kernel_size,
residual=residual,
n_bins_in=n_bins_in,
output_dim=output_dim,
activation_fn=activation_fn,
a_lrelu=a_lrelu,
p_dropout=p_dropout)
if activation_fn == "relu":
activation_layer = nn.ReLU
elif activation_fn == "silu":
activation_layer = nn.SiLU
elif activation_fn == "leaky":
activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu)
else:
raise ValueError
n_in = n_chan_input
n_ch = n_chan_layers
if len(n_ch) < 5:
n_ch.append(1)
# Layer normalization over frequency
self.layernorm = nn.LayerNorm(normalized_shape=[1, n_bins_in])
# Layer normalization over frequency and channels (harmonics of HCQT)
self.layernorm = nn.LayerNorm(normalized_shape=[n_in, n_bins_in])
# Prefiltering
prefilt_padding = prefilt_kernel_size // 2
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
activation_layer()
nn.Conv1d(in_channels=n_in,
out_channels=n_ch[0],
kernel_size=prefilt_kernel_size,
padding=prefilt_padding,
stride=1),
activation_layer(),
nn.Dropout(p=p_dropout)
)
self.n_prefilt_layers = n_prefilt_layers
self.prefilt_list = nn.ModuleList()
for p in range(1, n_prefilt_layers):
self.prefilt_list.append(nn.Sequential(
nn.Conv1d(in_channels=n_ch[0], out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
activation_layer()
))
self.prefilt_layers = nn.ModuleList(*[
nn.Sequential(
nn.Conv1d(in_channels=n_ch[0],
out_channels=n_ch[0],
kernel_size=prefilt_kernel_size,
padding=prefilt_padding,
stride=1),
activation_layer(),
nn.Dropout(p=p_dropout)
)
for _ in range(n_prefilt_layers-1)
])
self.residual = residual
self.conv2 = nn.Sequential(
nn.Conv1d(
in_channels=n_ch[0],
out_channels=n_ch[1],
conv_layers = []
for i in range(len(n_chan_layers)-1):
conv_layers.extend([
nn.Conv1d(in_channels=n_ch[i],
out_channels=n_ch[i + 1],
kernel_size=1,
stride=1,
padding=0
),
activation_layer()
)
self.conv3 = nn.Sequential(
nn.Conv1d(in_channels=n_ch[1], out_channels=n_ch[2], kernel_size=1, padding=0, stride=1),
activation_layer()
)
self.conv4 = nn.Sequential(
nn.Conv1d(in_channels=n_ch[2], out_channels=n_ch[3], kernel_size=1, padding=0, stride=1),
padding=0,
stride=1),
activation_layer(),
nn.Dropout(),
nn.Conv1d(in_channels=n_ch[3], out_channels=n_ch[4], kernel_size=1, padding=0, stride=1)
)
nn.Dropout(p=p_dropout)
])
self.conv_layers = nn.Sequential(*conv_layers)
self.flatten = nn.Flatten(start_dim=1)
layers = []
pre_fc_dim = n_bins_in * n_ch[4]
for i in range(num_output_layers-1):
layers.extend([
ToeplitzLinear(pre_fc_dim, pre_fc_dim),
activation_layer()
])
self.pre_fc = nn.Sequential(*layers)
self.fc = ToeplitzLinear(pre_fc_dim, output_dim)
self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
self.final_norm = nn.Softmax(dim=-1)
self.register_buffer("abs_shift", torch.zeros((), dtype=torch.long), persistent=True)
def forward(self, x):
r"""
Args:
x (torch.Tensor): shape (batch, channels, freq_bins)
"""
x_norm = self.layernorm(x)
x = self.layernorm(x)
x = self.conv1(x_norm)
x = self.conv1(x)
for p in range(0, self.n_prefilt_layers - 1):
prefilt_layer = self.prefilt_list[p]
prefilt_layer = self.prefilt_layers[p]
if self.residual:
x_new = prefilt_layer(x)
x = x_new + x
else:
x = prefilt_layer(x)
conv2_lrelu = self.conv2(x)
conv3_lrelu = self.conv3(conv2_lrelu)
y_pred = self.conv4(conv3_lrelu)
y_pred = self.flatten(y_pred)
y_pred = self.pre_fc(y_pred)
y_pred = self.fc(y_pred)
x = self.conv_layers(x)
x = self.flatten(x)
y_pred = self.fc(x)
return self.final_norm(y_pred)
class PESTO(nn.Module):
def __init__(self,
encoder: nn.Module,
preprocessor: nn.Module,
crop_kwargs: Mapping[str, Any] | None = None,
reduction: str = "alwa"):
super(PESTO, self).__init__()
self.encoder = encoder
self.preprocessor = preprocessor
# crop CQT
if crop_kwargs is None:
crop_kwargs = {}
self.crop_cqt = CropCQT(**crop_kwargs)
self.reduction = reduction
# constant shift to get absolute pitch from predictions
self.register_buffer('shift', torch.zeros((), dtype=torch.float), persistent=True)
def forward(self,
audio_waveforms: torch.Tensor,
sr: Optional[int] = None,
convert_to_freq: bool = False,
return_activations: bool = False) -> OUTPUT_TYPE:
r"""
Args:
audio_waveforms (torch.Tensor): mono audio waveform or batch of mono audio waveforms,
shape (batch_size?, num_samples)
sr (int, optional): sampling rate, defaults to the previously used sampling rate
convert_to_freq (bool): whether to convert the result to frequencies or return fractional semitones instead.
return_activations (bool): whether to return activations or pitch predictions only
Returns:
preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
shape (batch_size?, num_timesteps)
activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
"""
batch_size = audio_waveforms.size(0) if audio_waveforms.ndim == 2 else None
x = self.preprocessor(audio_waveforms, sr=sr)
x = self.crop_cqt(x) # the CQT has to be cropped beforehand
# for now, confidence is computed very naively just based on energy in the CQT
confidence = x.mean(dim=-2).max(dim=-1).values
conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
confidence = (confidence - conf_min) / (conf_max - conf_min)
# flatten batch_size and time_steps since anyway predictions are made on CQT frames independently
if batch_size:
x = x.flatten(0, 1)
activations = self.encoder(x)
if batch_size:
activations = activations.view(batch_size, -1, activations.size(-1))
preds = reduce_activations(activations, reduction=self.reduction)
# decrease by shift to get absolute pitch
preds.sub_(self.shift)
if convert_to_freq:
preds = 440 * 2 ** ((preds - 69) / 12)
if return_activations:
return preds, confidence, activations
return preds, confidence
@property
def hop_size(self) -> float:
r"""Returns the hop size of the model (in milliseconds)"""
return self.preprocessor.hop_size
import os
from typing import Optional
import torch
from .config import model_args, cqt_args, bins_per_semitone
from .data import DataProcessor
from .model import PESTOEncoder
def load_dataprocessor(step_size, sampling_rate: Optional[int] = None, device: Optional[torch.device] = None):
return DataProcessor(step_size=step_size, sampling_rate=sampling_rate, **cqt_args).to(device)
def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
model = PESTOEncoder(**model_args).to(device)
model.eval()
model_path = os.path.join(os.path.dirname(__file__), "weights", model_name + ".pth")
model.load_state_dict(torch.load(model_path, map_location=device))
return model
def reduce_activation(activations: torch.Tensor, reduction: str) -> torch.Tensor:
r"""Computes the pitch predictions from the activation outputs of the encoder.
Pitch predictions are returned in semitones, NOT in frequencies.
Args:
activations: tensor of probability activations, shape (*, num_bins)
reduction:
Returns:
torch.Tensor: pitch predictions, shape (*,)
"""
bps = bins_per_semitone
if reduction == "argmax":
pred = activations.argmax(dim=-1)
return pred.float() / bps
all_pitches = (torch.arange(activations.size(-1), dtype=torch.float, device=activations.device)) / bps
if reduction == "mean":
return activations @ all_pitches
if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe
center_bin = activations.argmax(dim=-1, keepdim=True)
window = torch.arange(-bps+1, bps, device=activations.device)
indices = window + center_bin
cropped_activations = activations.gather(-1, indices)
cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)
raise ValueError
from .crop_cqt import CropCQT
from .export import export
from .hcqt import HarmonicCQT
from .reduce_activations import reduce_activations
\ No newline at end of file
import torch
import torch.nn as nn
class CropCQT(nn.Module):
def __init__(self, min_steps: int, max_steps: int):
super(CropCQT, self).__init__()
self.min_steps = min_steps
self.max_steps = max_steps
# lower bin
self.lower_bin = self.max_steps
def forward(self, spectrograms: torch.Tensor) -> torch.Tensor:
# WARNING: didn't check that it works, it may be dangerous
return spectrograms[..., self.max_steps: self.min_steps]
# old implementation
batch_size, _, input_height = spectrograms.size()
output_height = input_height - self.max_steps + self.min_steps
assert output_height > 0, \
f"With input height {input_height:d} and output height {output_height:d}, impossible " \
f"to have a range of {self.max_steps - self.min_steps:d} bins."
return spectrograms[..., self.lower_bin: self.lower_bin + output_height]
File moved
......@@ -354,3 +354,39 @@ class CQT(nn.Module):
phase_real = torch.cos(torch.atan2(CQT_imag, CQT_real))
phase_imag = torch.sin(torch.atan2(CQT_imag, CQT_real))
return torch.stack((phase_real, phase_imag), -1)
class HarmonicCQT(nn.Module):
r"""Harmonic CQT layer, as described in Bittner et al. (20??)"""
def __init__(
self,
harmonics,
sr: int = 22050,
hop_length: int = 512,
fmin: float = 32.7,
fmax: float | None = None,
bins_per_semitone: int = 1,
n_bins: int = 84,
center_bins: bool = True
):
super(HarmonicCQT, self).__init__()
if center_bins:
fmin = fmin / 2 ** ((bins_per_semitone - 1) / (24 * bins_per_semitone))
self.cqt_kernels = nn.ModuleList([
CQT(sr=sr, hop_length=hop_length, fmin=h*fmin, fmax=fmax, n_bins=n_bins,
bins_per_octave=12*bins_per_semitone, output_format="Complex")
for h in harmonics
])
def forward(self, audio_waveforms: torch.Tensor):
r"""Converts a batch of waveforms into a batch of HCQTs.
Args:
audio_waveforms (torch.Tensor): Batch of waveforms, shape (batch_size, num_samples)
Returns:
Harmonic CQT, shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
"""
return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_kernels], dim=1)
File moved
import torch
def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor:
r"""
Args:
activations: tensor of probability activations, shape (*, num_bins)
reduction (str): reduction method to compute pitch out of activations,
choose between "argmax", "mean", "alwa".
Returns:
torch.Tensor: pitches as fractions of MIDI semitones, shape (*)
"""
device = activations.device
num_bins = activations.size(-1)
bps, r = divmod(num_bins, 128)
assert r == 0, f"Activations should have output size 128*bins_per_semitone, got {num_bins}."
if reduction == "argmax":
pred = activations.argmax(dim=-1)
return pred.float() / bps
all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps)
if reduction == "mean":
return torch.matmul(activations, all_pitches)
if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe
center_bin = activations.argmax(dim=-1, keepdim=True)
window = torch.arange(1, 2 * bps, device=device) - bps # [-bps+1, -bps+2, ..., bps-2, bps-1]
indices = (center_bin + window).clip_(min=0, max=num_bins - 1)
cropped_activations = activations.gather(-1, indices)
cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(1, indices)
return (cropped_activations * cropped_pitches).sum(dim=1) / cropped_activations.sum(dim=1)
raise ValueError
File deleted
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment