diff --git a/pesto/model.py b/pesto/model.py index a728e3522000b07e07e305d615b6c7f2f66785fd..9fbaddf6e7323a2b9007ba7e5171f2e5b1989b11 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -215,10 +215,9 @@ class PESTO(nn.Module): if batch_size: activations = activations.view(batch_size, -1, activations.size(-1)) - preds = reduce_activations(activations, reduction=self.reduction) + activations = activations.roll(-self.shift.cpu().item(), dims=-1) - # decrease by shift to get absolute pitch - preds.sub_(self.shift) + preds = reduce_activations(activations, reduction=self.reduction) if convert_to_freq: preds = 440 * 2 ** ((preds - 69) / 12)