Skip to content
Snippets Groups Projects
Commit 752df3b7 authored by Ben Hayes's avatar Ben Hayes
Browse files

Persist CQT device

parent afa44099
No related branches found
No related tags found
No related merge requests found
from typing import Optional
import torch
import torch.nn as nn
......@@ -10,10 +12,12 @@ class DataProcessor(nn.Module):
Args:
step_size (float): step size between consecutive CQT frames (in SECONDS)
"""
_sampling_rate: Optional[int] = None
def __init__(self,
step_size: float,
bins_per_semitone: int = 3,
device: torch.device = torch.device("cpu"),
sampling_rate: int = 44100,
**cqt_kwargs):
super(DataProcessor, self).__init__()
self.bins_per_semitone = bins_per_semitone
......@@ -31,9 +35,8 @@ class DataProcessor(nn.Module):
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# handling different sampling rates
self._sampling_rate = None
self.step_size = step_size
self.device = device
self.sampling_rate = sampling_rate
def forward(self, x: torch.Tensor):
r"""
......@@ -58,7 +61,11 @@ class DataProcessor(nn.Module):
return log_cqt
def _init_cqt_layer(self, sr: int, hop_length: int):
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device)
if self.cqt is not None:
device = self.cqt.cqt_kernels_real.device
else:
device = None
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
@property
def sampling_rate(self) -> int:
......
......@@ -9,7 +9,7 @@ from .model import PESTOEncoder
def load_dataprocessor(step_size, device: Optional[torch.device] = None):
return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device)
return DataProcessor(step_size=step_size, **cqt_args).to(device)
def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment