From 752df3b7107807b4ec53cbfc5bd9341c857e635f Mon Sep 17 00:00:00 2001 From: Ben Hayes <Benjamin.Hayes@sony.com> Date: Tue, 28 Nov 2023 13:37:39 +0100 Subject: [PATCH] Persist CQT device --- pesto/data.py | 15 +++++++++++---- pesto/utils.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pesto/data.py b/pesto/data.py index ea05372..4c51baf 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn as nn @@ -10,10 +12,12 @@ class DataProcessor(nn.Module): Args: step_size (float): step size between consecutive CQT frames (in SECONDS) """ + _sampling_rate: Optional[int] = None + def __init__(self, step_size: float, bins_per_semitone: int = 3, - device: torch.device = torch.device("cpu"), + sampling_rate: int = 44100, **cqt_kwargs): super(DataProcessor, self).__init__() self.bins_per_semitone = bins_per_semitone @@ -31,9 +35,8 @@ class DataProcessor(nn.Module): self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone # handling different sampling rates - self._sampling_rate = None self.step_size = step_size - self.device = device + self.sampling_rate = sampling_rate def forward(self, x: torch.Tensor): r""" @@ -58,7 +61,11 @@ class DataProcessor(nn.Module): return log_cqt def _init_cqt_layer(self, sr: int, hop_length: int): - self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device) + if self.cqt is not None: + device = self.cqt.cqt_kernels_real.device + else: + device = None + self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device) @property def sampling_rate(self) -> int: diff --git a/pesto/utils.py b/pesto/utils.py index 46281eb..f3fbe10 100644 --- a/pesto/utils.py +++ b/pesto/utils.py @@ -9,7 +9,7 @@ from .model import PESTOEncoder def load_dataprocessor(step_size, device: Optional[torch.device] = None): - return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device) + return DataProcessor(step_size=step_size, **cqt_args).to(device) def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: -- GitLab