diff --git a/CHANGELOG.md b/CHANGELOG.md index af278493ef973c9fa6baf2e4f6a09ab1ed3e6456..baa8bcd3c4c5b5283eb4f04cb4be66a519af07a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v0.1.1 - solve issue when exporting in PNG +- solve device issue when changing sampling rate (#17) ## v0.1.0 - 2023-10-17 diff --git a/pesto/data.py b/pesto/data.py index 4c51bafa830fdaa26b3645ecf7668c22e5049c31..326d04751e3d86939ba02fcb8e5126704fe4c99c 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -17,9 +17,10 @@ class DataProcessor(nn.Module): def __init__(self, step_size: float, bins_per_semitone: int = 3, - sampling_rate: int = 44100, + sampling_rate: Optional[int] = None, **cqt_kwargs): super(DataProcessor, self).__init__() + self.step_size = step_size self.bins_per_semitone = bins_per_semitone # CQT-related stuff @@ -34,9 +35,12 @@ class DataProcessor(nn.Module): self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5) self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone - # handling different sampling rates - self.step_size = step_size - self.sampling_rate = sampling_rate + # sampling rate is lazily initialized + if sampling_rate is not None: + self.sampling_rate = sampling_rate + + # register a dummy tensor to get implicit access to the module's device + self.register_buffer("_device", torch.zeros(()), persistent=False) def forward(self, x: torch.Tensor): r""" @@ -61,18 +65,14 @@ class DataProcessor(nn.Module): return log_cqt def _init_cqt_layer(self, sr: int, hop_length: int): - if self.cqt is not None: - device = self.cqt.cqt_kernels_real.device - else: - device = None - self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device) + self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device) @property def sampling_rate(self) -> int: return self._sampling_rate @sampling_rate.setter - def sampling_rate(self, sr: int): + def sampling_rate(self, sr: int) -> None: if sr == self._sampling_rate: return