diff --git a/pesto/model.py b/pesto/model.py
index a728e3522000b07e07e305d615b6c7f2f66785fd..9fbaddf6e7323a2b9007ba7e5171f2e5b1989b11 100644
--- a/pesto/model.py
+++ b/pesto/model.py
@@ -215,10 +215,9 @@ class PESTO(nn.Module):
         if batch_size:
             activations = activations.view(batch_size, -1, activations.size(-1))
 
-        preds = reduce_activations(activations, reduction=self.reduction)
+        activations = activations.roll(-self.shift.cpu().item(), dims=-1)
 
-        # decrease by shift to get absolute pitch
-        preds.sub_(self.shift)
+        preds = reduce_activations(activations, reduction=self.reduction)
 
         if convert_to_freq:
             preds = 440 * 2 ** ((preds - 69) / 12)