Skip to content
Snippets Groups Projects
Commit be6c0b9f authored by Alain Riou's avatar Alain Riou
Browse files

initial commit

parents
Branches
No related tags found
No related merge requests found
Showing
with 645 additions and 0 deletions
# @package _global_
# to execute this experiment run:
# python train.py experiment=example
defaults:
- override /data: mnist
- override /model: mnist
- override /callbacks: default
- override /trainer: default
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["mnist", "simple_dense_net"]
seed: 12345
trainer:
min_epochs: 10
max_epochs: 10
gradient_clip_val: 0.5
model:
optimizer:
lr: 0.002
net:
lin1_size: 128
lin2_size: 256
lin3_size: 64
compile: false
data:
batch_size: 64
logger:
wandb:
tags: ${tags}
group: "mnist"
aim:
experiment: "mnist"
# disable python warnings if they annoy you
ignore_warnings: False
# ask user for tags if none are provided in the config
enforce_tags: True
# pretty print config tree at the start of the run using Rich library
print_config: True
save_config: true
# https://hydra.cc/docs/configure_hydra/intro/
# enable color logging
defaults:
- override hydra_logging: colorlog
- override job_logging: colorlog
# output directory, generated dynamically on each run
run:
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
sweep:
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
subdir: ${hydra.job.num}
job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${task_name}.log
# csv logger built in lightning
csv:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "${paths.output_dir}"
name: "csv/"
prefix: ""
# train with many loggers at once
defaults:
# - comet
- csv
# - mlflow
# - neptune
- tensorboard
- wandb
# https://www.tensorflow.org/tensorboard/
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.output_dir}/tensorboard/"
name: null
log_graph: False
default_hp_metric: True
prefix: ""
# version: ""
# https://wandb.ai
wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: false
project: "PESTO"
log_model: False # upload lightning ckpts
save_code: true
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
group: ""
tags: []
job_type: ""
_target_: src.models.pesto.PESTO
encoder:
_target_: src.models.networks.resnet1d.Resnet1d
n_chan_input: ${len:${data.harmonics}}
n_chan_layers:
- 40
- 30
- 30
- 10
- 3
n_prefilt_layers: 2
prefilt_kernel_size: 15
residual: true
n_bins_in: ${eval:88 * ${data.bins_per_semitone}}
output_dim: ${eval:128 * ${data.bins_per_semitone}}
equiv_loss_fn:
_target_: src.losses.equivariance.PowerSeries
value: 1.019440644 # 2 ** (1/36)
power_min: ${eval:${model.equiv_loss_fn.power_max} - ${model.encoder.output_dim}}
power_max: 1
tau: 0.122462048 # 2 ** (1/6) - 1
inv_loss_fn:
_target_: src.losses.entropy.CrossEntropyLoss
symmetric: true
detach_targets: true
sce_loss_fn:
_target_: src.losses.entropy.ShiftCrossEntropy
pad_length: ${model.pitch_shift_kwargs.max_steps}
criterion: ${model.inv_loss_fn}
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 1e-4
weight_decay: 0
scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
_partial_: true
T_max: ${trainer.max_epochs}
pitch_shift_kwargs:
min_steps: ${eval:-${model.pitch_shift_kwargs.max_steps}}
max_steps: ${eval:${data.bins_per_semitone} * 11 // 2}
transforms:
- _target_: src.data.transforms.BatchRandomNoise
min_snr: 0.1
max_snr: 2.
p: 0.7
- _target_: src.data.transforms.BatchRandomGain
min_gain: 0.5
max_gain: 1.5
p: 0.7
# path to root directory
# this requires PROJECT_ROOT environment variable to exist
# you can replace it with "." if you want the root to be the current working directory
root_dir: .
# path to data directory
data_dir: ${paths.root_dir}/data
# path to logging directory
log_dir: ${paths.root_dir}/logs/
# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like ckpts and metrics
output_dir: ${hydra:runtime.output_dir}
# where checkpoints should be stored
ckpt_dir: ${paths.output_dir}/checkpoints
\ No newline at end of file
# @package _global_
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: default
- model: default
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
- hydra: default
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# debugging config (enable through command line, e.g. `python train.py debug=default)
- debug: null
# task name, determines output directory path
task_name: "train"
# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["dev"]
# set False to skip model training
train: true
# evaluate on test set, using best model weights achieved during training
# lightning chooses best weights based on the metric specified in checkpoint callback
test: false
# simply provide checkpoint path to resume training
ckpt_path: null
# seed for random number generators in pytorch, numpy and python.random
seed: 0
defaults:
- default
accelerator: cpu
devices: 1
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.output_dir}
log_every_n_steps: 200
max_epochs: 50
accelerator: gpu
devices: 1
defaults:
- default
accelerator: gpu
devices: 1
name: pesto-full
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- antlr-python-runtime=4.9.3=pyhd8ed1ab_1
- appdirs=1.4.4=pyh9f0ad1d_0
- blas=1.0=mkl
- brotli-python=1.0.9=py311h6a678d5_7
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.12.12=h06a4308_0
- certifi=2023.11.17=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- click=8.1.7=unix_pyh707e725_0
- colorama=0.4.6=pyhd8ed1ab_0
- colorlog=6.8.0=py311h38be061_0
- cuda-cudart=11.8.89=0
- cuda-cupti=11.8.87=0
- cuda-libraries=11.8.0=0
- cuda-nvrtc=11.8.89=0
- cuda-nvtx=11.8.86=0
- cuda-runtime=11.8.0=0
- docker-pycreds=0.4.0=py_0
- filelock=3.13.1=py311h06a4308_0
- fsspec=2023.12.2=pyhca7485f_0
- future=0.18.3=pyhd8ed1ab_0
- gitdb=4.0.11=pyhd8ed1ab_0
- gitpython=3.1.41=pyhd8ed1ab_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py311hc9b5ff0_0
- hydra-colorlog=1.2.0=pyhd8ed1ab_1
- hydra-core=1.3.2=pyhd8ed1ab_0
- idna=3.6=pyhd8ed1ab_0
- importlib_resources=6.1.1=pyhd8ed1ab_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jinja2=3.1.2=py311h06a4308_0
- joblib=1.3.2=pyhd8ed1ab_0
- ld_impl_linux-64=2.38=h1181459_1
- libblas=3.9.0=1_h86c2bf4_netlib
- libcblas=3.9.0=5_h92ddd45_netlib
- libcublas=11.11.3.6=0
- libcufft=10.9.0.58=0
- libcufile=1.8.1.2=0
- libcurand=10.3.4.107=0
- libcusolver=11.4.1.48=0
- libcusparse=11.7.5.86=0
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=13.2.0=h807b86a_3
- libgfortran-ng=13.2.0=h69a702a_3
- libgfortran5=13.2.0=ha4646dd_3
- libgomp=13.2.0=h807b86a_3
- liblapack=3.9.0=5_h92ddd45_netlib
- libnpp=11.8.0.86=0
- libnvjpeg=11.9.0.86=0
- libprotobuf=3.20.3=he621ea3_0
- libstdcxx-ng=13.2.0=h7e041cc_3
- libuuid=1.41.5=h5eee18b_0
- lightning=2.1.3=pyhd8ed1ab_1
- lightning-utilities=0.10.0=pyhd8ed1ab_0
- llvm-openmp=14.0.6=h9e868ea_0
- markdown-it-py=3.0.0=pyhd8ed1ab_0
- markupsafe=2.1.3=py311h5eee18b_0
- mdurl=0.1.2=pyhd8ed1ab_0
- mir_eval=0.6=pyh9f0ad1d_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py311h5eee18b_1
- mkl_fft=1.3.8=py311h5eee18b_0
- mkl_random=1.2.4=py311hdb19cb5_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py311h06a4308_0
- ncurses=6.4=h6a678d5_0
- networkx=3.1=py311h06a4308_0
- numpy=1.26.3=py311h08b1b3b_0
- numpy-base=1.26.3=py311hf175353_0
- omegaconf=2.3.0=pyhd8ed1ab_0
- openssl=3.2.0=hd590300_1
- packaging=23.2=pyhd8ed1ab_0
- pathtools=0.1.2=py_1
- pip=23.3.1=py311h06a4308_0
- protobuf=3.20.3=py311hcafe171_1
- psutil=5.9.7=py311h459d7ec_0
- pygments=2.17.2=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.11.7=h955ad1f_0
- python_abi=3.11=2_cp311
- pytorch=2.1.2=py3.11_cuda11.8_cudnn8.7.0_0
- pytorch-cuda=11.8=h7e8668a_5
- pytorch-lightning=2.1.3=pyhd8ed1ab_0
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.1=py311h5eee18b_0
- readline=8.2=h5eee18b_0
- requests=2.31.0=pyhd8ed1ab_0
- rich=13.7.0=pyhd8ed1ab_0
- scikit-learn=1.3.2=py311hc009520_2
- scipy=1.11.4=py311h64a7726_0
- sentry-sdk=1.39.2=pyhd8ed1ab_0
- setproctitle=1.3.3=py311h459d7ec_0
- setuptools=68.2.2=py311h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- smmap=5.0.0=pyhd8ed1ab_0
- sqlite=3.41.2=h5eee18b_0
- sympy=1.12=pyh04b8f61_3
- tbb=2021.8.0=hdb19cb5_0
- threadpoolctl=3.2.0=pyha21a80b_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=2.1.2=py311_cu118
- torchmetrics=1.2.1=pyhd8ed1ab_0
- torchtriton=2.1.0=py311
- tqdm=4.66.1=pyhd8ed1ab_0
- typing-extensions=4.9.0=pyhd3eb1b0_0
- typing_extensions=4.9.0=py311h06a4308_0
- tzdata=2023d=h04d1e81_0
- urllib3=2.1.0=pyhd8ed1ab_0
- wandb=0.16.2=pyhd8ed1ab_0
- wheel=0.41.2=py311h06a4308_0
- xz=5.4.5=h5eee18b_0
- yaml=0.2.5=h7b6447c_0
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=h5eee18b_0
- pip:
- nnaudio==0.3.2
- python-dotenv==1.0.0
- rootutils==1.0.7
prefix: /home/alain/miniconda3/envs/pesto-full
# pytorch
torch>=2.1.2
torchaudio>=2.1.2
# lightning
lightning>=2.1.3
# hydra
hydra-core>=1.3.2
hydra-colorlog
# nnAudio
nnAudio>=0.3.2
# utils
rich>=13.7.0
rootutils>=1.0.7
from collections import defaultdict
from math import cos, pi
from typing import Any, Mapping, Optional
import torch
import lightning.pytorch as pl
class LossWeighting(pl.Callback):
def __init__(self, weights: Mapping[str, float] | None = None) -> None:
self.weights = weights if weights is not None else defaultdict(lambda: 1.)
def on_train_batch_end(self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int) -> None:
pl_module.log_dict({f"hparams/{k}_weight": v for k, v in self.weights.items()}, prog_bar=False, logger=True)
def combine_losses(self, **losses):
self.update_weights(losses)
return sum([self.weights[key] * losses[key] for key in self.weights.keys()])
def update_weights(self, losses):
pass
def __str__(self):
params = '\n'.join(f"\t{k}: {v}" for k, v in vars(self).items())
return self.__class__.__name__ + "(\n" + params + "\n)"
class WarmupLossWeighting(LossWeighting):
def __init__(
self,
weights: Mapping[str, float],
warmup_term: str,
warmup_epochs: int = 10,
initial_weight: float = 0.,
warmup_fn: str = "linear"
):
super(WarmupLossWeighting, self).__init__(weights)
self.key = warmup_term
self.warmup_epochs = warmup_epochs
self.initial_weight = initial_weight
self.target_weight = self.weights[self.key]
self.warmup_fn = warmup_fn
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
epoch = trainer.current_epoch
if epoch >= self.warmup_epochs:
self.weights[self.key] = self.target_weight
return
# compute new value for the weight
if self.warmup_fn == "linear":
weight = (self.target_weight - self.initial_weight) * epoch / self.warmup_epochs + self.initial_weight
elif self.warmup_fn == "cosine":
weight = 0.5 * (1 - cos(pi * epoch / self.warmup_epochs)) * (self.target_weight - self.initial_weight) + self.initial_weight
else:
raise NotImplementedError(f"This warmup schedule is not supported: `{self.warmup_fn}`.")
self.weights[self.key] = weight
class GradientsLossWeighting(LossWeighting):
def __init__(self,
weights: Mapping[str, float] | None = None,
last_layer: Optional[torch.Tensor] = None,
ema_rate: float = 0.):
super(GradientsLossWeighting, self).__init__(weights)
self.last_layer = last_layer
self.ema_rate = ema_rate
self.grads = {k: 1-v for k, v in weights.items()}
self.weights_tensor = None
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self.weights_tensor = torch.zeros(len(self.weights.keys()), device=pl_module.device)
def update_weights(self, losses):
# compute gradient w.r.t last layer for each loss term
for i, (k, loss) in enumerate(losses.items()):
if not loss.requires_grad:
return
grads = torch.autograd.grad(loss, self.get_last_layer(k), retain_graph=True)[0].norm().detach()
old_grads = self.grads[k]
if old_grads is not None:
grads = self.ema_rate * old_grads + (1 - self.ema_rate) * grads
self.grads[k] = grads
self.weights_tensor[i] = grads
# compute the weight of this loss based on these gradients
self.weights_tensor = 1 - self.weights_tensor / self.weights_tensor.sum().clip(min=1e-7)
# associate each weight with the right loss
for i, k in enumerate(losses.keys()):
self.weights[k] = self.weights_tensor[i]
def get_last_layer(self, key: str) -> torch.Tensor:
if torch.is_tensor(self.last_layer):
return self.last_layer
return self.last_layer[key]
import logging
from typing import Mapping
import numpy as np
import torch
import mir_eval.melody as mir_eval
import lightning.pytorch as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_only
try:
import wandb
WANDB_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
WANDB_AVAILABLE = False
log = logging.getLogger(__name__)
def wandb_only(func):
def wrapper(*args, **kwargs):
if WANDB_AVAILABLE:
return func(*args, **kwargs)
log.warning(f"Method {func.__name__} can be used only with wandb.")
return None
return wrapper
class MIREvalCallback(pl.Callback):
def __init__(self,
bins_per_semitone: int = 1,
reduction: str = "alwa",
cdf_resolution: int = 0):
super(MIREvalCallback, self).__init__()
self.bps = bins_per_semitone
self.reduction = reduction
self.logger = None
self.cdf_resolution = cdf_resolution
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
for logger in pl_module.loggers:
if isinstance(logger, WandbLogger):
self.logger = logger
break
if self.logger is None:
global WANDB_AVAILABLE
WANDB_AVAILABLE = False
log.warning(f"As of now, `{self.__class__.__name__}` is only compatible with `WandbLogger`. "
f"Loggers: {pl_module.loggers}.")
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
predictions = torch.cat(pl_module.predictions).cpu().numpy()
labels = torch.cat(pl_module.labels).cpu().numpy()
log_path = "accuracy/{}"
metrics = self.compute_metrics(predictions, labels)
pl_module.log_dict({log_path.format(k): v for k, v in metrics.items()}, sync_dist=True)
self.plot_pitch_error_cdf(predictions, labels, labels > 0)
@staticmethod
def compute_metrics(predictions: np.ndarray, labels: np.ndarray) -> Mapping[str, float]:
# convert semitones to cents and infer voicing
ref_cent, ref_voicing = mir_eval.freq_to_voicing(100 * labels)
est_cent, est_voicing = mir_eval.freq_to_voicing(100 * predictions)
# compute mir_eval metrics
metrics = {}
metrics["RPA"] = mir_eval.raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent)
metrics["RCA"] = mir_eval.raw_chroma_accuracy(ref_voicing, ref_cent, est_voicing, est_cent)
metrics["OA"] = mir_eval.overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent)
return metrics
@wandb_only
def plot_pitch_error_cdf(self, predictions: np.ndarray, labels: np.ndarray, voiced: np.ndarray):
sorted_errors = np.sort(np.abs(predictions[voiced] - labels[voiced]))
total = len(sorted_errors)
cumul_probs = np.arange(1, total + 1) / total
cols = ["Pitch error (semitones)", "Cumulative Density Function"]
fig = wandb.Table(data=list(zip(sorted_errors[::self.cdf_resolution], cumul_probs[::self.cdf_resolution])),
columns=cols)
self.logger.experiment.log({"pitch_error": wandb.plot.line(fig, *cols)})
import logging
import numpy as np
import torch
import lightning.pytorch as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_only
try:
import wandb
WANDB_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
WANDB_AVAILABLE = False
log = logging.getLogger(__name__)
def wandb_only(func):
def wrapper(*args, **kwargs):
if WANDB_AVAILABLE:
return func(*args, **kwargs)
log.warning(f"Method {func.__name__} can be used only with wandb.")
return None
return wrapper
class PitchHistogramCallback(pl.Callback):
def __init__(self):
super(PitchHistogramCallback, self).__init__()
self.logger = None
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
for logger in pl_module.loggers:
if isinstance(logger, WandbLogger):
self.logger = logger
break
if self.logger is None:
global WANDB_AVAILABLE
WANDB_AVAILABLE = False
log.warning(f"As of now, `{self.__class__.__name__}` is only compatible with `WandbLogger`. "
f"Loggers: {pl_module.loggers}.")
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
predictions = torch.cat(pl_module.predictions)
self.plot_pitch_histogram(predictions + pl_module.shift) # we unshift distributions there to see better what's going on
@wandb_only
def plot_pitch_histogram(self, predictions: np.ndarray):
fig = wandb.Table(data=[[p] for p in predictions], columns=["predictions"])
self.logger.experiment.log({"pitch_histogram": wandb.plot.histogram(fig,
"predictions",
title="Pitch histogram")})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment