diff --git a/README.md b/README.md
index f703e23ae3c086fe4150d364527ebbc89b32d56b..ac4d40043b01c9d10152d6a10327a61d3f040847 100644
--- a/README.md
+++ b/README.md
@@ -107,46 +107,64 @@ import pesto
 
 # predict the pitch of your audio tensors directly within your own Python code
 x, sr = torchaudio.load("my_file.wav")
-timesteps, pitch, confidence, activations = pesto.predict(x, sr, step_size=10.)
+x = x.mean(dim=0)  # PESTO takes mono audio as input
+timesteps, pitch, confidence, activations = pesto.predict(x, sr)
+
+# or predict using your own custom checkpoint
+predictions = pesto.predict(x, sr, model_name="/path/to/checkpoint.ckpt")
 
 # You can also predict pitches from audio files directly
-pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], step_size=10., export_format=["csv"])
+pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], export_format=["csv"])
 ```
+`pesto.predict` supports batched inputs, which should then have shape `(batch_size, num_samples)`.
+
+**Warning:** If you forget to convert a stereo audio in mono, channels will be treated as batch dimensions and you will 
+get predictions for each channel separately.
 
 #### Advanced usage
 
-If not provided,  `pesto.predict` will first load the CQT kernels and the model before performing 
+`pesto.predict` will first load the CQT kernels and the model before performing 
 any pitch estimation. If you want to process a significant number of files, calling `predict` several times will then 
 re-initialize the same model for each tensor.
 
-To avoid this time-consuming step, one can manually instantiate  the model   and data processor, then pass them directly 
-as args to the `predict` function. To do so, one has to use the underlying methods from `pesto.utils`:
+To avoid this time-consuming step, one can manually instantiate  the model with `load_model`,
+then call the forward method of the model instead:
 
 ```python
 import torch
 
-from pesto import predict
-from pesto.utils import load_model, load_dataprocessor
+from pesto import load_model
 
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-model = load_model("mir-1k", device=device)
-data_processor = load_dataprocessor(step_size=0.01, device=device)
+pesto_model = load_model("mir-1k", step_size=20.).to(device)
 
 for x, sr in ...:
-    data_processor.sampling_rate = sr  # The data_processor handles waveform->CQT conversion so it must know the sampling rate
-    predictions = predict(x, sr, model=model)
+    x = x.to(device)
+    predictions, confidence, activations = pesto_model(x, sr)
     ...
 ```
-Note that when passing a list of files to `pesto.predict_from_files(...)` or the CLI directly, the model  is loaded only
-once so you don't have to bother with that in general.
 
-#### Batched pitch estimation
+Note that the `PESTO` object returned by `load_model` is a subclass of `nn.Module` 
+and its `forward` method also supports batched inputs.
+One can therefore easily integrate PESTO within their own architecture by doing:
+```python
+import torch
+import torch.nn as nn
 
-By default, the function `pesto.predict` takes an audio waveform represented as a Tensor object of shape `(num_channels, num_samples)`.
-However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications.
-`pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly.
+from pesto import load_model
 
-Note that batched predictions are available only from the Python API and not from the CLI because handling audios of different lengths is annoying, and I don't want to bother with that.
+
+class MyGreatModel(nn.Module):
+    def __init__(self, step_size, sr=44100, *args, **kwargs):
+        super(MyGreatModel, self).__init__()
+        self.f0_estimator = load_model("mir-1k", step_size, sampling_rate=sr)
+        ...
+
+    def forward(self, x):
+        with torch.no_grad():
+            f0, conf = self.f0_estimator(x, convert_to_freq=True, return_activations=False)
+        ...
+```
 
 ## Performances
 
diff --git a/pesto/__init__.py b/pesto/__init__.py
index 82a92f8908b9f434140e3c1f7f44899a11cb8655..cff2aebc8285d3f463252869e8f432cb0189bade 100644
--- a/pesto/__init__.py
+++ b/pesto/__init__.py
@@ -1 +1 @@
-from .core import predict, predict_from_files
+from .core import load_model, predict, predict_from_files
diff --git a/pesto/core.py b/pesto/core.py
index 30e2e98a6a6ed409ff8c003214273c4aadab4799..899baa6f933fb7d8403e8a5b0adfaa3b2aaa3f0e 100644
--- a/pesto/core.py
+++ b/pesto/core.py
@@ -1,38 +1,61 @@
 import os
 import warnings
-from typing import Optional, Sequence, Union
+from typing import Optional, Sequence, Tuple, Union
 
 import torch
 import torchaudio
 from tqdm import tqdm
 
-from .export import export
-from .utils import load_model, load_dataprocessor, reduce_activation
-
-
-def predict(
-        x: torch.Tensor,
-        sr: Optional[int] = None,
-        model: Union[torch.nn.Module, str] = "mir-1k",
-        data_preprocessor=None,
-        step_size: Optional[float] = None,
-        reduction: str = "argmax",
-        num_chunks: int = 1,
-        convert_to_freq: bool = False,
-        inference_mode: bool = True,
-        no_grad: bool = True
-):
+from .loader import load_model
+from .model import PESTO
+from .utils import export
+
+
+def _predict(x: torch.Tensor,
+             sr: int,
+             model: PESTO,
+             num_chunks: int = 1,
+             convert_to_freq: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    preds, confidence, activations = [], [], []
+    try:
+        for chunk in x.chunk(chunks=num_chunks):
+            pred, conf, act = model(chunk, sr=sr, convert_to_freq=convert_to_freq, return_activations=True)
+            preds.append(pred)
+            confidence.append(conf)
+            activations.append(act)
+    except torch.cuda.OutOfMemoryError:
+        raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
+                                          "Please increase the number of chunks with option `-c`/`--chunks` "
+                                          "to reduce GPU memory usage.")
+
+    preds = torch.cat(preds, dim=0)
+    confidence = torch.cat(confidence, dim=0)
+    activations = torch.cat(activations, dim=0)
+
+    # compute timesteps
+    timesteps = torch.arange(preds.size(-1), device=x.device) * model.hop_size
+
+    return timesteps, preds, confidence, activations
+
+
+def predict(x: torch.Tensor,
+            sr: int,
+            step_size: float = 10.,
+            model_name: str = "mir-1k",
+            reduction: str = "alwa",
+            num_chunks: int = 1,
+            convert_to_freq: bool = True,
+            inference_mode: bool = True,
+            no_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     r"""Main prediction function.
 
     Args:
-        x (torch.Tensor): input audio tensor,
-            shape (num_channels, num_samples) or (batch_size, num_channels, num_samples)
+        x (torch.Tensor): input audio tensor, can be provided as a batch but should be mono,
+            shape (num_samples) or (batch_size, num_samples)
         sr (int, optional): sampling rate. If not specified, uses the current sampling rate of the model.
-        model: PESTO model. If a string is passed, it will load the model with the corresponding name.
-            Otherwise, the actual nn.Module will be used for doing predictions.
-        data_preprocessor: Module handling the data processing pipeline (waveform to CQT, cropping, etc.)
         step_size (float, optional): step size between each CQT frame in milliseconds.
-            If the data_preprocessor is passed, its value will be used instead.
+            If a `PESTO` object is passed as `model`, this will be ignored.
+        model_name: name of PESTO model. Can be a path to a custom PESTO checkpoint or the name of a standard model.
         reduction (str): reduction method for converting activation probabilities to log-frequencies.
         num_chunks (int): number of chunks to split the input audios in.
             Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage
@@ -40,58 +63,25 @@ def predict(
         convert_to_freq (bool): whether predictions should be converted to frequencies or not.
         inference_mode (bool): whether to run with `torch.inference_mode`.
         no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored.
+
+    Returns:
+        timesteps (torch.Tensor): timesteps corresponding to each pitch prediction, shape (num_timesteps)
+        preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
+            where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
+        confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
+            shape (batch_size?, num_timesteps)
+        activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
     """
+    # sanity checks
+    assert x.ndim <= 2, \
+        f"Audio file should have shape (num_samples) or (batch_size, num_samples), but found shape {x.size()}."
+
     inference_mode = inference_mode and no_grad
     with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode):
-        # convert to mono
-        assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}"
-        batch_size = x.size(0) if x.ndim == 3 else None
-        x = x.mean(dim=-2)
-
-        if data_preprocessor is None:
-            assert step_size is not None and sr is not None, \
-                "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)"
-            data_preprocessor = load_dataprocessor(step_size=step_size / 1000., sampling_rate=sr, device=x.device)
-
-        # If the sampling rate has changed, change the sampling rate accordingly
-        # It will automatically recompute the CQT kernels if needed
-        if sr is not None:
-            data_preprocessor.sampling_rate = sr
-
-        if isinstance(model, str):
-            model = load_model(model, device=x.device)
+        model = load_model(model_name, step_size, sampling_rate=sr).to(x.device)
+        model.reduction = reduction
 
-        # apply model
-        cqt = data_preprocessor(x)
-        try:
-            activations = torch.cat([
-                model(chunk) for chunk in cqt.chunk(chunks=num_chunks)
-            ])
-        except torch.cuda.OutOfMemoryError:
-            raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
-                                              "Please increase the number of chunks with option `-c`/`--chunks` "
-                                              "to reduce GPU memory usage.")
-
-        if batch_size:
-            total_batch_size, num_predictions = activations.size()
-            activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions)
-
-        # shift activations as it should (PESTO predicts pitches up to an additive constant)
-        activations = activations.roll(model.abs_shift.cpu().item(), dims=-1)
-
-        # convert model predictions to pitch values
-        pitch = reduce_activation(activations, reduction=reduction)
-        if convert_to_freq:
-            pitch = 440 * 2 ** ((pitch - 69) / 12)
-
-        # for now, confidence is computed very naively just based on volume
-        confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch)
-        conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
-        confidence = (confidence - conf_min) / (conf_max - conf_min)
-
-        timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size
-
-    return timesteps, pitch, confidence, activations
+        return _predict(x, sr, model, num_chunks=num_chunks, convert_to_freq=convert_to_freq)
 
 
 def predict_from_files(
@@ -113,12 +103,11 @@ def predict_from_files(
         output:
         step_size: hop length in milliseconds
         reduction:
-        export_format:
+        export_format (Sequence[str]): format to export the predictions to.
+            Currently formats supported are: ["csv", "npz", "json"].
         no_convert_to_freq: whether convert output values to Hz or keep fractional MIDI pitches
+        num_chunks (int): number of chunks to divide the inputs into. Increase this value if you encounter OOM errors.
         gpu: index of GPU to use (-1 for CPU)
-
-    Returns:
-        Pitch predictions, see `predict` for more details.
     """
     if isinstance(audio_files, str):
         audio_files = [audio_files]
@@ -128,37 +117,34 @@ def predict_from_files(
         gpu = -1
     device = torch.device(f"cuda:{gpu:d}" if gpu >= 0 else "cpu")
 
-    # define data preprocessing
-    data_preprocessor = load_dataprocessor(step_size / 1000., device=device)
-
     # define model
-    model = load_model(model_name, device=device)
-    predictions = None
+    model = load_model(model_name, step_size=step_size).to(device)
+    model.reduction = reduction
 
     pbar = tqdm(audio_files)
-    for file in pbar:
-        pbar.set_description(file)
 
-        # load audio file
-        try:
-            x, sr = torchaudio.load(file)
-        except Exception as e:
-            print(e, f"Skipping {file}...")
-            continue
+    with torch.inference_mode():  # here the purpose is to write results in disk, so there is no point storing gradients
+        for file in pbar:
+            pbar.set_description(file)
+
+            # load audio file
+            try:
+                x, sr = torchaudio.load(file)
+            except Exception as e:
+                print(e, f"Skipping {file}...")
+                continue
 
-        x = x.to(device)
+            x = x.mean(dim=0).to(device)  # convert to mono then pass to the right device
 
-        # compute the predictions
-        predictions = predict(x, sr, model=model, data_preprocessor=data_preprocessor, reduction=reduction,
-                              convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
+            # compute the predictions
+            predictions = _predict(x, sr, model=model, convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
 
-        output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0")
-        if output is not None:
-            os.makedirs(output, exist_ok=True)
-            output_file = os.path.join(output, os.path.basename(output_file))
+            output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0")
+            if output is not None:
+                os.makedirs(output, exist_ok=True)
+                output_file = os.path.join(output, os.path.basename(output_file))
 
-        predictions = [p.cpu().numpy() for p in predictions]
-        for fmt in export_format:
-            export(fmt, output_file, *predictions)
+            predictions = [p.cpu().numpy() for p in predictions]
+            for fmt in export_format:
+                export(fmt, output_file, *predictions)
 
-    return predictions
diff --git a/pesto/data.py b/pesto/data.py
index f5ed52e7fe383d4b3d54a3fbf19fd627b6e99b67..7d8b81d77b2a982c559e258d3f7c8cde301f6d75 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -3,79 +3,92 @@ from typing import Optional
 import torch
 import torch.nn as nn
 
-from .cqt import CQT
+from .utils import HarmonicCQT
 
 
-class DataProcessor(nn.Module):
+class ToLogMagnitude(nn.Module):
+    def __init__(self):
+        super(ToLogMagnitude, self).__init__()
+        self.eps = torch.finfo(torch.float32).eps
+
+    def forward(self, x):
+        x = x.abs()
+        x.clamp_(min=self.eps).log10_().mul_(20)
+        return x
+
+
+
+class Preprocessor(nn.Module):
     r"""
 
     Args:
-        step_size (float): step size between consecutive CQT frames (in SECONDS)
+        hop_size (float): step size between consecutive CQT frames (in milliseconds)
     """
-    _sampling_rate: Optional[int] = None
-
     def __init__(self,
-                 step_size: float,
-                 bins_per_semitone: int = 3,
+                 hop_size: float,
                  sampling_rate: Optional[int] = None,
-                 **cqt_kwargs):
-        super(DataProcessor, self).__init__()
-        self.step_size = step_size
-        self.bins_per_semitone = bins_per_semitone
+                 **hcqt_kwargs):
+        super(Preprocessor, self).__init__()
 
-        # CQT-related stuff
-        self.cqt_kwargs = cqt_kwargs
-        self.cqt_kwargs["bins_per_octave"] = 12 * bins_per_semitone
-        self.cqt = None
+        # HCQT
+        self.hcqt_sr = None
+        self.hcqt_kernels = None
+        self.hop_size = hop_size
 
-        # log-magnitude
-        self.eps = torch.finfo(torch.float32).eps
+        self.hcqt_kwargs = hcqt_kwargs
 
-        # cropping
-        self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
-        self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
+        # log-magnitude
+        self.to_log = ToLogMagnitude()
 
         # register a dummy tensor to get implicit access to the module's device
         self.register_buffer("_device", torch.zeros(()), persistent=False)
 
-        # sampling rate is lazily initialized
+        # if the sampling rate is provided, instantiate the CQT kernels
         if sampling_rate is not None:
-            self.sampling_rate = sampling_rate
+            self.hcqt_sr = sampling_rate
+            self._reset_hcqt_kernels()
 
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
         r"""
 
         Args:
-            x: audio waveform, any sampling rate, shape (num_samples)
+            x (torch.Tensor): audio waveform or batch of audio waveforms, any sampling rate,
+                shape (batch_size?, num_samples)
+            sr (int, optional): sampling rate
 
         Returns:
-            log-magnitude CQT, shape (
+            torch.Tensor: log-magnitude CQT of batch of CQTs,
+                shape (batch_size?, num_timesteps, num_harmonics, num_freqs)
         """
-        # compute CQT from input waveform, and invert dims for (batch_size, time_steps, freq_bins)
-        complex_cqt = torch.view_as_complex(self.cqt(x)).transpose(1, 2)
-
-        # reshape and crop borders to fit training input shape
-        complex_cqt = complex_cqt[..., self.lowest_bin: self.highest_bin]
-
-        # flatten eventual batch dimensions so that batched audios can be processed in parallel
-        complex_cqt = complex_cqt.flatten(0, 1).unsqueeze(1)
+        # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins)
+        # in other words, time becomes the batch dimension, enabling efficient processing for long audios.
+        complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2)
 
         # convert to dB
-        log_cqt = complex_cqt.abs().clamp_(min=self.eps).log10_().mul_(20)
-        return log_cqt
-
-    def _init_cqt_layer(self, sr: int, hop_length: int):
-        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self._device.device)
+        return self.to_log(complex_cqt)
 
-    @property
-    def sampling_rate(self) -> int:
-        return self._sampling_rate
+    def hcqt(self, audio: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
+        r"""Compute the Harmonic CQT of the input audio after eventually recreating the kernels
+        (in case the sampling rate has changed).
 
-    @sampling_rate.setter
-    def sampling_rate(self, sr: int) -> None:
-        if sr == self._sampling_rate:
-            return
+        Args:
+            audio (torch.Tensor): mono audio waveform, shape (batch_size, num_samples)
+            sr (int): sampling rate of the audio waveform.
+                If not specified, we assume it is the same as the previous processed audio waveform.
 
-        hop_length = int(self.step_size * sr + 0.5)
-        self._init_cqt_layer(sr, hop_length)
-        self._sampling_rate = sr
+        Returns:
+            torch.Tensor: Complex Harmonic CQT (HCQT) of the input,
+                shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
+        """
+        # compute HCQT kernels if it does not exist or if the sampling rate has changed
+        if sr is not None and sr != self.hcqt_sr:
+            self.hcqt_sr = sr
+            self._reset_hcqt_layer()
+
+        return self.hcqt_kernels(audio)
+
+    def _reset_hcqt_layer(self) -> None:
+        hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5)
+        self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr,
+                                        hop_length=hop_length,
+                                        **self.hcqt_kwargs).to(self._device.device)
diff --git a/pesto/loader.py b/pesto/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..609d4b5df3cc4b9e2e1c891034b9479379cb774e
--- /dev/null
+++ b/pesto/loader.py
@@ -0,0 +1,51 @@
+import os
+from typing import Optional
+
+import torch
+
+from .data import Preprocessor
+from .model import PESTO, Resnet1d
+
+
+def load_model(checkpoint: str,
+               step_size: float,
+               sampling_rate: Optional[int] = None) -> PESTO:
+    r"""Load a trained model from a checkpoint file.
+    See https://github.com/SonyCSLParis/pesto-full/blob/master/src/models/pesto.py for the structure of the checkpoint.
+
+    Args:
+        checkpoint (str): path to the checkpoint or name of the checkpoint file (if using a provided checkpoint)
+        step_size (float): hop size in milliseconds
+        sampling_rate (int, optional): sampling rate of the audios.
+            If not provided, it can be inferred dynamically as well.
+    Returns:
+        PESTO: instance of PESTO model
+    """
+    if os.path.exists(checkpoint):  # handle user-provided checkpoints
+        model_path = checkpoint
+    else:
+        model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".pth")
+        if not os.path.exists(model_path):
+            raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.")
+
+    # load checkpoint
+    checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
+    hparams = checkpoint["hparams"]
+    hcqt_params = checkpoint["hcqt_params"]
+    state_dict = checkpoint["state_dict"]
+
+    # instantiate preprocessor
+    preprocessor = Preprocessor(hop_size=step_size, sampling_rate=sampling_rate, **hcqt_params)
+
+    # instantiate PESTO encoder
+    encoder = Resnet1d(**hparams["encoder"])
+
+    # instantiate main PESTO module and load its weights
+    model = PESTO(encoder,
+                  preprocessor=preprocessor,
+                  crop_kwargs=hparams["pitch_shift"],
+                  reduction=hparams["reduction"])
+    model.load_state_dict(state_dict)
+    model.eval()
+
+    return model
diff --git a/pesto/main.py b/pesto/main.py
index bc52df6cb0374c6798e23fdfe117e224c7983369..d117941f84eed1535970c16f5f4fa53029bf91f5 100644
--- a/pesto/main.py
+++ b/pesto/main.py
@@ -1,4 +1,4 @@
-from .parser import parse_args
+from pesto.utils.parser import parse_args
 from .core import predict_from_files
 
 
diff --git a/pesto/model.py b/pesto/model.py
index 94337ec84570b1cd3d5aaf1609b06f3f49a30ae6..27edd932faa90d51e8d23f5a55dde498a33fbb0d 100644
--- a/pesto/model.py
+++ b/pesto/model.py
@@ -1,8 +1,15 @@
 from functools import partial
+from typing import Any, Mapping, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
 
+from .utils import CropCQT
+from .utils import reduce_activations
+
+
+OUTPUT_TYPE = Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
+
 
 class ToeplitzLinear(nn.Conv1d):
     def __init__(self, in_features, out_features):
@@ -18,7 +25,7 @@ class ToeplitzLinear(nn.Conv1d):
         return super(ToeplitzLinear, self).forward(input.unsqueeze(-2)).squeeze(-2)
 
 
-class PESTOEncoder(nn.Module):
+class Resnet1d(nn.Module):
     """
     Basic CNN similar to the one in Johannes Zeitler's report,
     but for longer HCQT input (always stride 1 in time)
@@ -29,7 +36,8 @@ class PESTOEncoder(nn.Module):
     not over time (in order to work with variable length input).
     Outputs one channel with sigmoid activation.
 
-    Args (Defaults: BasicCNN by Johannes Zeitler but with 1 input channel):
+    Args (Defaults: BasicCNN by Johannes Zeitler but with 6 input channels):
+        n_chan_input:     Number of input channels (harmonics in HCQT)
         n_chan_layers:    Number of channels in the hidden layers (list)
         n_prefilt_layers: Number of repetitions of the prefiltering layer
         residual:         If True, use residual connections for prefiltering (default: False)
@@ -39,100 +47,188 @@ class PESTOEncoder(nn.Module):
         p_dropout:        Dropout probability
     """
 
-    def __init__(
-            self,
-            n_chan_layers=(20, 20, 10, 1),
-            n_prefilt_layers=1,
-            residual=False,
-            n_bins_in=216,
-            output_dim=128,
-            num_output_layers: int = 1
-    ):
-        super(PESTOEncoder, self).__init__()
-
-        activation_layer = partial(nn.LeakyReLU, negative_slope=0.3)
-
+    def __init__(self,
+                 n_chan_input=1,
+                 n_chan_layers=(20, 20, 10, 1),
+                 n_prefilt_layers=1,
+                 prefilt_kernel_size=15,
+                 residual=False,
+                 n_bins_in=216,
+                 output_dim=128,
+                 activation_fn: str = "leaky",
+                 a_lrelu=0.3,
+                 p_dropout=0.2):
+        super(Resnet1d, self).__init__()
+
+        self.hparams = dict(n_chan_input=n_chan_input,
+                            n_chan_layers=n_chan_layers,
+                            n_prefilt_layers=n_prefilt_layers,
+                            prefilt_kernel_size=prefilt_kernel_size,
+                            residual=residual,
+                            n_bins_in=n_bins_in,
+                            output_dim=output_dim,
+                            activation_fn=activation_fn,
+                            a_lrelu=a_lrelu,
+                            p_dropout=p_dropout)
+
+        if activation_fn == "relu":
+            activation_layer = nn.ReLU
+        elif activation_fn == "silu":
+            activation_layer = nn.SiLU
+        elif activation_fn == "leaky":
+            activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu)
+        else:
+            raise ValueError
+
+        n_in = n_chan_input
         n_ch = n_chan_layers
         if len(n_ch) < 5:
             n_ch.append(1)
 
-        # Layer normalization over frequency
-        self.layernorm = nn.LayerNorm(normalized_shape=[1, n_bins_in])
+        # Layer normalization over frequency and channels (harmonics of HCQT)
+        self.layernorm = nn.LayerNorm(normalized_shape=[n_in, n_bins_in])
 
         # Prefiltering
+        prefilt_padding = prefilt_kernel_size // 2
         self.conv1 = nn.Sequential(
-            nn.Conv1d(in_channels=1, out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
-            activation_layer()
+            nn.Conv1d(in_channels=n_in,
+                      out_channels=n_ch[0],
+                      kernel_size=prefilt_kernel_size,
+                      padding=prefilt_padding,
+                      stride=1),
+            activation_layer(),
+            nn.Dropout(p=p_dropout)
         )
         self.n_prefilt_layers = n_prefilt_layers
-        self.prefilt_list = nn.ModuleList()
-        for p in range(1, n_prefilt_layers):
-            self.prefilt_list.append(nn.Sequential(
-                nn.Conv1d(in_channels=n_ch[0], out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
-                activation_layer()
-            ))
+        self.prefilt_layers = nn.ModuleList(*[
+            nn.Sequential(
+                nn.Conv1d(in_channels=n_ch[0],
+                          out_channels=n_ch[0],
+                          kernel_size=prefilt_kernel_size,
+                          padding=prefilt_padding,
+                          stride=1),
+                activation_layer(),
+                nn.Dropout(p=p_dropout)
+            )
+            for _ in range(n_prefilt_layers-1)
+        ])
         self.residual = residual
 
-        self.conv2 = nn.Sequential(
-            nn.Conv1d(
-                in_channels=n_ch[0],
-                out_channels=n_ch[1],
-                kernel_size=1,
-                stride=1,
-                padding=0
-            ),
-            activation_layer()
-        )
-
-        self.conv3 = nn.Sequential(
-            nn.Conv1d(in_channels=n_ch[1], out_channels=n_ch[2], kernel_size=1, padding=0, stride=1),
-            activation_layer()
-        )
-
-        self.conv4 = nn.Sequential(
-            nn.Conv1d(in_channels=n_ch[2], out_channels=n_ch[3], kernel_size=1, padding=0, stride=1),
-            activation_layer(),
-            nn.Dropout(),
-            nn.Conv1d(in_channels=n_ch[3], out_channels=n_ch[4], kernel_size=1, padding=0, stride=1)
-        )
+        conv_layers = []
+        for i in range(len(n_chan_layers)-1):
+            conv_layers.extend([
+                nn.Conv1d(in_channels=n_ch[i],
+                          out_channels=n_ch[i + 1],
+                          kernel_size=1,
+                          padding=0,
+                          stride=1),
+                activation_layer(),
+                nn.Dropout(p=p_dropout)
+            ])
+        self.conv_layers = nn.Sequential(*conv_layers)
 
         self.flatten = nn.Flatten(start_dim=1)
-
-        layers = []
-        pre_fc_dim = n_bins_in * n_ch[4]
-        for i in range(num_output_layers-1):
-            layers.extend([
-                ToeplitzLinear(pre_fc_dim, pre_fc_dim),
-                activation_layer()
-            ])
-        self.pre_fc = nn.Sequential(*layers)
-        self.fc = ToeplitzLinear(pre_fc_dim, output_dim)
+        self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
 
         self.final_norm = nn.Softmax(dim=-1)
 
-        self.register_buffer("abs_shift", torch.zeros((), dtype=torch.long), persistent=True)
-
     def forward(self, x):
         r"""
 
         Args:
             x (torch.Tensor): shape (batch, channels, freq_bins)
         """
-        x_norm = self.layernorm(x)
+        x = self.layernorm(x)
 
-        x = self.conv1(x_norm)
+        x = self.conv1(x)
         for p in range(0, self.n_prefilt_layers - 1):
-            prefilt_layer = self.prefilt_list[p]
+            prefilt_layer = self.prefilt_layers[p]
             if self.residual:
                 x_new = prefilt_layer(x)
                 x = x_new + x
             else:
                 x = prefilt_layer(x)
-        conv2_lrelu = self.conv2(x)
-        conv3_lrelu = self.conv3(conv2_lrelu)
 
-        y_pred = self.conv4(conv3_lrelu)
-        y_pred = self.flatten(y_pred)
-        y_pred = self.pre_fc(y_pred)
-        y_pred = self.fc(y_pred)
+        x = self.conv_layers(x)
+        x = self.flatten(x)
+
+        y_pred = self.fc(x)
+
         return self.final_norm(y_pred)
+
+
+class PESTO(nn.Module):
+    def __init__(self,
+                 encoder: nn.Module,
+                 preprocessor: nn.Module,
+                 crop_kwargs: Mapping[str, Any] | None = None,
+                 reduction: str = "alwa"):
+        super(PESTO, self).__init__()
+        self.encoder = encoder
+        self.preprocessor = preprocessor
+
+        # crop CQT
+        if crop_kwargs is None:
+            crop_kwargs = {}
+        self.crop_cqt = CropCQT(**crop_kwargs)
+
+        self.reduction = reduction
+
+        # constant shift to get absolute pitch from predictions
+        self.register_buffer('shift', torch.zeros((), dtype=torch.float), persistent=True)
+
+    def forward(self,
+                audio_waveforms: torch.Tensor,
+                sr: Optional[int] = None,
+                convert_to_freq: bool = False,
+                return_activations: bool = False) -> OUTPUT_TYPE:
+        r"""
+
+        Args:
+            audio_waveforms (torch.Tensor): mono audio waveform or batch of mono audio waveforms,
+                shape (batch_size?, num_samples)
+            sr (int, optional): sampling rate, defaults to the previously used sampling rate
+            convert_to_freq (bool): whether to convert the result to frequencies or return fractional semitones instead.
+            return_activations (bool): whether to return activations or pitch predictions only
+
+        Returns:
+            preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
+                where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
+            confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
+                shape (batch_size?, num_timesteps)
+            activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
+        """
+        batch_size = audio_waveforms.size(0) if audio_waveforms.ndim == 2 else None
+        x = self.preprocessor(audio_waveforms, sr=sr)
+        x = self.crop_cqt(x)  # the CQT has to be cropped beforehand
+
+        # for now, confidence is computed very naively just based on energy in the CQT
+        confidence = x.mean(dim=-2).max(dim=-1).values
+        conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
+        confidence = (confidence - conf_min) / (conf_max - conf_min)
+
+        # flatten batch_size and time_steps since anyway predictions are made on CQT frames independently
+        if batch_size:
+            x = x.flatten(0, 1)
+
+        activations = self.encoder(x)
+        if batch_size:
+            activations = activations.view(batch_size, -1, activations.size(-1))
+
+        preds = reduce_activations(activations, reduction=self.reduction)
+
+        # decrease by shift to get absolute pitch
+        preds.sub_(self.shift)
+
+        if convert_to_freq:
+            preds = 440 * 2 ** ((preds - 69) / 12)
+
+        if return_activations:
+            return preds, confidence, activations
+
+        return preds, confidence
+
+    @property
+    def hop_size(self) -> float:
+        r"""Returns the hop size of the model (in milliseconds)"""
+        return self.preprocessor.hop_size
diff --git a/pesto/utils.py b/pesto/utils.py
deleted file mode 100644
index 18722603ce06cde714ff3030405e904d1c16b71e..0000000000000000000000000000000000000000
--- a/pesto/utils.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import os
-from typing import Optional
-
-import torch
-
-from .config import model_args, cqt_args, bins_per_semitone
-from .data import DataProcessor
-from .model import PESTOEncoder
-
-
-def load_dataprocessor(step_size, sampling_rate: Optional[int] = None, device: Optional[torch.device] = None):
-    return DataProcessor(step_size=step_size, sampling_rate=sampling_rate, **cqt_args).to(device)
-
-
-def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
-    model = PESTOEncoder(**model_args).to(device)
-    model.eval()
-
-    model_path = os.path.join(os.path.dirname(__file__), "weights", model_name + ".pth")
-    model.load_state_dict(torch.load(model_path, map_location=device))
-
-    return model
-
-
-def reduce_activation(activations: torch.Tensor, reduction: str) -> torch.Tensor:
-    r"""Computes the pitch predictions from the activation outputs of the encoder.
-    Pitch predictions are returned in semitones, NOT in frequencies.
-
-    Args:
-        activations: tensor of probability activations, shape (*, num_bins)
-        reduction:
-
-    Returns:
-        torch.Tensor: pitch predictions, shape (*,)
-    """
-    bps = bins_per_semitone
-    if reduction == "argmax":
-        pred = activations.argmax(dim=-1)
-        return pred.float() / bps
-
-    all_pitches = (torch.arange(activations.size(-1), dtype=torch.float, device=activations.device)) / bps
-    if reduction == "mean":
-        return activations @ all_pitches
-
-    if reduction == "alwa":  # argmax-local weighted averaging, see https://github.com/marl/crepe
-        center_bin = activations.argmax(dim=-1, keepdim=True)
-        window = torch.arange(-bps+1, bps, device=activations.device)
-        indices = window + center_bin
-        cropped_activations = activations.gather(-1, indices)
-        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
-        return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)
-
-    raise ValueError
diff --git a/pesto/utils/__init__.py b/pesto/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be9b1b72be5579424e7411fe2291823838cd3d4
--- /dev/null
+++ b/pesto/utils/__init__.py
@@ -0,0 +1,4 @@
+from .crop_cqt import CropCQT
+from .export import export
+from .hcqt import HarmonicCQT
+from .reduce_activations import reduce_activations
\ No newline at end of file
diff --git a/pesto/utils/crop_cqt.py b/pesto/utils/crop_cqt.py
new file mode 100644
index 0000000000000000000000000000000000000000..65829c70e8ee90554acee1c027178f6292cd4230
--- /dev/null
+++ b/pesto/utils/crop_cqt.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn as nn
+
+
+class CropCQT(nn.Module):
+    def __init__(self, min_steps: int, max_steps: int):
+        super(CropCQT, self).__init__()
+        self.min_steps = min_steps
+        self.max_steps = max_steps
+
+        # lower bin
+        self.lower_bin = self.max_steps
+
+    def forward(self, spectrograms: torch.Tensor) -> torch.Tensor:
+        # WARNING: didn't check that it works, it may be dangerous
+        return spectrograms[..., self.max_steps: self.min_steps]
+
+        # old implementation
+        batch_size, _, input_height = spectrograms.size()
+
+        output_height = input_height - self.max_steps + self.min_steps
+        assert output_height > 0, \
+            f"With input height {input_height:d} and output height {output_height:d}, impossible " \
+            f"to have a range of {self.max_steps - self.min_steps:d} bins."
+
+        return spectrograms[..., self.lower_bin: self.lower_bin + output_height]
diff --git a/pesto/export.py b/pesto/utils/export.py
similarity index 100%
rename from pesto/export.py
rename to pesto/utils/export.py
diff --git a/pesto/cqt.py b/pesto/utils/hcqt.py
similarity index 91%
rename from pesto/cqt.py
rename to pesto/utils/hcqt.py
index 6fdfccb6372077fd1f94b645fded1178c90983d7..4c92564ce46c05002f0d7c9ffca548818b7db0ca 100644
--- a/pesto/cqt.py
+++ b/pesto/utils/hcqt.py
@@ -354,3 +354,39 @@ class CQT(nn.Module):
             phase_real = torch.cos(torch.atan2(CQT_imag, CQT_real))
             phase_imag = torch.sin(torch.atan2(CQT_imag, CQT_real))
             return torch.stack((phase_real, phase_imag), -1)
+
+
+class HarmonicCQT(nn.Module):
+    r"""Harmonic CQT layer, as described in Bittner et al. (20??)"""
+    def __init__(
+            self,
+            harmonics,
+            sr: int = 22050,
+            hop_length: int = 512,
+            fmin: float = 32.7,
+            fmax: float | None = None,
+            bins_per_semitone: int = 1,
+            n_bins: int = 84,
+            center_bins: bool = True
+    ):
+        super(HarmonicCQT, self).__init__()
+
+        if center_bins:
+            fmin = fmin / 2 ** ((bins_per_semitone - 1) / (24 * bins_per_semitone))
+
+        self.cqt_kernels = nn.ModuleList([
+            CQT(sr=sr, hop_length=hop_length, fmin=h*fmin, fmax=fmax, n_bins=n_bins,
+                bins_per_octave=12*bins_per_semitone, output_format="Complex")
+            for h in harmonics
+        ])
+
+    def forward(self, audio_waveforms: torch.Tensor):
+        r"""Converts a batch of waveforms into a batch of HCQTs.
+
+        Args:
+            audio_waveforms (torch.Tensor): Batch of waveforms, shape (batch_size, num_samples)
+
+        Returns:
+            Harmonic CQT, shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
+        """
+        return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_kernels], dim=1)
diff --git a/pesto/parser.py b/pesto/utils/parser.py
similarity index 100%
rename from pesto/parser.py
rename to pesto/utils/parser.py
diff --git a/pesto/utils/reduce_activations.py b/pesto/utils/reduce_activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..191cb4f060ce53e1081b8696b5408ab0af1d6b28
--- /dev/null
+++ b/pesto/utils/reduce_activations.py
@@ -0,0 +1,36 @@
+import torch
+
+
+def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor:
+    r"""
+
+    Args:
+        activations: tensor of probability activations, shape (*, num_bins)
+        reduction (str): reduction method to compute pitch out of activations,
+            choose between "argmax", "mean", "alwa".
+
+    Returns:
+        torch.Tensor: pitches as fractions of MIDI semitones, shape (*)
+    """
+    device = activations.device
+    num_bins = activations.size(-1)
+    bps, r = divmod(num_bins, 128)
+    assert r == 0, f"Activations should have output size 128*bins_per_semitone, got {num_bins}."
+
+    if reduction == "argmax":
+        pred = activations.argmax(dim=-1)
+        return pred.float() / bps
+
+    all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps)
+    if reduction == "mean":
+        return torch.matmul(activations, all_pitches)
+
+    if reduction == "alwa":  # argmax-local weighted averaging, see https://github.com/marl/crepe
+        center_bin = activations.argmax(dim=-1, keepdim=True)
+        window = torch.arange(1, 2 * bps, device=device) - bps  # [-bps+1, -bps+2, ..., bps-2, bps-1]
+        indices = (center_bin + window).clip_(min=0, max=num_bins - 1)
+        cropped_activations = activations.gather(-1, indices)
+        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(1, indices)
+        return (cropped_activations * cropped_pitches).sum(dim=1) / cropped_activations.sum(dim=1)
+
+    raise ValueError
diff --git a/pesto/weights/mir-1k.pth b/pesto/weights/mir-1k.pth
deleted file mode 100644
index aadf78ceeff86ae4e86ac305053c5bcc9791ec7f..0000000000000000000000000000000000000000
Binary files a/pesto/weights/mir-1k.pth and /dev/null differ