Skip to content
Snippets Groups Projects
Commit 153c6751 authored by Alain Riou's avatar Alain Riou
Browse files

lazy device initialization

parent 8201e4f7
Branches
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
## v0.1.1 ## v0.1.1
- solve issue when exporting in PNG - solve issue when exporting in PNG
- solve device issue when changing sampling rate (#17)
## v0.1.0 - 2023-10-17 ## v0.1.0 - 2023-10-17
......
...@@ -17,9 +17,10 @@ class DataProcessor(nn.Module): ...@@ -17,9 +17,10 @@ class DataProcessor(nn.Module):
def __init__(self, def __init__(self,
step_size: float, step_size: float,
bins_per_semitone: int = 3, bins_per_semitone: int = 3,
sampling_rate: int = 44100, sampling_rate: Optional[int] = None,
**cqt_kwargs): **cqt_kwargs):
super(DataProcessor, self).__init__() super(DataProcessor, self).__init__()
self.step_size = step_size
self.bins_per_semitone = bins_per_semitone self.bins_per_semitone = bins_per_semitone
# CQT-related stuff # CQT-related stuff
...@@ -34,10 +35,13 @@ class DataProcessor(nn.Module): ...@@ -34,10 +35,13 @@ class DataProcessor(nn.Module):
self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5) self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# handling different sampling rates # sampling rate is lazily initialized
self.step_size = step_size if sampling_rate is not None:
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
# register a dummy tensor to get implicit access to the module's device
self.register_buffer("_device", torch.zeros(()), persistent=False)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
r""" r"""
...@@ -61,18 +65,14 @@ class DataProcessor(nn.Module): ...@@ -61,18 +65,14 @@ class DataProcessor(nn.Module):
return log_cqt return log_cqt
def _init_cqt_layer(self, sr: int, hop_length: int): def _init_cqt_layer(self, sr: int, hop_length: int):
if self.cqt is not None: self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device)
device = self.cqt.cqt_kernels_real.device
else:
device = None
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
@property @property
def sampling_rate(self) -> int: def sampling_rate(self) -> int:
return self._sampling_rate return self._sampling_rate
@sampling_rate.setter @sampling_rate.setter
def sampling_rate(self, sr: int): def sampling_rate(self, sr: int) -> None:
if sr == self._sampling_rate: if sr == self._sampling_rate:
return return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment