From 1c57b6b169a403b6022c83b59139236b2b3ca731 Mon Sep 17 00:00:00 2001
From: Alain Riou <alain.riou14000@yahoo.com>
Date: Sun, 10 Dec 2023 22:06:17 +0100
Subject: [PATCH] handle inference mode

---
 README.md      |   4 +-
 pesto/core.py  | 104 ++++++++++++++++++++++++++-----------------------
 pesto/data.py  |   6 +--
 pesto/utils.py |   4 +-
 4 files changed, 61 insertions(+), 57 deletions(-)

diff --git a/README.md b/README.md
index 633b184..f703e23 100644
--- a/README.md
+++ b/README.md
@@ -146,9 +146,7 @@ By default, the function `pesto.predict` takes an audio waveform represented as
 However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications.
 `pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly.
 
-Note that batched predictions are available only from the Python API and not from the CLI because:
-- handling audios of different lengths is annoying, I don't want to bother with that
-- when estimating pitch on
+Note that batched predictions are available only from the Python API and not from the CLI because handling audios of different lengths is annoying, and I don't want to bother with that.
 
 ## Performances
 
diff --git a/pesto/core.py b/pesto/core.py
index c55ac4a..30e2e98 100644
--- a/pesto/core.py
+++ b/pesto/core.py
@@ -6,11 +6,10 @@ import torch
 import torchaudio
 from tqdm import tqdm
 
-from .utils import load_model, load_dataprocessor, reduce_activation
 from .export import export
+from .utils import load_model, load_dataprocessor, reduce_activation
 
 
-@torch.inference_mode()
 def predict(
         x: torch.Tensor,
         sr: Optional[int] = None,
@@ -19,7 +18,9 @@ def predict(
         step_size: Optional[float] = None,
         reduction: str = "argmax",
         num_chunks: int = 1,
-        convert_to_freq: bool = False
+        convert_to_freq: bool = False,
+        inference_mode: bool = True,
+        no_grad: bool = True
 ):
     r"""Main prediction function.
 
@@ -37,53 +38,58 @@ def predict(
             Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage
             and prevent out-of-memory errors.
         convert_to_freq (bool): whether predictions should be converted to frequencies or not.
+        inference_mode (bool): whether to run with `torch.inference_mode`.
+        no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored.
     """
-    # convert to mono
-    assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}"
-    batch_size = x.size(0) if x.ndim == 3 else None
-    x = x.mean(dim=-2)
-
-    if data_preprocessor is None:
-        assert step_size is not None, \
-            "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)"
-        data_preprocessor = load_dataprocessor(step_size=step_size / 1000., device=x.device)
-
-    # If the sampling rate has changed, change the sampling rate accordingly
-    # It will automatically recompute the CQT kernels if needed
-    data_preprocessor.sampling_rate = sr
-
-    if isinstance(model, str):
-        model = load_model(model, device=x.device)
-
-    # apply model
-    cqt = data_preprocessor(x)
-    try:
-        activations = torch.cat([
-            model(chunk) for chunk in cqt.chunk(chunks=num_chunks)
-        ])
-    except torch.cuda.OutOfMemoryError:
-        raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
-                                          "Please increase the number of chunks with option `-c`/`--chunks` "
-                                          "to reduce GPU memory usage.")
-
-    if batch_size:
-        total_batch_size, num_predictions = activations.size()
-        activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions)
-
-    # shift activations as it should (PESTO predicts pitches up to an additive constant)
-    activations = activations.roll(model.abs_shift.cpu().item(), dims=-1)
-
-    # convert model predictions to pitch values
-    pitch = reduce_activation(activations, reduction=reduction)
-    if convert_to_freq:
-        pitch = 440 * 2 ** ((pitch - 69) / 12)
-
-    # for now, confidence is computed very naively just based on volume
-    confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch)
-    conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
-    confidence = (confidence - conf_min) / (conf_max - conf_min)
-
-    timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size
+    inference_mode = inference_mode and no_grad
+    with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode):
+        # convert to mono
+        assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}"
+        batch_size = x.size(0) if x.ndim == 3 else None
+        x = x.mean(dim=-2)
+
+        if data_preprocessor is None:
+            assert step_size is not None and sr is not None, \
+                "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)"
+            data_preprocessor = load_dataprocessor(step_size=step_size / 1000., sampling_rate=sr, device=x.device)
+
+        # If the sampling rate has changed, change the sampling rate accordingly
+        # It will automatically recompute the CQT kernels if needed
+        if sr is not None:
+            data_preprocessor.sampling_rate = sr
+
+        if isinstance(model, str):
+            model = load_model(model, device=x.device)
+
+        # apply model
+        cqt = data_preprocessor(x)
+        try:
+            activations = torch.cat([
+                model(chunk) for chunk in cqt.chunk(chunks=num_chunks)
+            ])
+        except torch.cuda.OutOfMemoryError:
+            raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
+                                              "Please increase the number of chunks with option `-c`/`--chunks` "
+                                              "to reduce GPU memory usage.")
+
+        if batch_size:
+            total_batch_size, num_predictions = activations.size()
+            activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions)
+
+        # shift activations as it should (PESTO predicts pitches up to an additive constant)
+        activations = activations.roll(model.abs_shift.cpu().item(), dims=-1)
+
+        # convert model predictions to pitch values
+        pitch = reduce_activation(activations, reduction=reduction)
+        if convert_to_freq:
+            pitch = 440 * 2 ** ((pitch - 69) / 12)
+
+        # for now, confidence is computed very naively just based on volume
+        confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch)
+        conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
+        confidence = (confidence - conf_min) / (conf_max - conf_min)
+
+        timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size
 
     return timesteps, pitch, confidence, activations
 
diff --git a/pesto/data.py b/pesto/data.py
index 326d047..f5ed52e 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -35,13 +35,13 @@ class DataProcessor(nn.Module):
         self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
         self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
 
+        # register a dummy tensor to get implicit access to the module's device
+        self.register_buffer("_device", torch.zeros(()), persistent=False)
+
         # sampling rate is lazily initialized
         if sampling_rate is not None:
             self.sampling_rate = sampling_rate
 
-        # register a dummy tensor to get implicit access to the module's device
-        self.register_buffer("_device", torch.zeros(()), persistent=False)
-
     def forward(self, x: torch.Tensor):
         r"""
 
diff --git a/pesto/utils.py b/pesto/utils.py
index f3fbe10..1872260 100644
--- a/pesto/utils.py
+++ b/pesto/utils.py
@@ -8,8 +8,8 @@ from .data import DataProcessor
 from .model import PESTOEncoder
 
 
-def load_dataprocessor(step_size, device: Optional[torch.device] = None):
-    return DataProcessor(step_size=step_size, **cqt_args).to(device)
+def load_dataprocessor(step_size, sampling_rate: Optional[int] = None, device: Optional[torch.device] = None):
+    return DataProcessor(step_size=step_size, sampling_rate=sampling_rate, **cqt_args).to(device)
 
 
 def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
-- 
GitLab