Skip to content
Snippets Groups Projects
Commit 1c57b6b1 authored by Alain Riou's avatar Alain Riou
Browse files

handle inference mode

parent 153c6751
Branches
No related tags found
No related merge requests found
......@@ -146,9 +146,7 @@ By default, the function `pesto.predict` takes an audio waveform represented as
However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications.
`pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly.
Note that batched predictions are available only from the Python API and not from the CLI because:
- handling audios of different lengths is annoying, I don't want to bother with that
- when estimating pitch on
Note that batched predictions are available only from the Python API and not from the CLI because handling audios of different lengths is annoying, and I don't want to bother with that.
## Performances
......
......@@ -6,11 +6,10 @@ import torch
import torchaudio
from tqdm import tqdm
from .utils import load_model, load_dataprocessor, reduce_activation
from .export import export
from .utils import load_model, load_dataprocessor, reduce_activation
@torch.inference_mode()
def predict(
x: torch.Tensor,
sr: Optional[int] = None,
......@@ -19,7 +18,9 @@ def predict(
step_size: Optional[float] = None,
reduction: str = "argmax",
num_chunks: int = 1,
convert_to_freq: bool = False
convert_to_freq: bool = False,
inference_mode: bool = True,
no_grad: bool = True
):
r"""Main prediction function.
......@@ -37,19 +38,24 @@ def predict(
Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage
and prevent out-of-memory errors.
convert_to_freq (bool): whether predictions should be converted to frequencies or not.
inference_mode (bool): whether to run with `torch.inference_mode`.
no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored.
"""
inference_mode = inference_mode and no_grad
with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode):
# convert to mono
assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}"
batch_size = x.size(0) if x.ndim == 3 else None
x = x.mean(dim=-2)
if data_preprocessor is None:
assert step_size is not None, \
assert step_size is not None and sr is not None, \
"If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)"
data_preprocessor = load_dataprocessor(step_size=step_size / 1000., device=x.device)
data_preprocessor = load_dataprocessor(step_size=step_size / 1000., sampling_rate=sr, device=x.device)
# If the sampling rate has changed, change the sampling rate accordingly
# It will automatically recompute the CQT kernels if needed
if sr is not None:
data_preprocessor.sampling_rate = sr
if isinstance(model, str):
......
......@@ -35,13 +35,13 @@ class DataProcessor(nn.Module):
self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# register a dummy tensor to get implicit access to the module's device
self.register_buffer("_device", torch.zeros(()), persistent=False)
# sampling rate is lazily initialized
if sampling_rate is not None:
self.sampling_rate = sampling_rate
# register a dummy tensor to get implicit access to the module's device
self.register_buffer("_device", torch.zeros(()), persistent=False)
def forward(self, x: torch.Tensor):
r"""
......
......@@ -8,8 +8,8 @@ from .data import DataProcessor
from .model import PESTOEncoder
def load_dataprocessor(step_size, device: Optional[torch.device] = None):
return DataProcessor(step_size=step_size, **cqt_args).to(device)
def load_dataprocessor(step_size, sampling_rate: Optional[int] = None, device: Optional[torch.device] = None):
return DataProcessor(step_size=step_size, sampling_rate=sampling_rate, **cqt_args).to(device)
def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment