Skip to content
Snippets Groups Projects
Commit 752df3b7 authored by Ben Hayes's avatar Ben Hayes
Browse files

Persist CQT device

parent afa44099
No related branches found
No related tags found
No related merge requests found
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,10 +12,12 @@ class DataProcessor(nn.Module): ...@@ -10,10 +12,12 @@ class DataProcessor(nn.Module):
Args: Args:
step_size (float): step size between consecutive CQT frames (in SECONDS) step_size (float): step size between consecutive CQT frames (in SECONDS)
""" """
_sampling_rate: Optional[int] = None
def __init__(self, def __init__(self,
step_size: float, step_size: float,
bins_per_semitone: int = 3, bins_per_semitone: int = 3,
device: torch.device = torch.device("cpu"), sampling_rate: int = 44100,
**cqt_kwargs): **cqt_kwargs):
super(DataProcessor, self).__init__() super(DataProcessor, self).__init__()
self.bins_per_semitone = bins_per_semitone self.bins_per_semitone = bins_per_semitone
...@@ -31,9 +35,8 @@ class DataProcessor(nn.Module): ...@@ -31,9 +35,8 @@ class DataProcessor(nn.Module):
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# handling different sampling rates # handling different sampling rates
self._sampling_rate = None
self.step_size = step_size self.step_size = step_size
self.device = device self.sampling_rate = sampling_rate
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
r""" r"""
...@@ -58,7 +61,11 @@ class DataProcessor(nn.Module): ...@@ -58,7 +61,11 @@ class DataProcessor(nn.Module):
return log_cqt return log_cqt
def _init_cqt_layer(self, sr: int, hop_length: int): def _init_cqt_layer(self, sr: int, hop_length: int):
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device) if self.cqt is not None:
device = self.cqt.cqt_kernels_real.device
else:
device = None
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
@property @property
def sampling_rate(self) -> int: def sampling_rate(self) -> int:
......
...@@ -9,7 +9,7 @@ from .model import PESTOEncoder ...@@ -9,7 +9,7 @@ from .model import PESTOEncoder
def load_dataprocessor(step_size, device: Optional[torch.device] = None): def load_dataprocessor(step_size, device: Optional[torch.device] = None):
return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device) return DataProcessor(step_size=step_size, **cqt_args).to(device)
def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment