diff --git a/pesto/data.py b/pesto/data.py
index ea053726694aa78a6645a7d0c2223b02d81a2f2a..4c51bafa830fdaa26b3645ecf7668c22e5049c31 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 import torch
 import torch.nn as nn
 
@@ -10,10 +12,12 @@ class DataProcessor(nn.Module):
     Args:
         step_size (float): step size between consecutive CQT frames (in SECONDS)
     """
+    _sampling_rate: Optional[int] = None
+
     def __init__(self,
                  step_size: float,
                  bins_per_semitone: int = 3,
-                 device: torch.device = torch.device("cpu"),
+                 sampling_rate: int = 44100,
                  **cqt_kwargs):
         super(DataProcessor, self).__init__()
         self.bins_per_semitone = bins_per_semitone
@@ -31,9 +35,8 @@ class DataProcessor(nn.Module):
         self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
 
         # handling different sampling rates
-        self._sampling_rate = None
         self.step_size = step_size
-        self.device = device
+        self.sampling_rate = sampling_rate
 
     def forward(self, x: torch.Tensor):
         r"""
@@ -58,7 +61,11 @@ class DataProcessor(nn.Module):
         return log_cqt
 
     def _init_cqt_layer(self, sr: int, hop_length: int):
-        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device)
+        if self.cqt is not None:
+            device = self.cqt.cqt_kernels_real.device
+        else:
+            device = None
+        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
 
     @property
     def sampling_rate(self) -> int:
diff --git a/pesto/utils.py b/pesto/utils.py
index 46281ebc1f4fff05c5c10fd14bd056244482039b..f3fbe10ce1d630ae41153ffa4612692e15ef06a1 100644
--- a/pesto/utils.py
+++ b/pesto/utils.py
@@ -9,7 +9,7 @@ from .model import PESTOEncoder
 
 
 def load_dataprocessor(step_size, device: Optional[torch.device] = None):
-    return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device)
+    return DataProcessor(step_size=step_size, **cqt_args).to(device)
 
 
 def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: