diff --git a/pesto/core.py b/pesto/core.py
index 899baa6f933fb7d8403e8a5b0adfaa3b2aaa3f0e..7c3522510d142d98d7df2310e6f337d468f6dd6f 100644
--- a/pesto/core.py
+++ b/pesto/core.py
@@ -93,8 +93,7 @@ def predict_from_files(
         export_format: Sequence[str] = ("csv",),
         no_convert_to_freq: bool = False,
         num_chunks: int = 1,
-        gpu: int = -1
-):
+        gpu: int = -1):
     r"""
 
     Args:
diff --git a/pesto/data.py b/pesto/data.py
index 652ebb4d8edbc2a80b207b5f19c5bf47bcd281c1..3044bbdf3fcfc0818b198555f7711623517d8e62 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -84,11 +84,11 @@ class Preprocessor(nn.Module):
         # compute HCQT kernels if it does not exist or if the sampling rate has changed
         if sr is not None and sr != self.hcqt_sr:
             self.hcqt_sr = sr
-            self._reset_hcqt_layer()
+            self._reset_hcqt_kernels()
 
         return self.hcqt_kernels(audio)
 
-    def _reset_hcqt_layer(self) -> None:
+    def _reset_hcqt_kernels(self) -> None:
         hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5)
         self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr,
                                         hop_length=hop_length,
diff --git a/pesto/loader.py b/pesto/loader.py
index 94c519fb5f3c3b1c6f7bafabba35b0bb659dbfa4..a270615997197cfa4abdcfeef30de6a68d67be98 100644
--- a/pesto/loader.py
+++ b/pesto/loader.py
@@ -45,7 +45,7 @@ def load_model(checkpoint: str,
                   preprocessor=preprocessor,
                   crop_kwargs=hparams["pitch_shift"],
                   reduction=hparams["reduction"])
-    model.load_state_dict(state_dict)
+    model.load_state_dict(state_dict, strict=False)
     model.eval()
 
     return model
diff --git a/pesto/model.py b/pesto/model.py
index 27edd932faa90d51e8d23f5a55dde498a33fbb0d..27241271f15211799e8573bac077f56fa95f400d 100644
--- a/pesto/model.py
+++ b/pesto/model.py
@@ -228,6 +228,10 @@ class PESTO(nn.Module):
 
         return preds, confidence
 
+    @property
+    def bins_per_semitone(self) -> int:
+        return self.preprocessor.hcqt_kwargs["bins_per_semitone"]
+
     @property
     def hop_size(self) -> float:
         r"""Returns the hop size of the model (in milliseconds)"""
diff --git a/pesto/utils/reduce_activations.py b/pesto/utils/reduce_activations.py
index 191cb4f060ce53e1081b8696b5408ab0af1d6b28..b22dc508c2c0299b03c3038c9c50bf282bf1b76f 100644
--- a/pesto/utils/reduce_activations.py
+++ b/pesto/utils/reduce_activations.py
@@ -30,7 +30,7 @@ def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> to
         window = torch.arange(1, 2 * bps, device=device) - bps  # [-bps+1, -bps+2, ..., bps-2, bps-1]
         indices = (center_bin + window).clip_(min=0, max=num_bins - 1)
         cropped_activations = activations.gather(-1, indices)
-        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(1, indices)
-        return (cropped_activations * cropped_pitches).sum(dim=1) / cropped_activations.sum(dim=1)
+        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
+        return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)
 
     raise ValueError
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/test_basic.py b/tests/test_basic.py
deleted file mode 100644
index 1f4096eab0b3ed39a4b697104f2f9d33ddad8dc8..0000000000000000000000000000000000000000
--- a/tests/test_basic.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import unittest
-import pesto
-
-
-class MyTestCase(unittest.TestCase):
-    def test_something(self):
-        self.assertEqual(True, True)  # add assertion here
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/tests/test_performances.py b/tests/test_performances.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aed8de0b4172658236ddbfce55f15549346b7d2
--- /dev/null
+++ b/tests/test_performances.py
@@ -0,0 +1,23 @@
+import itertools
+
+import pytest
+
+import torch
+
+from pesto import load_model
+from .utils import generate_synth_data
+
+
+@pytest.fixture
+def model():
+    return load_model('mir-1k', step_size=10.)
+
+
+@pytest.mark.parametrize('pitch, sr, reduction',
+                         itertools.product(range(50, 80), [16000, 44100, 48000], ["argmax", "alwa"]))
+def test_performances(model, pitch, sr, reduction):
+    x = generate_synth_data(pitch, sr=sr)
+
+    preds, conf = model(x, sr=sr, return_activations=False)
+
+    torch.testing.assert_close(preds, torch.full_like(preds, pitch), atol=0.1, rtol=0.1)
diff --git a/tests/test_predict.py b/tests/test_predict.py
deleted file mode 100644
index 464090415c47109523e91779d4f40e19495c9cf1..0000000000000000000000000000000000000000
--- a/tests/test_predict.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO
diff --git a/tests/test_shape.py b/tests/test_shape.py
new file mode 100644
index 0000000000000000000000000000000000000000..eed58bb8e4179c5f91c242fefe4e398f574512c6
--- /dev/null
+++ b/tests/test_shape.py
@@ -0,0 +1,97 @@
+import itertools
+
+import pytest
+
+import torch
+
+from pesto import load_model, predict
+from .utils import generate_synth_data
+
+
+@pytest.fixture
+def model():
+    return load_model('mir-1k', step_size=10.)
+
+
+@pytest.fixture
+def synth_data_16k():
+    return generate_synth_data(pitch=69, duration=5., sr=16000), 16000
+
+
+@pytest.mark.parametrize('reduction', ["argmax", "mean", "alwa"])
+def test_shape_no_batch(model, synth_data_16k, reduction):
+    x, sr = synth_data_16k
+
+    model.reduction = reduction
+
+    num_samples = x.size(-1)
+
+    num_timesteps = int(num_samples * 1000 / (model.hop_size * sr)) + 1
+
+    preds, conf, activations = model(x, sr=sr, return_activations=True)
+
+    assert preds.shape == (num_timesteps,)
+    assert conf.shape == (num_timesteps,)
+    assert activations.shape == (num_timesteps, 128 * model.bins_per_semitone)
+
+
+@pytest.mark.parametrize('sr, reduction',
+                         itertools.product([16000, 44100, 48000], ["argmax", "mean", "alwa"]))
+def test_shape_batch(model, sr, reduction):
+    model.reduction = reduction
+
+    batch_size = 13
+
+    batch = torch.stack([
+        generate_synth_data(pitch=p, duration=5., sr=sr)
+        for p in range(50, 50+batch_size)
+    ])
+
+    num_timesteps = int(batch.size(-1) * 1000 / (model.hop_size * sr)) + 1
+
+    preds, conf, activations = model(batch, sr=sr, return_activations=True)
+
+    assert preds.shape == (batch_size, num_timesteps)
+    assert conf.shape == (batch_size, num_timesteps)
+    assert activations.shape == (batch_size, num_timesteps, 128 * model.bins_per_semitone)
+
+
+@pytest.mark.parametrize('step_size, reduction',
+                         itertools.product([10., 20., 50., 100], ["argmax", "mean", "alwa"]))
+def test_predict_shape_no_batch(synth_data_16k, step_size, reduction):
+    x, sr = synth_data_16k
+
+    num_samples = x.size(-1)
+
+    num_timesteps = int(num_samples * 1000 / (step_size * sr)) + 1
+
+    timesteps, preds, conf, activations = predict(x,
+                                                  sr,
+                                                  step_size=step_size,
+                                                  reduction=reduction)
+
+    assert timesteps.shape == (num_timesteps,)
+    assert preds.shape == (num_timesteps,)
+    assert conf.shape == (num_timesteps,)
+
+
+@pytest.mark.parametrize('sr, step_size, reduction',
+                         itertools.product([16000, 44100, 48000], [10., 20., 50., 100.], ["argmax", "mean", "alwa"]))
+def test_predict_shape_batch(sr, step_size, reduction):
+    batch_size = 13
+
+    batch = torch.stack([
+        generate_synth_data(pitch=p, duration=5., sr=sr)
+        for p in range(50, 50+batch_size)
+    ])
+
+    num_timesteps = int(batch.size(-1) * 1000 / (step_size * sr)) + 1
+
+    timesteps, preds, conf, activations = predict(batch,
+                                                  sr=sr,
+                                                  step_size=step_size,
+                                                  reduction=reduction)
+
+    assert timesteps.shape == (num_timesteps,)
+    assert preds.shape == (batch_size, num_timesteps)
+    assert conf.shape == (batch_size, num_timesteps)
diff --git a/tests/test_timesteps.py b/tests/test_timesteps.py
new file mode 100644
index 0000000000000000000000000000000000000000..546ee09beb2b00cfd3a0c316c88d5167c69d9c72
--- /dev/null
+++ b/tests/test_timesteps.py
@@ -0,0 +1,18 @@
+import pytest
+
+import torch
+
+from pesto import predict
+from .utils import generate_synth_data
+
+
+@pytest.fixture
+def synth_data_16k():
+    return generate_synth_data(pitch=69, duration=5., sr=16000), 16000
+
+
+@pytest.mark.parametrize('step_size', [10., 20., 50., 100])
+def test_build_timesteps(synth_data_16k, step_size):
+    timesteps, *_ = predict(*synth_data_16k, step_size=step_size)
+    diff = timesteps[1:] - timesteps[:-1]
+    torch.testing.assert_close(diff, torch.full_like(diff, step_size))
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..71210e340044aa7e463908c841b45d8e719d457e
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,20 @@
+import torch
+
+
+def mid_to_hz(pitch: int):
+    return 440 * 2 ** ((pitch - 69) / 12)
+
+
+def generate_synth_data(pitch: int, num_harmonics: int = 5, duration=2, sr=16000):
+    f0 = mid_to_hz(pitch)
+    t = torch.arange(0, duration, 1/sr)
+    harmonics = torch.stack([
+        torch.cos(2 * torch.pi * k * f0 * t + torch.rand(()))
+        for k in range(1, num_harmonics+1)
+    ], dim=1)
+    # volume = torch.randn(()) * torch.arange(num_harmonics).neg().div(0.5).exp()
+    volume = torch.rand(num_harmonics)
+    volume[0] = 1
+    volume *= torch.randn(())
+    audio = torch.sum(volume * harmonics, dim=1)
+    return audio
\ No newline at end of file