diff --git a/pesto/data.py b/pesto/data.py index ea053726694aa78a6645a7d0c2223b02d81a2f2a..4c51bafa830fdaa26b3645ecf7668c22e5049c31 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn as nn @@ -10,10 +12,12 @@ class DataProcessor(nn.Module): Args: step_size (float): step size between consecutive CQT frames (in SECONDS) """ + _sampling_rate: Optional[int] = None + def __init__(self, step_size: float, bins_per_semitone: int = 3, - device: torch.device = torch.device("cpu"), + sampling_rate: int = 44100, **cqt_kwargs): super(DataProcessor, self).__init__() self.bins_per_semitone = bins_per_semitone @@ -31,9 +35,8 @@ class DataProcessor(nn.Module): self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone # handling different sampling rates - self._sampling_rate = None self.step_size = step_size - self.device = device + self.sampling_rate = sampling_rate def forward(self, x: torch.Tensor): r""" @@ -58,7 +61,11 @@ class DataProcessor(nn.Module): return log_cqt def _init_cqt_layer(self, sr: int, hop_length: int): - self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device) + if self.cqt is not None: + device = self.cqt.cqt_kernels_real.device + else: + device = None + self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device) @property def sampling_rate(self) -> int: diff --git a/pesto/utils.py b/pesto/utils.py index 46281ebc1f4fff05c5c10fd14bd056244482039b..f3fbe10ce1d630ae41153ffa4612692e15ef06a1 100644 --- a/pesto/utils.py +++ b/pesto/utils.py @@ -9,7 +9,7 @@ from .model import PESTOEncoder def load_dataprocessor(step_size, device: Optional[torch.device] = None): - return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device) + return DataProcessor(step_size=step_size, **cqt_args).to(device) def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: