Skip to content
Snippets Groups Projects
Commit 153c6751 authored by Alain Riou's avatar Alain Riou
Browse files

lazy device initialization

parent 8201e4f7
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
## v0.1.1
- solve issue when exporting in PNG
- solve device issue when changing sampling rate (#17)
## v0.1.0 - 2023-10-17
......
......@@ -17,9 +17,10 @@ class DataProcessor(nn.Module):
def __init__(self,
step_size: float,
bins_per_semitone: int = 3,
sampling_rate: int = 44100,
sampling_rate: Optional[int] = None,
**cqt_kwargs):
super(DataProcessor, self).__init__()
self.step_size = step_size
self.bins_per_semitone = bins_per_semitone
# CQT-related stuff
......@@ -34,10 +35,13 @@ class DataProcessor(nn.Module):
self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# handling different sampling rates
self.step_size = step_size
# sampling rate is lazily initialized
if sampling_rate is not None:
self.sampling_rate = sampling_rate
# register a dummy tensor to get implicit access to the module's device
self.register_buffer("_device", torch.zeros(()), persistent=False)
def forward(self, x: torch.Tensor):
r"""
......@@ -61,18 +65,14 @@ class DataProcessor(nn.Module):
return log_cqt
def _init_cqt_layer(self, sr: int, hop_length: int):
if self.cqt is not None:
device = self.cqt.cqt_kernels_real.device
else:
device = None
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device)
@property
def sampling_rate(self) -> int:
return self._sampling_rate
@sampling_rate.setter
def sampling_rate(self, sr: int):
def sampling_rate(self, sr: int) -> None:
if sr == self._sampling_rate:
return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment