diff --git a/CHANGELOG.md b/CHANGELOG.md
index af278493ef973c9fa6baf2e4f6a09ab1ed3e6456..baa8bcd3c4c5b5283eb4f04cb4be66a519af07a2 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,7 @@
 ## v0.1.1
 
 - solve issue when exporting in PNG
+- solve device issue when changing sampling rate (#17)
 
 
 ## v0.1.0 - 2023-10-17
diff --git a/pesto/data.py b/pesto/data.py
index 4c51bafa830fdaa26b3645ecf7668c22e5049c31..326d04751e3d86939ba02fcb8e5126704fe4c99c 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -17,9 +17,10 @@ class DataProcessor(nn.Module):
     def __init__(self,
                  step_size: float,
                  bins_per_semitone: int = 3,
-                 sampling_rate: int = 44100,
+                 sampling_rate: Optional[int] = None,
                  **cqt_kwargs):
         super(DataProcessor, self).__init__()
+        self.step_size = step_size
         self.bins_per_semitone = bins_per_semitone
 
         # CQT-related stuff
@@ -34,9 +35,12 @@ class DataProcessor(nn.Module):
         self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
         self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
 
-        # handling different sampling rates
-        self.step_size = step_size
-        self.sampling_rate = sampling_rate
+        # sampling rate is lazily initialized
+        if sampling_rate is not None:
+            self.sampling_rate = sampling_rate
+
+        # register a dummy tensor to get implicit access to the module's device
+        self.register_buffer("_device", torch.zeros(()), persistent=False)
 
     def forward(self, x: torch.Tensor):
         r"""
@@ -61,18 +65,14 @@ class DataProcessor(nn.Module):
         return log_cqt
 
     def _init_cqt_layer(self, sr: int, hop_length: int):
-        if self.cqt is not None:
-            device = self.cqt.cqt_kernels_real.device
-        else:
-            device = None
-        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
+        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device)
 
     @property
     def sampling_rate(self) -> int:
         return self._sampling_rate
 
     @sampling_rate.setter
-    def sampling_rate(self, sr: int):
+    def sampling_rate(self, sr: int) -> None:
         if sr == self._sampling_rate:
             return