From 1c57b6b169a403b6022c83b59139236b2b3ca731 Mon Sep 17 00:00:00 2001 From: Alain Riou <alain.riou14000@yahoo.com> Date: Sun, 10 Dec 2023 22:06:17 +0100 Subject: [PATCH] handle inference mode --- README.md | 4 +- pesto/core.py | 104 ++++++++++++++++++++++++++----------------------- pesto/data.py | 6 +-- pesto/utils.py | 4 +- 4 files changed, 61 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 633b184..f703e23 100644 --- a/README.md +++ b/README.md @@ -146,9 +146,7 @@ By default, the function `pesto.predict` takes an audio waveform represented as However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications. `pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly. -Note that batched predictions are available only from the Python API and not from the CLI because: -- handling audios of different lengths is annoying, I don't want to bother with that -- when estimating pitch on +Note that batched predictions are available only from the Python API and not from the CLI because handling audios of different lengths is annoying, and I don't want to bother with that. ## Performances diff --git a/pesto/core.py b/pesto/core.py index c55ac4a..30e2e98 100644 --- a/pesto/core.py +++ b/pesto/core.py @@ -6,11 +6,10 @@ import torch import torchaudio from tqdm import tqdm -from .utils import load_model, load_dataprocessor, reduce_activation from .export import export +from .utils import load_model, load_dataprocessor, reduce_activation -@torch.inference_mode() def predict( x: torch.Tensor, sr: Optional[int] = None, @@ -19,7 +18,9 @@ def predict( step_size: Optional[float] = None, reduction: str = "argmax", num_chunks: int = 1, - convert_to_freq: bool = False + convert_to_freq: bool = False, + inference_mode: bool = True, + no_grad: bool = True ): r"""Main prediction function. @@ -37,53 +38,58 @@ def predict( Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage and prevent out-of-memory errors. convert_to_freq (bool): whether predictions should be converted to frequencies or not. + inference_mode (bool): whether to run with `torch.inference_mode`. + no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored. """ - # convert to mono - assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}" - batch_size = x.size(0) if x.ndim == 3 else None - x = x.mean(dim=-2) - - if data_preprocessor is None: - assert step_size is not None, \ - "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)" - data_preprocessor = load_dataprocessor(step_size=step_size / 1000., device=x.device) - - # If the sampling rate has changed, change the sampling rate accordingly - # It will automatically recompute the CQT kernels if needed - data_preprocessor.sampling_rate = sr - - if isinstance(model, str): - model = load_model(model, device=x.device) - - # apply model - cqt = data_preprocessor(x) - try: - activations = torch.cat([ - model(chunk) for chunk in cqt.chunk(chunks=num_chunks) - ]) - except torch.cuda.OutOfMemoryError: - raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. " - "Please increase the number of chunks with option `-c`/`--chunks` " - "to reduce GPU memory usage.") - - if batch_size: - total_batch_size, num_predictions = activations.size() - activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions) - - # shift activations as it should (PESTO predicts pitches up to an additive constant) - activations = activations.roll(model.abs_shift.cpu().item(), dims=-1) - - # convert model predictions to pitch values - pitch = reduce_activation(activations, reduction=reduction) - if convert_to_freq: - pitch = 440 * 2 ** ((pitch - 69) / 12) - - # for now, confidence is computed very naively just based on volume - confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch) - conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values - confidence = (confidence - conf_min) / (conf_max - conf_min) - - timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size + inference_mode = inference_mode and no_grad + with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode): + # convert to mono + assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}" + batch_size = x.size(0) if x.ndim == 3 else None + x = x.mean(dim=-2) + + if data_preprocessor is None: + assert step_size is not None and sr is not None, \ + "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)" + data_preprocessor = load_dataprocessor(step_size=step_size / 1000., sampling_rate=sr, device=x.device) + + # If the sampling rate has changed, change the sampling rate accordingly + # It will automatically recompute the CQT kernels if needed + if sr is not None: + data_preprocessor.sampling_rate = sr + + if isinstance(model, str): + model = load_model(model, device=x.device) + + # apply model + cqt = data_preprocessor(x) + try: + activations = torch.cat([ + model(chunk) for chunk in cqt.chunk(chunks=num_chunks) + ]) + except torch.cuda.OutOfMemoryError: + raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. " + "Please increase the number of chunks with option `-c`/`--chunks` " + "to reduce GPU memory usage.") + + if batch_size: + total_batch_size, num_predictions = activations.size() + activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions) + + # shift activations as it should (PESTO predicts pitches up to an additive constant) + activations = activations.roll(model.abs_shift.cpu().item(), dims=-1) + + # convert model predictions to pitch values + pitch = reduce_activation(activations, reduction=reduction) + if convert_to_freq: + pitch = 440 * 2 ** ((pitch - 69) / 12) + + # for now, confidence is computed very naively just based on volume + confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch) + conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values + confidence = (confidence - conf_min) / (conf_max - conf_min) + + timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size return timesteps, pitch, confidence, activations diff --git a/pesto/data.py b/pesto/data.py index 326d047..f5ed52e 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -35,13 +35,13 @@ class DataProcessor(nn.Module): self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5) self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone + # register a dummy tensor to get implicit access to the module's device + self.register_buffer("_device", torch.zeros(()), persistent=False) + # sampling rate is lazily initialized if sampling_rate is not None: self.sampling_rate = sampling_rate - # register a dummy tensor to get implicit access to the module's device - self.register_buffer("_device", torch.zeros(()), persistent=False) - def forward(self, x: torch.Tensor): r""" diff --git a/pesto/utils.py b/pesto/utils.py index f3fbe10..1872260 100644 --- a/pesto/utils.py +++ b/pesto/utils.py @@ -8,8 +8,8 @@ from .data import DataProcessor from .model import PESTOEncoder -def load_dataprocessor(step_size, device: Optional[torch.device] = None): - return DataProcessor(step_size=step_size, **cqt_args).to(device) +def load_dataprocessor(step_size, sampling_rate: Optional[int] = None, device: Optional[torch.device] = None): + return DataProcessor(step_size=step_size, sampling_rate=sampling_rate, **cqt_args).to(device) def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: -- GitLab