From 153c6751a4df9e9e702d98d49f9eb767ef541bc9 Mon Sep 17 00:00:00 2001 From: Alain Riou <alain.riou14000@yahoo.com> Date: Fri, 1 Dec 2023 19:34:34 +0100 Subject: [PATCH] lazy device initialization --- CHANGELOG.md | 1 + pesto/data.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af27849..baa8bcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v0.1.1 - solve issue when exporting in PNG +- solve device issue when changing sampling rate (#17) ## v0.1.0 - 2023-10-17 diff --git a/pesto/data.py b/pesto/data.py index 4c51baf..326d047 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -17,9 +17,10 @@ class DataProcessor(nn.Module): def __init__(self, step_size: float, bins_per_semitone: int = 3, - sampling_rate: int = 44100, + sampling_rate: Optional[int] = None, **cqt_kwargs): super(DataProcessor, self).__init__() + self.step_size = step_size self.bins_per_semitone = bins_per_semitone # CQT-related stuff @@ -34,9 +35,12 @@ class DataProcessor(nn.Module): self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5) self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone - # handling different sampling rates - self.step_size = step_size - self.sampling_rate = sampling_rate + # sampling rate is lazily initialized + if sampling_rate is not None: + self.sampling_rate = sampling_rate + + # register a dummy tensor to get implicit access to the module's device + self.register_buffer("_device", torch.zeros(()), persistent=False) def forward(self, x: torch.Tensor): r""" @@ -61,18 +65,14 @@ class DataProcessor(nn.Module): return log_cqt def _init_cqt_layer(self, sr: int, hop_length: int): - if self.cqt is not None: - device = self.cqt.cqt_kernels_real.device - else: - device = None - self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device) + self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device) @property def sampling_rate(self) -> int: return self._sampling_rate @sampling_rate.setter - def sampling_rate(self, sr: int): + def sampling_rate(self, sr: int) -> None: if sr == self._sampling_rate: return -- GitLab