From 752df3b7107807b4ec53cbfc5bd9341c857e635f Mon Sep 17 00:00:00 2001
From: Ben Hayes <Benjamin.Hayes@sony.com>
Date: Tue, 28 Nov 2023 13:37:39 +0100
Subject: [PATCH] Persist CQT device

---
 pesto/data.py  | 15 +++++++++++----
 pesto/utils.py |  2 +-
 2 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/pesto/data.py b/pesto/data.py
index ea05372..4c51baf 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 import torch
 import torch.nn as nn
 
@@ -10,10 +12,12 @@ class DataProcessor(nn.Module):
     Args:
         step_size (float): step size between consecutive CQT frames (in SECONDS)
     """
+    _sampling_rate: Optional[int] = None
+
     def __init__(self,
                  step_size: float,
                  bins_per_semitone: int = 3,
-                 device: torch.device = torch.device("cpu"),
+                 sampling_rate: int = 44100,
                  **cqt_kwargs):
         super(DataProcessor, self).__init__()
         self.bins_per_semitone = bins_per_semitone
@@ -31,9 +35,8 @@ class DataProcessor(nn.Module):
         self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
 
         # handling different sampling rates
-        self._sampling_rate = None
         self.step_size = step_size
-        self.device = device
+        self.sampling_rate = sampling_rate
 
     def forward(self, x: torch.Tensor):
         r"""
@@ -58,7 +61,11 @@ class DataProcessor(nn.Module):
         return log_cqt
 
     def _init_cqt_layer(self, sr: int, hop_length: int):
-        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device)
+        if self.cqt is not None:
+            device = self.cqt.cqt_kernels_real.device
+        else:
+            device = None
+        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
 
     @property
     def sampling_rate(self) -> int:
diff --git a/pesto/utils.py b/pesto/utils.py
index 46281eb..f3fbe10 100644
--- a/pesto/utils.py
+++ b/pesto/utils.py
@@ -9,7 +9,7 @@ from .model import PESTOEncoder
 
 
 def load_dataprocessor(step_size, device: Optional[torch.device] = None):
-    return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device)
+    return DataProcessor(step_size=step_size, **cqt_args).to(device)
 
 
 def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
-- 
GitLab