Skip to content
Snippets Groups Projects
Commit be6c0b9f authored by Alain Riou's avatar Alain Riou
Browse files

initial commit

parents
Branches
No related tags found
No related merge requests found
Showing
with 1277 additions and 0 deletions
import hashlib
import itertools
import json
import logging
from pathlib import Path
from typing import Sequence, Tuple, Any
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data
import torchaudio
from lightning import LightningDataModule
from src.data.hcqt import HarmonicCQT
log = logging.getLogger(__name__)
def hz_to_mid(freqs):
return np.where(freqs > 0, 12 * np.log2(freqs / 440) + 69, 0)
class NpyDataset(torch.utils.data.Dataset):
def __init__(self, inputs, labels=None, filter_unvoiced: bool = False) -> None:
assert labels is None or len(inputs) == len(labels), "Lengths of inputs and labels do not match"
if filter_unvoiced and labels is None:
log.warning("Cannnot filter out unvoiced frames without annotations.")
filter_unvoiced = False
if filter_unvoiced:
self.inputs = inputs[labels > 0]
self.labels = labels[labels > 0]
else:
self.inputs = inputs
self.labels = labels
def __getitem__(self, item) -> Tuple[torch.Tensor, torch.Tensor]:
label = self.labels[item] if self.labels is not None else 0
return torch.view_as_complex(torch.from_numpy(self.inputs[item])), label
def __len__(self):
return len(self.inputs)
class AudioDataModule(LightningDataModule):
def __init__(self,
audio_files: str,
annot_files: str | None = None,
val_audio_files: str | None = None,
val_annot_files: str | None = None,
harmonics: Sequence[float] = (1,),
hop_duration: float = 10.,
fmin: float = 27.5,
fmax: float | None = None,
bins_per_semitone: int = 1,
n_bins: int = 84,
center_bins: bool = False,
batch_size: int = 256,
num_workers: int = 0,
pin_memory: bool = False,
transforms: Sequence[torch.nn.Module] | None = None,
fold: int | None = None,
n_folds: int = 5,
cache_dir: str = "./cache",
filter_unvoiced: bool = False,
mmap_mode: str | None = None):
r"""
Args:
audio_files: path to csv file containing the list of audio files to process
"""
super(AudioDataModule, self).__init__()
# sanity checks
assert val_audio_files is None or val_annot_files is not None, "Validation set (if it exists) must be annotated"
assert val_audio_files is None or fold is None, "Specify `val_audio_files` OR cross-validation `fold`, not both"
assert annot_files is not None or fold is None, "Cannot perform cross-validation without any annotations."
self.audio_files = Path(audio_files)
self.annot_files = Path(annot_files) if annot_files is not None else None
if val_audio_files is not None:
self.val_audio_files = Path(val_audio_files)
self.val_annot_files = Path(val_annot_files)
else:
self.val_audio_files = None
self.val_annot_files = None
self.fold = fold
self.n_folds = n_folds
# HCQT
self.hcqt_sr = None
self.hcqt_kernels = None
self.hop_duration = hop_duration
self.hcqt_kwargs = dict(
harmonics=list(harmonics),
fmin=fmin,
fmax=fmax,
bins_per_semitone=bins_per_semitone,
n_bins=n_bins,
center_bins=center_bins
)
# dataloader keyword-arguments
self.dl_kwargs = dict(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
# transforms
self.transforms = nn.Sequential(*transforms) if transforms is not None else nn.Identity()
# misc
self.cache_dir = Path(cache_dir)
self.filter_unvoiced = filter_unvoiced
self.mmap_mode = mmap_mode
# placeholders for datasets and samplers
self.train_dataset = None
self.train_sampler = None
self.val_dataset = None
self.val_sampler = None
def prepare_data(self) -> None:
self.train_dataset = self.load_data(self.audio_files, self.annot_files)
if self.val_audio_files is not None:
self.val_dataset = self.load_data(self.val_audio_files, self.val_annot_files)
def setup(self, stage: str) -> None:
# If the dataset is labeled, we split it randomly and keep 20% for validation only
# Otherwise we train on the whole dataset
if self.val_dataset is not None:
return
if not self.annot_files:
# create dummy validation set
self.val_dataset = NpyDataset(np.zeros_like(self.train_dataset.inputs[:1]))
return
self.val_dataset = self.load_data(self.audio_files, self.annot_files)
if self.fold is not None:
# see https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-use-k-fold-cross-validation-with-pytorch.md
from sklearn.model_selection import KFold
# We fix random_state=0 for the train/val split to be consistent across runs, even if the global seed changes
kfold = KFold(n_splits=self.n_folds, shuffle=True, random_state=0)
iterator = kfold.split(self.train_dataset)
train_idx, val_idx = None, None # just to make the linter shut up
for _ in range(self.fold + 1):
train_idx, val_idx = next(iterator)
self.train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
self.val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
else:
self.train_sampler = torch.utils.data.RandomSampler(self.train_dataset)
self.val_sampler = torch.utils.data.SequentialSampler(self.val_dataset)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_dataset, sampler=self.train_sampler, **self.dl_kwargs)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_dataset, sampler=self.val_sampler, **self.dl_kwargs)
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
x, y = batch
return self.transforms(x), y
def load_data(self, audio_files: Path, annot_files: Path | None = None) -> torch.utils.data.Dataset:
cache_cqt = self.build_cqt_filename(audio_files)
if cache_cqt.exists():
inputs = np.load(cache_cqt, mmap_mode=self.mmap_mode)
cache_annot = cache_cqt.with_suffix(".csv")
annotations = np.loadtxt(cache_annot, dtype=np.float32) if cache_annot.exists() else None
else:
self.cache_dir.mkdir(parents=True, exist_ok=True)
inputs, annotations = self.precompute_hcqt(audio_files, annot_files)
np.save(cache_cqt, inputs, allow_pickle=False)
if annotations is not None:
np.savetxt(cache_cqt.with_suffix(".csv"), annotations)
return NpyDataset(inputs, labels=annotations, filter_unvoiced=self.filter_unvoiced)
def build_cqt_filename(self, audio_files) -> Path:
# build a hash
dict_str = json.dumps({
"audio_files": str(audio_files),
"hop_duration": self.hop_duration,
**self.hcqt_kwargs
}, sort_keys=True)
hash_id = hashlib.sha256(dict_str.encode()).hexdigest()[:8]
# build filename
fname = "hcqt_" + hash_id + ".npy"
return self.cache_dir / fname
def precompute_hcqt(self, audio_path: Path, annot_path: Path | None = None) -> Tuple[np.ndarray,np.ndarray]:
data_dir = audio_path.parent
cqt_list = []
with audio_path.open('r') as f:
audio_files = f.readlines()
if annot_path is not None:
with annot_path.open('r') as f:
annot_files = f.readlines()
annot_list = []
else:
annot_files = []
annot_list = None
log.info("Precomputing HCQT...")
pbar = tqdm(itertools.zip_longest(audio_files, annot_files, fillvalue=None),
total=len(audio_files),
leave=False)
for fname, annot in pbar:
fname = fname.strip()
pbar.set_description(fname)
x, sr = torchaudio.load(data_dir / fname)
out = self.hcqt(x.mean(dim=0), sr) # convert to mono and compute HCQT
if annot is not None:
annot = annot.strip()
timesteps, freqs = np.loadtxt(data_dir / annot, delimiter=',', dtype=np.float32).T
hop_duration = 1000 * (timesteps[1] - timesteps[0])
# Badly-aligned annotations is a fucking nightmare
# so we double-check for each file that hop sizes and lengths do match.
# Since hop sizes are floats we put a tolerance of 1e-6 in the equality
assert abs(hop_duration - self.hop_duration) < 1e-6, \
(f"Inconsistency between {fname} and {annot}:\n"
f"the resolution of the annotations ({len(freqs):d}) "
f"does not match the number of CQT frames ({len(out):d}). "
f"The hop duration between CQT frames should be identical "
f"but got {hop_duration:.1f} ms vs {self.hop_duration:.1f} ms. "
f"Please either adjust the hop duration of the CQT or resample the annotations.")
assert len(out) == len(freqs), \
(f"Inconsistency between {fname} and {annot}:"
f"the resolution of the annotations ({len(freqs):d}) "
f"does not match the number of CQT frames ({len(out):d}) "
f"despite hop durations match. "
f"Please check that your annotations are correct.")
annot_list.append(hz_to_mid(freqs))
cqt_list.append(out.cpu().numpy())
return np.concatenate(cqt_list), np.concatenate(annot_list) if annot_list is not None else None
def hcqt(self, audio: torch.Tensor, sr: int):
# compute CQT kernels if it does not exist yet
if sr != self.hcqt_sr:
self.hcqt_sr = sr
hop_length = int(self.hop_duration * sr / 1000 + 0.5)
self.hcqt_kernels = HarmonicCQT(sr=sr, hop_length=hop_length, **self.hcqt_kwargs)
return self.hcqt_kernels(audio).squeeze(0).permute(2, 0, 1, 3) # (time, harmonics, freq_bins, 2)
import torch
import torch.nn as nn
from nnAudio.features.cqt import CQT
class HarmonicCQT(nn.Module):
def __init__(
self,
harmonics,
sr: int = 22050,
hop_length: int = 512,
fmin: float = 32.7,
fmax: float | None = None,
bins_per_semitone: int = 1,
n_bins: int = 84,
center_bins: bool = True
):
super(HarmonicCQT, self).__init__()
if center_bins:
fmin = fmin / 2 ** ((bins_per_semitone - 1) / (24 * bins_per_semitone))
self.cqt_kernels = nn.ModuleList([
CQT(sr=sr, hop_length=hop_length, fmin=h*fmin, fmax=fmax, n_bins=n_bins,
bins_per_octave=12*bins_per_semitone, output_format="Complex", verbose=False)
for h in harmonics
])
def forward(self, audio_waveforms: torch.Tensor):
r"""
Returns:
Harmonic CQT, shape (num_channels, num_harmonics, num_freqs, num_timesteps, 2)
"""
return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_kernels], dim=1)
import torch
import torch.nn as nn
def randint_sampling_fn(min_value, max_value):
def sample_randint(*size, **kwargs):
return torch.randint(min_value, max_value+1, size, **kwargs)
return sample_randint
def gaussint_sampling_fn(min_value, max_value):
mean = (min_value + max_value) / 2
std = (max_value - mean) / 2
def sample_gaussint(*size, **kwargs):
return torch.randn(size, **kwargs).add_(mean).mul_(std).long().clip(min=min_value, max=max_value)
return sample_gaussint
class PitchShiftCQT(nn.Module):
def __init__(self,
min_steps: int,
max_steps: int,
gaussian_sampling: bool = False):
super(PitchShiftCQT, self).__init__()
self.min_steps = min_steps
self.max_steps = max_steps
self.sample_random_steps = gaussint_sampling_fn(min_steps, max_steps) if gaussian_sampling \
else randint_sampling_fn(min_steps, max_steps)
# lower bin
self.lower_bin = self.max_steps
def forward(self, spectrograms: torch.Tensor):
batch_size, _, input_height = spectrograms.size()
output_height = input_height - self.max_steps + self.min_steps
assert output_height > 0, \
f"With input height {input_height:d} and output height {output_height:d}, impossible " \
f"to have a range of {self.max_steps - self.min_steps:d} bins."
n_steps = self.sample_random_steps(batch_size, device=spectrograms.device)
x = spectrograms[..., self.lower_bin: self.lower_bin + output_height]
xt = self.extract_bins(spectrograms, self.lower_bin - n_steps, output_height)
return x, xt, n_steps
def extract_bins(self, inputs: torch.Tensor, first_bin: torch.LongTensor, output_height: int):
r"""Efficiently extract subsegments of CQT of size `output_height`,
i.e. so that outputs[i, j] = inputs[i, ..., first_bin[j] : first_bin[j] + self.output_height]
Args:
inputs (torch.Tensor): tensor of CQTs, shape (batch_size, num_channels, input_height)
first_bin (torch.LongTensor): indices of the first bin of each segment, shape (batch_size)
output_height (int): output height of the cropped CQT
Returns:
segments of CQTs, shape (batch_size, num_channels, output_height)
"""
indices = first_bin.unsqueeze(-1) + torch.arange(output_height, device=inputs.device)
dims = inputs.size(0), 1, output_height
output_size = list(inputs.size())[:-1] + [output_height]
indices = indices.view(*dims).expand(output_size)
return inputs.gather(-1, indices)
from typing import Optional
import torch
import torch.nn as nn
class ToLogMagnitude(nn.Module):
def __init__(self):
super(ToLogMagnitude, self).__init__()
self.eps = torch.finfo(torch.float32).eps
def forward(self, x):
if x.size(-1) == 2:
x = torch.view_as_complex(x)
if x.ndim == 2:
x.unsqueeze_(1)
x = x.abs()
x.clamp_(min=self.eps).log10_().mul_(20)
return x
class BatchRandomNoise(nn.Module):
def __init__(
self,
min_snr: float = 0.0001,
max_snr: float = 0.01,
p: Optional[float] = None,
):
super(BatchRandomNoise, self).__init__()
self.min_snr = min_snr
self.max_snr = max_snr
self.p = p
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.size(0)
device = x.device
snr = torch.empty(batch_size, device=device)
snr.uniform_(self.min_snr, self.max_snr)
mask = torch.rand_like(snr).le(self.p)
snr[mask] = 0
noise_std = snr * x.view(batch_size, -1).std(dim=-1)
noise_std = noise_std.unsqueeze(-1).expand_as(x.view(batch_size, -1)).view_as(x)
# compute ratios corresponding to gain in dB
noise = noise_std * torch.randn_like(x)
return x + noise
class BatchRandomGain(nn.Module):
def __init__(
self,
min_gain: float = 0.5,
max_gain: float = 1.5,
p: Optional[float] = None,
):
super(BatchRandomGain, self).__init__()
self.min_gain = min_gain
self.max_gain = max_gain
self.p = p
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.size(0)
device = x.device
vol = torch.empty(batch_size, device=device)
vol.uniform_(self.min_gain, self.max_gain)
mask = torch.rand_like(vol).le(self.p)
vol[mask] = 1
vol = vol.unsqueeze(-1).expand_as(x.view(batch_size, -1)).view_as(x)
return vol * x
from .base import NullLoss
from .entropy import CrossEntropyLoss, ShiftCrossEntropy
from .equivariance import PowerSeries
import abc
from typing import Dict
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
class Loss(_Loss, metaclass=abc.ABCMeta):
@abc.abstractmethod
def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
pass
class ComposeLoss(Loss):
def __init__(self, losses: Dict[str, Loss], input_dims: Dict[str, int]):
super(ComposeLoss, self).__init__()
self.losses = nn.ModuleDict(losses)
self.input_dims = [input_dims[k] for k in self.losses.keys()]
def forward(self, inputs) -> Dict[str, torch.Tensor]:
chunks = inputs.split(self.input_dims, dim=-1)
loss_dict = {}
total_loss = None
for (k, loss_fn), chunk in zip(self.losses.items(), chunks):
print(k, loss_fn, chunk.size())
# compute loss
aux_dict = loss_fn(chunk)
# retrieve total auxiliary loss
loss = aux_dict.pop("loss")
# add it to total loss
if total_loss is None:
total_loss = loss.clone()
else:
total_loss = total_loss + loss
# write quantities into dict
loss_dict.update(aux_dict)
loss_dict[k] = loss
loss_dict["loss"] = total_loss
print(loss_dict)
return loss_dict
class NullLoss(nn.Module):
def forward(self, *args, **kwargs) -> torch.Tensor:
return args[0].mean().mul(0)
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossEntropyLoss(nn.Module):
def __init__(self,
symmetric: bool = False,
detach_targets: bool = False,
backend: nn.Module = nn.CrossEntropyLoss()) -> None:
super(CrossEntropyLoss, self).__init__()
self.symmetric = symmetric
self.detach_targets = detach_targets
self.backend = backend
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.symmetric:
return (self.compute_loss(input, target) + self.compute_loss(target, input)) / 2
return self.compute_loss(input, target)
def compute_loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return self.backend(input, target.detach() if self.detach_targets else target)
class ShiftCrossEntropy(nn.Module):
def __init__(self,
pad_length: int = 5,
criterion: nn.Module = CrossEntropyLoss()):
super(ShiftCrossEntropy, self).__init__()
self.criterion = criterion
self.pad_length = pad_length
def forward(self, x1: torch.Tensor, x2: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
r"""x2[i] is the pitch-shifted version of x1[i] by target[i] semitones, i.e.
if x1[i] is a C# and x2[i] is a C then target[i] = -1
"""
# pad x1 and x2
x1 = F.pad(x1, (self.pad_length, self.pad_length))
x2 = F.pad(x2, (2*self.pad_length, 2*self.pad_length))
# shift x2
idx = target.unsqueeze(1) + torch.arange(x1.size(-1), device=target.device) + self.pad_length
shift_x2 = torch.gather(x2, dim=1, index=idx)
# compute loss
return self.criterion(x1, shift_x2)
from typing import Dict
import torch
import torch.nn as nn
class HuberLoss(nn.Module):
def __init__(self, tau: float):
super(HuberLoss, self).__init__()
self.register_buffer("tau", torch.tensor(tau), persistent=False)
def forward(self, x):
x = x.abs()
return torch.where(x.le(self.tau),
x ** 2 / 2,
self.tau ** 2 / 2 + self.tau * (x - self.tau))
class PowerSeries(nn.Module):
def __init__(self, value: float, power_min, power_max, tau: float = 1.):
super(PowerSeries, self).__init__()
self.value = value
# compute weights vector
powers = torch.arange(power_min, power_max)
self.register_buffer("weights", self.value ** powers, persistent=False)
self.dim = len(self.weights)
self.loss_fn = HuberLoss(tau)
def forward(self, x1: torch.Tensor, x2: torch.Tensor, target: torch.Tensor,
nlog_c1: torch.Tensor | None = None, nlog_c2: torch.Tensor | None = None) -> Dict[str, torch.Tensor]:
r"""x2[i] is the pitch-shifted version of x1[i] by target[i] semitones, i.e.
if x1[i] is a C# and x2[i] is a C then target[i] = -1
"""
z1 = self.project(x1)
z2 = self.project(x2)
if nlog_c1 is not None:
z1 = z1 * torch.exp(-nlog_c1)
if nlog_c2 is not None:
z2 = z2 * torch.exp(-nlog_c2)
# compute frequency ratios out of semitones
freq_ratios = self.value ** target.float()
# compute equivariant loss
loss_12 = self.loss_fn(z2 / z1 - freq_ratios).mean()
loss_21 = self.loss_fn(z1 / z2 - 1/freq_ratios).mean()
return (loss_12 + loss_21) / 2
def project(self, x: torch.Tensor):
r"""Projects a batch of vectors into a batch of scalars
Args:
x (torch.Tensor): batch of input vectors, shape (batch_size, output_dim)
Returns:
torch.Tensor: batch of output scalars, shape (batch_size)
"""
return x.mv(self.weights)
from functools import partial
import torch
import torch.nn as nn
class ToeplitzLinear(nn.Conv1d):
def __init__(self, in_features, out_features):
super(ToeplitzLinear, self).__init__(
in_channels=1,
out_channels=1,
kernel_size=in_features+out_features-1,
padding=out_features-1,
bias=False
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super(ToeplitzLinear, self).forward(input.unsqueeze(-2)).squeeze(-2)
class Resnet1d(nn.Module):
"""
Basic CNN similar to the one in Johannes Zeitler's report,
but for longer HCQT input (always stride 1 in time)
Still with 75 (-1) context frames, i.e. 37 frames padded to each side
The number of input channels, channels in the hidden layers, and output
dimensions (e.g. for pitch output) can be parameterized.
Layer normalization is only performed over frequency and channel dimensions,
not over time (in order to work with variable length input).
Outputs one channel with sigmoid activation.
Args (Defaults: BasicCNN by Johannes Zeitler but with 6 input channels):
n_chan_input: Number of input channels (harmonics in HCQT)
n_chan_layers: Number of channels in the hidden layers (list)
n_prefilt_layers: Number of repetitions of the prefiltering layer
residual: If True, use residual connections for prefiltering (default: False)
n_bins_in: Number of input bins (12 * number of octaves)
n_bins_out: Number of output bins (12 for pitch class, 72 for pitch, num_octaves * 12)
a_lrelu: alpha parameter (slope) of LeakyReLU activation function
p_dropout: Dropout probability
"""
def __init__(self,
n_chan_input=1,
n_chan_layers=(20, 20, 10, 1),
n_prefilt_layers=1,
prefilt_kernel_size=15,
residual=False,
n_bins_in=216,
output_dim=128,
activation_fn: str = "leaky",
a_lrelu=0.3,
p_dropout=0.2):
super(Resnet1d, self).__init__()
self.hparams = dict(n_chan_input=n_chan_input,
n_chan_layers=n_chan_layers,
n_prefilt_layers=n_prefilt_layers,
prefilt_kernel_size=prefilt_kernel_size,
residual=residual,
n_bins_in=n_bins_in,
output_dim=output_dim,
activation_fn=activation_fn,
a_lrelu=a_lrelu,
p_dropout=p_dropout)
if activation_fn == "relu":
activation_layer = nn.ReLU
elif activation_fn == "silu":
activation_layer = nn.SiLU
elif activation_fn == "leaky":
activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu)
else:
raise ValueError
n_in = n_chan_input
n_ch = n_chan_layers
if len(n_ch) < 5:
n_ch.append(1)
# Layer normalization over frequency and channels (harmonics of HCQT)
self.layernorm = nn.LayerNorm(normalized_shape=[n_in, n_bins_in])
# Prefiltering
prefilt_padding = prefilt_kernel_size // 2
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels=n_in,
out_channels=n_ch[0],
kernel_size=prefilt_kernel_size,
padding=prefilt_padding,
stride=1),
activation_layer(),
nn.Dropout(p=p_dropout)
)
self.n_prefilt_layers = n_prefilt_layers
self.prefilt_layers = nn.ModuleList(*[
nn.Sequential(
nn.Conv1d(in_channels=n_ch[0],
out_channels=n_ch[0],
kernel_size=prefilt_kernel_size,
padding=prefilt_padding,
stride=1),
activation_layer(),
nn.Dropout(p=p_dropout)
)
for _ in range(n_prefilt_layers-1)
])
self.residual = residual
conv_layers = []
for i in range(len(n_chan_layers)-1):
conv_layers.extend([
nn.Conv1d(in_channels=n_ch[i],
out_channels=n_ch[i + 1],
kernel_size=1,
padding=0,
stride=1),
activation_layer(),
nn.Dropout(p=p_dropout)
])
self.conv_layers = nn.Sequential(*conv_layers)
self.flatten = nn.Flatten(start_dim=1)
self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
self.final_norm = nn.Softmax(dim=-1)
def forward(self, x):
r"""
Args:
x (torch.Tensor): shape (batch, channels, freq_bins)
"""
x = self.layernorm(x)
x = self.conv1(x)
for p in range(0, self.n_prefilt_layers - 1):
prefilt_layer = self.prefilt_layers[p]
if self.residual:
x_new = prefilt_layer(x)
x = x_new + x
else:
x = prefilt_layer(x)
x = self.conv_layers(x)
x = self.flatten(x)
y_pred = self.fc(x)
return self.final_norm(y_pred)
import logging
from typing import Any, Dict, Mapping, Sequence, Tuple, Union
import torch
import torch.nn as nn
from lightning import LightningModule
from src.callbacks.loss_weighting import LossWeighting
from src.data.pitch_shift import PitchShiftCQT
from src.losses import NullLoss
from src.utils import reduce_activations, remove_omegaconf_dependencies
from src.utils.calibration import generate_synth_data
log = logging.getLogger(__name__)
class PESTO(LightningModule):
def __init__(self,
encoder: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
equiv_loss_fn: nn.Module | None = None,
sce_loss_fn: nn.Module | None = None,
inv_loss_fn: nn.Module | None = None,
pitch_shift_kwargs: Mapping[str, Any] | None = None,
transforms: Sequence[nn.Module] | None = None,
reduction: str = "alwa"):
super(PESTO, self).__init__()
self.encoder = encoder
self.optimizer_cls = optimizer
self.scheduler_cls = scheduler
# loss definitions
self.equiv_loss_fn = equiv_loss_fn or NullLoss()
self.sce_loss_fn = sce_loss_fn or NullLoss()
self.inv_loss_fn = inv_loss_fn or NullLoss()
# pitch-shift CQT
if pitch_shift_kwargs is None:
pitch_shift_kwargs = {}
self.pitch_shift = PitchShiftCQT(**pitch_shift_kwargs)
# preprocessing and transforms
self.transforms = nn.Sequential(*transforms) if transforms is not None else nn.Identity()
self.reduction = reduction
# loss weighting
self.loss_weighting = None
# predictions and labels
self.predictions = None
self.labels = None
# constant shift to get absolute pitch from predictions
self.register_buffer('shift', torch.zeros((), dtype=torch.float), persistent=True)
# save hparams
self.hyperparams = dict(encoder=encoder.hparams, pitch_shift=pitch_shift_kwargs)
def forward(self,
x: torch.Tensor,
shift: bool = True,
return_activations: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
x, *_ = self.pitch_shift(x) # the CQT has to be cropped beforehand
activations = self.encoder(x)
preds = reduce_activations(activations, reduction=self.reduction)
if shift:
preds.sub_(self.shift)
if return_activations:
return preds, activations
return preds
def on_fit_start(self) -> None:
r"""Search among Trainer's checkpoints if there is a `LossWeighting`.
If so, then identify it to use it for training.
Otherwise create a dummy one.
"""
for callback in self.trainer.callbacks:
if isinstance(callback, LossWeighting):
self.loss_weighting = callback
if self.loss_weighting is None:
self.loss_weighting = LossWeighting()
self.loss_weighting.last_layer = self.encoder.fc.weight
def on_validation_epoch_start(self) -> None:
self.predictions = []
self.labels = []
self.estimate_shift()
def on_validation_batch_end(self,
outputs,
batch,
batch_idx: int,
dataloader_idx: int = 0) -> None:
preds, labels = outputs
self.predictions.append(preds)
self.labels.append(labels)
def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
x, _ = batch # we do not use the eventual labels during training
# pitch-shift
x, xt, n_steps = self.pitch_shift(x)
xa = x.clone()
xa = self.transforms(xa)
xt = self.transforms(xt)
# pass through network
y = self.encoder(x)
ya = self.encoder(xa)
yt = self.encoder(xt)
# invariance
inv_loss = self.inv_loss_fn(y, ya)
# shift-entropy
shift_entropy_loss = self.sce_loss_fn(ya, yt, n_steps)
# equivariance
equiv_loss = self.equiv_loss_fn(ya, yt, n_steps) # WARNING: augmented view is y2t!
# weighting
total_loss = self.loss_weighting.combine_losses(invariance=inv_loss,
shift_entropy=shift_entropy_loss,
equivariance=equiv_loss)
# add elems to dict
loss_dict = dict(invariance=inv_loss,
equivariance=equiv_loss,
shift_entropy=shift_entropy_loss,
loss=total_loss)
self.log_dict({f"loss/{k}/train": v for k, v in loss_dict.items()}, sync_dist=False)
return total_loss
def validation_step(self, batch, batch_idx):
x, pitch = batch
return self.forward(x), pitch
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
r"""We store the hyperparameters of the encoder for inference from outside.
It is not used in this repo but enables to load the model from the pip-installable inference repository.
"""
checkpoint["hparams"] = remove_omegaconf_dependencies(self.hyperparams)
checkpoint['hcqt_params'] = remove_omegaconf_dependencies(self.trainer.datamodule.hcqt_kwargs)
def configure_optimizers(self) -> Mapping[str, Any]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
"""
optimizer = self.optimizer_cls(params=self.encoder.parameters())
monitor = dict(optimizer=optimizer)
if self.scheduler_cls is not None:
monitor["lr_scheduler"] = self.scheduler_cls(optimizer=optimizer)
return monitor
def estimate_shift(self) -> None:
r"""Estimate the shift to predict absolute pitches from relative activations"""
# 0. Define labels
labels = torch.arange(60, 72)
# 1. Generate synthetic audio and convert it to HCQT
sr = 16000
dm = self.trainer.datamodule
batch = []
for p in labels:
audio = generate_synth_data(p, sr=sr)
hcqt = dm.hcqt(audio, sr)
batch.append(hcqt[0])
# 2. Stack batch and apply final transforms
x = torch.stack(batch, dim=0).to(self.device)
x = dm.transforms(torch.view_as_complex(x))
# 3. Pass it through the module
preds = self.forward(x, shift=False)
# 4. Compute the difference between the predictions and the expected values
diff = preds - labels.to(self.device)
# 5. Define the shift as the median distance and check that the std is low-enough
shift, std = diff.median(), diff.std()
log.info(f"Estimated shift: {shift.cpu().item():.3f} (std = {std.cpu().item():.3f})")
# 6. Update `self.shift` value
self.shift.fill_(shift)
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import hydra
import lightning as L
import rootutils
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
from src.utils import (
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
register_resolvers,
task_wrapper
)
log = logging.getLogger(__name__)
register_resolvers()
@task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: A DictConfig configuration composed by Hydra.
:return: A tuple with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
# automatically resume from last checkpoint if exists and ckpt_path not manually specified
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
"""Main entry point for training.
:param cfg: DictConfig configuration composed by Hydra.
:return: Optional[float] with optimized metric value.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
main()
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.logging_utils import log_hyperparameters
from src.utils.reduce_activations import reduce_activations
from src.utils.resolvers import register_resolvers
from src.utils.rich_utils import enforce_tags, print_config_tree
from src.utils.utils import extras, get_metric_value, task_wrapper, remove_omegaconf_dependencies
import torch
def mid_to_hz(pitch: int):
return 440 * 2 ** ((pitch - 69) / 12)
def generate_synth_data(pitch: int, num_harmonics: int = 5, duration=2, sr=16000):
f0 = mid_to_hz(pitch)
t = torch.arange(0, duration, 1/sr)
harmonics = torch.stack([
torch.cos(2 * torch.pi * k * f0 * t + torch.rand(()))
for k in range(1, num_harmonics+1)
], dim=1)
# volume = torch.randn(()) * torch.arange(num_harmonics).neg().div(0.5).exp()
volume = torch.rand(num_harmonics)
volume[0] = 1
volume *= torch.randn(())
audio = torch.sum(volume * harmonics, dim=1)
return audio
import logging
from typing import List
import hydra
from lightning import Callback
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
log = logging.getLogger(__name__)
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config.
:param callbacks_cfg: A DictConfig object containing callback configurations.
:return: A list of instantiated callbacks.
"""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config.
:param logger_cfg: A DictConfig object containing logger configurations.
:return: A list of instantiated loggers.
"""
logger: List[Logger] = []
if not logger_cfg:
log.warning("No logger configs found! Skipping...")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
return logger
import logging
from typing import Any, Dict
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import OmegaConf
log = logging.getLogger(__name__)
@rank_zero_only
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
"""Controls which config parts are saved by Lightning loggers.
Additionally saves:
- Number of model parameters
:param object_dict: A dictionary containing the following objects:
- `"cfg"`: A DictConfig object containing the main config.
- `"model"`: The Lightning model.
- `"trainer"`: The Lightning trainer.
"""
hparams = {}
cfg = OmegaConf.to_container(object_dict["cfg"])
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
hparams["model"] = cfg["model"]
hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"]
hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")
hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)
import torch
def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor:
r"""
Args:
activations: tensor of probability activations, shape (batch_size, num_bins)
reduction (str): reduction method to compute pitch out of activations,
choose between "argmax", "mean", "alwa".
Returns:
torch.Tensor: pitches as fractions of MIDI semitones, shape (batch_size)
"""
device = activations.device
num_bins = activations.size(1)
bps, r = divmod(num_bins, 128)
assert r == 0, "Activations should have output size 128*bins_per_semitone"
if reduction == "argmax":
pred = activations.argmax(dim=1)
return pred.float() / bps
all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps)
if reduction == "mean":
return torch.mm(activations, all_pitches)
if reduction == "alwa": # argmax-local weighted averaging, see https://github.com/marl/crepe
center_bin = activations.argmax(dim=1, keepdim=True)
window = torch.arange(1, 2 * bps, device=device) - bps
indices = (window + center_bin).clip_(min=0, max=num_bins - 1)
cropped_activations = activations.gather(1, indices)
cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(1, indices)
return (cropped_activations * cropped_pitches).sum(dim=1) / cropped_activations.sum(dim=1)
raise ValueError
r"""Custom functions to register by Hydra"""
from typing import Callable, Dict
from omegaconf import OmegaConf
def register_custom_resolvers(extra_resolvers: Dict[str, Callable] = None):
"""Wrap your main function with this.
You can pass extra kwargs, e.g. `version_base` introduced in 1.2.
"""
extra_resolvers = extra_resolvers or {}
for name, resolver in extra_resolvers.items():
OmegaConf.register_new_resolver(name, resolver)
def register_resolvers():
register_custom_resolvers({
"eval": eval,
"len": len
})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment