diff --git a/.github/workflows/test-workflow.yml b/.github/workflows/test-workflow.yml
index 7fc7156c6d6d3ee4b417cbd211d322afb807ebc3..43bda75133354d34cb8d02e4c7a5e27b44110171 100644
--- a/.github/workflows/test-workflow.yml
+++ b/.github/workflows/test-workflow.yml
@@ -2,7 +2,7 @@ name: Test Workflow
 
 on:
   pull_request:
-    branches: [ "master" ]
+    branches: [ "dev" ]
 
 jobs:
   build:
@@ -21,7 +21,10 @@ jobs:
         python-version: ${{ matrix.python-version }}
     - name: Install dependencies
       run: |
+        sudo apt-get install -y libsox-dev
         python -m pip install --upgrade pip
+        python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
+        python -m pip install matplotlib
         python -m pip install pytest
         python -m pip install .
     - name: Test with pytest
diff --git a/.gitignore b/.gitignore
index cc684f882c3744dc83af39ef5e8cbfb7e934fcb9..daf4967a0f60ce293927d6e5b44629d6af654304 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,10 @@
 .idea
-*/__pycache__
+**/__pycache__
 *.egg-info/
 .DS_Store
 dist/
 build/
+
+**/*.csv
+**/*.mid
+**/*.png
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6fcb6ecf6f2179b378c9192e2ba3d3143d5cbe1
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,21 @@
+# Changelog
+
+## v1.0.0
+
+- Change API under the hood to make it more object-oriented
+  - store all utilities inside a `PESTO` object that is a subclass of `nn.Module`
+  - make the API compatible with the checkpoints generated by the training repo
+- add tests
+- replace `setup.py` by `pyproject.toml`
+- fix a few issues
+- improve README and documentation
+
+## v0.1.1
+
+- solve issue when exporting in PNG
+- solve device issue when changing sampling rate (#17)
+
+
+## v0.1.0 - 2023-10-17
+
+Initial version
\ No newline at end of file
diff --git a/README.md b/README.md
index a41a69ccdbddcd652cd2f309d84248276834478a..93c9572fae10943ea42b9988d572bfcbb4747c41 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,8 @@ This repository is implemented in [PyTorch](https://pytorch.org/) and has the fo
 - [torchaudio](https://pytorch.org/audio/stable/) for audio loading
 - `matplotlib` for exporting pitch predictions as images (optional)
 
+**Warning:** If installing in a clean environment, it may be safer to first install PyTorch [the recommended way](https://pytorch.org/get-started/locally/) before PESTO.
+
 ## Usage
 
 ### Command-line interface
@@ -107,48 +109,64 @@ import pesto
 
 # predict the pitch of your audio tensors directly within your own Python code
 x, sr = torchaudio.load("my_file.wav")
-timesteps, pitch, confidence, activations = pesto.predict(x, sr, step_size=10.)
+x = x.mean(dim=0)  # PESTO takes mono audio as input
+timesteps, pitch, confidence, activations = pesto.predict(x, sr)
+
+# or predict using your own custom checkpoint
+predictions = pesto.predict(x, sr, model_name="/path/to/checkpoint.ckpt")
 
 # You can also predict pitches from audio files directly
-pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], step_size=10., export_format=["csv"])
+pesto.predict_from_files(["example1.wav", "example2.mp3", "example3.ogg"], export_format=["csv"])
 ```
+`pesto.predict` supports batched inputs, which should then have shape `(batch_size, num_samples)`.
+
+**Warning:** If you forget to convert a stereo audio in mono, channels will be treated as batch dimensions and you will 
+get predictions for each channel separately.
 
 #### Advanced usage
 
-If not provided,  `pesto.predict` will first load the CQT kernels and the model before performing 
+`pesto.predict` will first load the CQT kernels and the model before performing 
 any pitch estimation. If you want to process a significant number of files, calling `predict` several times will then 
 re-initialize the same model for each tensor.
 
-To avoid this time-consuming step, one can manually instantiate  the model   and data processor, then pass them directly 
-as args to the `predict` function. To do so, one has to use the underlying methods from `pesto.utils`:
+To avoid this time-consuming step, one can manually instantiate  the model with `load_model`,
+then call the forward method of the model instead:
 
 ```python
 import torch
 
-from pesto import predict
-from pesto.utils import load_model, load_dataprocessor
+from pesto import load_model
 
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-model = load_model("mir-1k", device=device)
-data_processor = load_dataprocessor(step_size=0.01, device=device)
+pesto_model = load_model("mir-1k", step_size=20.).to(device)
 
 for x, sr in ...:
-    data_processor.sampling_rate = sr  # The data_processor handles waveform->CQT conversion so it must know the sampling rate
-    predictions = predict(x, sr, model=model)
+    x = x.to(device)
+    predictions, confidence, activations = pesto_model(x, sr)
     ...
 ```
-Note that when passing a list of files to `pesto.predict_from_files(...)` or the CLI directly, the model  is loaded only
-once so you don't have to bother with that in general.
 
-#### Batched pitch estimation
+Note that the `PESTO` object returned by `load_model` is a subclass of `nn.Module` 
+and its `forward` method also supports batched inputs.
+One can therefore easily integrate PESTO within their own architecture by doing:
+```python
+import torch
+import torch.nn as nn
+
+from pesto import load_model
 
-By default, the function `pesto.predict` takes an audio waveform represented as a Tensor object of shape `(num_channels, num_samples)`.
-However, one may want to estimate the pitch of batches of (cropped) waveforms within a training pipeline, e.g. for DDSP-related applications.
-`pesto.predict` therefore accepts Tensor inputs of shape `(batch_size, num_channels, num_samples)` and returns batch-wise pitch predictions accordingly.
 
-Note that batched predictions are available only from the Python API and not from the CLI because:
-- handling audios of different lengths is annoying, I don't want to bother with that
-- when estimating pitch on
+class MyGreatModel(nn.Module):
+    def __init__(self, step_size, sr=44100, *args, **kwargs):
+        super(MyGreatModel, self).__init__()
+        self.f0_estimator = load_model("mir-1k", step_size, sampling_rate=sr)
+        ...
+
+    def forward(self, x):
+        with torch.no_grad():
+            f0, conf = self.f0_estimator(x, convert_to_freq=True, return_activations=False)
+        ...
+```
 
 ## Performances
 
diff --git a/pesto/__init__.py b/pesto/__init__.py
index 82a92f8908b9f434140e3c1f7f44899a11cb8655..167a098923d22ab94ffa03e186547d9c6f3c2bfb 100644
--- a/pesto/__init__.py
+++ b/pesto/__init__.py
@@ -1 +1,4 @@
-from .core import predict, predict_from_files
+from .core import load_model, predict, predict_from_files
+
+
+__version__ = '1.0.0'
diff --git a/pesto/core.py b/pesto/core.py
index c55ac4afa24b020ac45e84e7eed4b46380c4524f..dd28a5872e8caa8a1f49160743b9d9fe05138a42 100644
--- a/pesto/core.py
+++ b/pesto/core.py
@@ -1,91 +1,87 @@
 import os
 import warnings
-from typing import Optional, Sequence, Union
+from typing import Optional, Sequence, Tuple, Union
 
 import torch
 import torchaudio
 from tqdm import tqdm
 
-from .utils import load_model, load_dataprocessor, reduce_activation
-from .export import export
+from .loader import load_model
+from .model import PESTO
+from .utils import export
 
 
-@torch.inference_mode()
-def predict(
-        x: torch.Tensor,
-        sr: Optional[int] = None,
-        model: Union[torch.nn.Module, str] = "mir-1k",
-        data_preprocessor=None,
-        step_size: Optional[float] = None,
-        reduction: str = "argmax",
-        num_chunks: int = 1,
-        convert_to_freq: bool = False
-):
+def _predict(x: torch.Tensor,
+             sr: int,
+             model: PESTO,
+             num_chunks: int = 1,
+             convert_to_freq: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    preds, confidence, activations = [], [], []
+    try:
+        for chunk in x.chunk(chunks=num_chunks):
+            pred, conf, act = model(chunk, sr=sr, convert_to_freq=convert_to_freq, return_activations=True)
+            preds.append(pred)
+            confidence.append(conf)
+            activations.append(act)
+    except torch.cuda.OutOfMemoryError:
+        raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
+                                          "Please increase the number of chunks with option `-c`/`--chunks` "
+                                          "to reduce GPU memory usage.")
+
+    preds = torch.cat(preds, dim=0)
+    confidence = torch.cat(confidence, dim=0)
+    activations = torch.cat(activations, dim=0)
+
+    # compute timesteps
+    timesteps = torch.arange(preds.size(-1), device=x.device) * model.hop_size
+
+    return timesteps, preds, confidence, activations
+
+
+def predict(x: torch.Tensor,
+            sr: int,
+            step_size: float = 10.,
+            model_name: str = "mir-1k",
+            reduction: str = "alwa",
+            num_chunks: int = 1,
+            convert_to_freq: bool = True,
+            inference_mode: bool = True,
+            no_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     r"""Main prediction function.
 
     Args:
-        x (torch.Tensor): input audio tensor,
-            shape (num_channels, num_samples) or (batch_size, num_channels, num_samples)
+        x (torch.Tensor): input audio tensor, can be provided as a batch but should be mono,
+            shape (num_samples) or (batch_size, num_samples)
         sr (int, optional): sampling rate. If not specified, uses the current sampling rate of the model.
-        model: PESTO model. If a string is passed, it will load the model with the corresponding name.
-            Otherwise, the actual nn.Module will be used for doing predictions.
-        data_preprocessor: Module handling the data processing pipeline (waveform to CQT, cropping, etc.)
         step_size (float, optional): step size between each CQT frame in milliseconds.
-            If the data_preprocessor is passed, its value will be used instead.
+            If a `PESTO` object is passed as `model`, this will be ignored.
+        model_name: name of PESTO model. Can be a path to a custom PESTO checkpoint or the name of a standard model.
         reduction (str): reduction method for converting activation probabilities to log-frequencies.
         num_chunks (int): number of chunks to split the input audios in.
             Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage
             and prevent out-of-memory errors.
         convert_to_freq (bool): whether predictions should be converted to frequencies or not.
-    """
-    # convert to mono
-    assert 2 <= x.ndim <= 3, f"Audio file should have two dimensions, but found shape {x.size()}"
-    batch_size = x.size(0) if x.ndim == 3 else None
-    x = x.mean(dim=-2)
-
-    if data_preprocessor is None:
-        assert step_size is not None, \
-            "If you don't use a predefined data preprocessor, you must at least indicate a step size (in milliseconds)"
-        data_preprocessor = load_dataprocessor(step_size=step_size / 1000., device=x.device)
-
-    # If the sampling rate has changed, change the sampling rate accordingly
-    # It will automatically recompute the CQT kernels if needed
-    data_preprocessor.sampling_rate = sr
-
-    if isinstance(model, str):
-        model = load_model(model, device=x.device)
+        inference_mode (bool): whether to run with `torch.inference_mode`.
+        no_grad (bool): whether to run with `torch.no_grad`. If set to `False`, argument `inference_mode` is ignored.
 
-    # apply model
-    cqt = data_preprocessor(x)
-    try:
-        activations = torch.cat([
-            model(chunk) for chunk in cqt.chunk(chunks=num_chunks)
-        ])
-    except torch.cuda.OutOfMemoryError:
-        raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
-                                          "Please increase the number of chunks with option `-c`/`--chunks` "
-                                          "to reduce GPU memory usage.")
-
-    if batch_size:
-        total_batch_size, num_predictions = activations.size()
-        activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions)
-
-    # shift activations as it should (PESTO predicts pitches up to an additive constant)
-    activations = activations.roll(model.abs_shift.cpu().item(), dims=-1)
-
-    # convert model predictions to pitch values
-    pitch = reduce_activation(activations, reduction=reduction)
-    if convert_to_freq:
-        pitch = 440 * 2 ** ((pitch - 69) / 12)
-
-    # for now, confidence is computed very naively just based on volume
-    confidence = cqt.squeeze(1).max(dim=1).values.view_as(pitch)
-    conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
-    confidence = (confidence - conf_min) / (conf_max - conf_min)
+    Returns:
+        timesteps (torch.Tensor): timesteps corresponding to each pitch prediction, shape (num_timesteps)
+        preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
+            where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
+        confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
+            shape (batch_size?, num_timesteps)
+        activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
+    """
+    # sanity checks
+    assert x.ndim <= 2, \
+        f"Audio file should have shape (num_samples) or (batch_size, num_samples), but found shape {x.size()}."
 
-    timesteps = torch.arange(pitch.size(-1), device=x.device) * data_preprocessor.step_size
+    inference_mode = inference_mode and no_grad
+    with torch.no_grad() if no_grad and not inference_mode else torch.inference_mode(mode=inference_mode):
+        model = load_model(model_name, step_size, sampling_rate=sr).to(x.device)
+        model.reduction = reduction
 
-    return timesteps, pitch, confidence, activations
+        return _predict(x, sr, model, num_chunks=num_chunks, convert_to_freq=convert_to_freq)
 
 
 def predict_from_files(
@@ -97,8 +93,7 @@ def predict_from_files(
         export_format: Sequence[str] = ("csv",),
         no_convert_to_freq: bool = False,
         num_chunks: int = 1,
-        gpu: int = -1
-):
+        gpu: int = -1):
     r"""
 
     Args:
@@ -107,12 +102,11 @@ def predict_from_files(
         output:
         step_size: hop length in milliseconds
         reduction:
-        export_format:
+        export_format (Sequence[str]): format to export the predictions to.
+            Currently formats supported are: ["csv", "npz", "json"].
         no_convert_to_freq: whether convert output values to Hz or keep fractional MIDI pitches
+        num_chunks (int): number of chunks to divide the inputs into. Increase this value if you encounter OOM errors.
         gpu: index of GPU to use (-1 for CPU)
-
-    Returns:
-        Pitch predictions, see `predict` for more details.
     """
     if isinstance(audio_files, str):
         audio_files = [audio_files]
@@ -122,37 +116,33 @@ def predict_from_files(
         gpu = -1
     device = torch.device(f"cuda:{gpu:d}" if gpu >= 0 else "cpu")
 
-    # define data preprocessing
-    data_preprocessor = load_dataprocessor(step_size / 1000., device=device)
-
     # define model
-    model = load_model(model_name, device=device)
-    predictions = None
+    model = load_model(model_name, step_size=step_size).to(device)
+    model.reduction = reduction
 
     pbar = tqdm(audio_files)
-    for file in pbar:
-        pbar.set_description(file)
 
-        # load audio file
-        try:
-            x, sr = torchaudio.load(file)
-        except Exception as e:
-            print(e, f"Skipping {file}...")
-            continue
+    with torch.inference_mode():  # here the purpose is to write results in disk, so there is no point storing gradients
+        for file in pbar:
+            pbar.set_description(file)
 
-        x = x.to(device)
+            # load audio file
+            try:
+                x, sr = torchaudio.load(file)
+            except Exception as e:
+                print(e, f"Skipping {file}...")
+                continue
 
-        # compute the predictions
-        predictions = predict(x, sr, model=model, data_preprocessor=data_preprocessor, reduction=reduction,
-                              convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
+            x = x.mean(dim=0).to(device)  # convert to mono then pass to the right device
 
-        output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0")
-        if output is not None:
-            os.makedirs(output, exist_ok=True)
-            output_file = os.path.join(output, os.path.basename(output_file))
+            # compute the predictions
+            predictions = _predict(x, sr, model=model, convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)
 
-        predictions = [p.cpu().numpy() for p in predictions]
-        for fmt in export_format:
-            export(fmt, output_file, *predictions)
+            output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0")
+            if output is not None:
+                os.makedirs(output, exist_ok=True)
+                output_file = os.path.join(output, os.path.basename(output_file))
 
-    return predictions
+            predictions = [p.cpu().numpy() for p in predictions]
+            for fmt in export_format:
+                export(fmt, output_file, *predictions)
diff --git a/pesto/data.py b/pesto/data.py
index ea053726694aa78a6645a7d0c2223b02d81a2f2a..3044bbdf3fcfc0818b198555f7711623517d8e62 100644
--- a/pesto/data.py
+++ b/pesto/data.py
@@ -1,74 +1,95 @@
+from typing import Optional
+
 import torch
 import torch.nn as nn
 
-from .cqt import CQT
+from .utils import HarmonicCQT
+
+
+class ToLogMagnitude(nn.Module):
+    def __init__(self):
+        super(ToLogMagnitude, self).__init__()
+        self.eps = torch.finfo(torch.float32).eps
+
+    def forward(self, x):
+        x = x.abs()
+        x.clamp_(min=self.eps).log10_().mul_(20)
+        return x
+
 
 
-class DataProcessor(nn.Module):
+class Preprocessor(nn.Module):
     r"""
 
     Args:
-        step_size (float): step size between consecutive CQT frames (in SECONDS)
+        hop_size (float): step size between consecutive CQT frames (in milliseconds)
     """
     def __init__(self,
-                 step_size: float,
-                 bins_per_semitone: int = 3,
-                 device: torch.device = torch.device("cpu"),
-                 **cqt_kwargs):
-        super(DataProcessor, self).__init__()
-        self.bins_per_semitone = bins_per_semitone
-
-        # CQT-related stuff
-        self.cqt_kwargs = cqt_kwargs
-        self.cqt_kwargs["bins_per_octave"] = 12 * bins_per_semitone
-        self.cqt = None
+                 hop_size: float,
+                 sampling_rate: Optional[int] = None,
+                 **hcqt_kwargs):
+        super(Preprocessor, self).__init__()
+
+        # HCQT
+        self.hcqt_sr = None
+        self.hcqt_kernels = None
+        self.hop_size = hop_size
+
+        self.hcqt_kwargs = hcqt_kwargs
 
         # log-magnitude
-        self.eps = torch.finfo(torch.float32).eps
+        self.to_log = ToLogMagnitude()
 
-        # cropping
-        self.lowest_bin = int(11 * self.bins_per_semitone / 2 + 0.5)
-        self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
+        # register a dummy tensor to get implicit access to the module's device
+        self.register_buffer("_device", torch.zeros(()), persistent=False)
 
-        # handling different sampling rates
-        self._sampling_rate = None
-        self.step_size = step_size
-        self.device = device
+        # if the sampling rate is provided, instantiate the CQT kernels
+        if sampling_rate is not None:
+            self.hcqt_sr = sampling_rate
+            self._reset_hcqt_kernels()
 
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
         r"""
 
         Args:
-            x: audio waveform, any sampling rate, shape (num_samples)
+            x (torch.Tensor): audio waveform or batch of audio waveforms, any sampling rate,
+                shape (batch_size?, num_samples)
+            sr (int, optional): sampling rate
 
         Returns:
-            log-magnitude CQT, shape (
+            torch.Tensor: log-magnitude CQT of batch of CQTs,
+                shape (batch_size?, num_timesteps, num_harmonics, num_freqs)
         """
-        # compute CQT from input waveform, and invert dims for (batch_size, time_steps, freq_bins)
-        complex_cqt = torch.view_as_complex(self.cqt(x)).transpose(1, 2)
-
-        # reshape and crop borders to fit training input shape
-        complex_cqt = complex_cqt[..., self.lowest_bin: self.highest_bin]
-
-        # flatten eventual batch dimensions so that batched audios can be processed in parallel
-        complex_cqt = complex_cqt.flatten(0, 1).unsqueeze(1)
+        # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins)
+        # in other words, time becomes the batch dimension, enabling efficient processing for long audios.
+        complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2)
+        complex_cqt.squeeze_(0)
 
         # convert to dB
-        log_cqt = complex_cqt.abs().clamp_(min=self.eps).log10_().mul_(20)
-        return log_cqt
-
-    def _init_cqt_layer(self, sr: int, hop_length: int):
-        self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device)
+        return self.to_log(complex_cqt)
 
-    @property
-    def sampling_rate(self) -> int:
-        return self._sampling_rate
+    def hcqt(self, audio: torch.Tensor, sr: Optional[int] = None) -> torch.Tensor:
+        r"""Compute the Harmonic CQT of the input audio after eventually recreating the kernels
+        (in case the sampling rate has changed).
 
-    @sampling_rate.setter
-    def sampling_rate(self, sr: int):
-        if sr == self._sampling_rate:
-            return
+        Args:
+            audio (torch.Tensor): mono audio waveform, shape (batch_size, num_samples)
+            sr (int): sampling rate of the audio waveform.
+                If not specified, we assume it is the same as the previous processed audio waveform.
 
-        hop_length = int(self.step_size * sr + 0.5)
-        self._init_cqt_layer(sr, hop_length)
-        self._sampling_rate = sr
+        Returns:
+            torch.Tensor: Complex Harmonic CQT (HCQT) of the input,
+                shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
+        """
+        # compute HCQT kernels if it does not exist or if the sampling rate has changed
+        if sr is not None and sr != self.hcqt_sr:
+            self.hcqt_sr = sr
+            self._reset_hcqt_kernels()
+
+        return self.hcqt_kernels(audio)
+
+    def _reset_hcqt_kernels(self) -> None:
+        hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5)
+        self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr,
+                                        hop_length=hop_length,
+                                        **self.hcqt_kwargs).to(self._device.device)
diff --git a/pesto/loader.py b/pesto/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..a270615997197cfa4abdcfeef30de6a68d67be98
--- /dev/null
+++ b/pesto/loader.py
@@ -0,0 +1,51 @@
+import os
+from typing import Optional
+
+import torch
+
+from .data import Preprocessor
+from .model import PESTO, Resnet1d
+
+
+def load_model(checkpoint: str,
+               step_size: float,
+               sampling_rate: Optional[int] = None) -> PESTO:
+    r"""Load a trained model from a checkpoint file.
+    See https://github.com/SonyCSLParis/pesto-full/blob/master/src/models/pesto.py for the structure of the checkpoint.
+
+    Args:
+        checkpoint (str): path to the checkpoint or name of the checkpoint file (if using a provided checkpoint)
+        step_size (float): hop size in milliseconds
+        sampling_rate (int, optional): sampling rate of the audios.
+            If not provided, it can be inferred dynamically as well.
+    Returns:
+        PESTO: instance of PESTO model
+    """
+    if os.path.exists(checkpoint):  # handle user-provided checkpoints
+        model_path = checkpoint
+    else:
+        model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".ckpt")
+        if not os.path.exists(model_path):
+            raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.")
+
+    # load checkpoint
+    checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
+    hparams = checkpoint["hparams"]
+    hcqt_params = checkpoint["hcqt_params"]
+    state_dict = checkpoint["state_dict"]
+
+    # instantiate preprocessor
+    preprocessor = Preprocessor(hop_size=step_size, sampling_rate=sampling_rate, **hcqt_params)
+
+    # instantiate PESTO encoder
+    encoder = Resnet1d(**hparams["encoder"])
+
+    # instantiate main PESTO module and load its weights
+    model = PESTO(encoder,
+                  preprocessor=preprocessor,
+                  crop_kwargs=hparams["pitch_shift"],
+                  reduction=hparams["reduction"])
+    model.load_state_dict(state_dict, strict=False)
+    model.eval()
+
+    return model
diff --git a/pesto/main.py b/pesto/main.py
index bc52df6cb0374c6798e23fdfe117e224c7983369..d117941f84eed1535970c16f5f4fa53029bf91f5 100644
--- a/pesto/main.py
+++ b/pesto/main.py
@@ -1,4 +1,4 @@
-from .parser import parse_args
+from pesto.utils.parser import parse_args
 from .core import predict_from_files
 
 
diff --git a/pesto/model.py b/pesto/model.py
index 94337ec84570b1cd3d5aaf1609b06f3f49a30ae6..b9cf5eda18349e38a2d4de1f739798bd5b1df0e6 100644
--- a/pesto/model.py
+++ b/pesto/model.py
@@ -1,8 +1,15 @@
 from functools import partial
+from typing import Any, Mapping, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
 
+from .utils import CropCQT
+from .utils import reduce_activations
+
+
+OUTPUT_TYPE = Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
+
 
 class ToeplitzLinear(nn.Conv1d):
     def __init__(self, in_features, out_features):
@@ -18,7 +25,7 @@ class ToeplitzLinear(nn.Conv1d):
         return super(ToeplitzLinear, self).forward(input.unsqueeze(-2)).squeeze(-2)
 
 
-class PESTOEncoder(nn.Module):
+class Resnet1d(nn.Module):
     """
     Basic CNN similar to the one in Johannes Zeitler's report,
     but for longer HCQT input (always stride 1 in time)
@@ -29,7 +36,8 @@ class PESTOEncoder(nn.Module):
     not over time (in order to work with variable length input).
     Outputs one channel with sigmoid activation.
 
-    Args (Defaults: BasicCNN by Johannes Zeitler but with 1 input channel):
+    Args (Defaults: BasicCNN by Johannes Zeitler but with 6 input channels):
+        n_chan_input:     Number of input channels (harmonics in HCQT)
         n_chan_layers:    Number of channels in the hidden layers (list)
         n_prefilt_layers: Number of repetitions of the prefiltering layer
         residual:         If True, use residual connections for prefiltering (default: False)
@@ -39,100 +47,191 @@ class PESTOEncoder(nn.Module):
         p_dropout:        Dropout probability
     """
 
-    def __init__(
-            self,
-            n_chan_layers=(20, 20, 10, 1),
-            n_prefilt_layers=1,
-            residual=False,
-            n_bins_in=216,
-            output_dim=128,
-            num_output_layers: int = 1
-    ):
-        super(PESTOEncoder, self).__init__()
-
-        activation_layer = partial(nn.LeakyReLU, negative_slope=0.3)
-
+    def __init__(self,
+                 n_chan_input=1,
+                 n_chan_layers=(20, 20, 10, 1),
+                 n_prefilt_layers=1,
+                 prefilt_kernel_size=15,
+                 residual=False,
+                 n_bins_in=216,
+                 output_dim=128,
+                 activation_fn: str = "leaky",
+                 a_lrelu=0.3,
+                 p_dropout=0.2):
+        super(Resnet1d, self).__init__()
+
+        self.hparams = dict(n_chan_input=n_chan_input,
+                            n_chan_layers=n_chan_layers,
+                            n_prefilt_layers=n_prefilt_layers,
+                            prefilt_kernel_size=prefilt_kernel_size,
+                            residual=residual,
+                            n_bins_in=n_bins_in,
+                            output_dim=output_dim,
+                            activation_fn=activation_fn,
+                            a_lrelu=a_lrelu,
+                            p_dropout=p_dropout)
+
+        if activation_fn == "relu":
+            activation_layer = nn.ReLU
+        elif activation_fn == "silu":
+            activation_layer = nn.SiLU
+        elif activation_fn == "leaky":
+            activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu)
+        else:
+            raise ValueError
+
+        n_in = n_chan_input
         n_ch = n_chan_layers
         if len(n_ch) < 5:
             n_ch.append(1)
 
-        # Layer normalization over frequency
-        self.layernorm = nn.LayerNorm(normalized_shape=[1, n_bins_in])
+        # Layer normalization over frequency and channels (harmonics of HCQT)
+        self.layernorm = nn.LayerNorm(normalized_shape=[n_in, n_bins_in])
 
         # Prefiltering
+        prefilt_padding = prefilt_kernel_size // 2
         self.conv1 = nn.Sequential(
-            nn.Conv1d(in_channels=1, out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
-            activation_layer()
+            nn.Conv1d(in_channels=n_in,
+                      out_channels=n_ch[0],
+                      kernel_size=prefilt_kernel_size,
+                      padding=prefilt_padding,
+                      stride=1),
+            activation_layer(),
+            nn.Dropout(p=p_dropout)
         )
         self.n_prefilt_layers = n_prefilt_layers
-        self.prefilt_list = nn.ModuleList()
-        for p in range(1, n_prefilt_layers):
-            self.prefilt_list.append(nn.Sequential(
-                nn.Conv1d(in_channels=n_ch[0], out_channels=n_ch[0], kernel_size=15, padding=7, stride=1),
-                activation_layer()
-            ))
+        self.prefilt_layers = nn.ModuleList(*[
+            nn.Sequential(
+                nn.Conv1d(in_channels=n_ch[0],
+                          out_channels=n_ch[0],
+                          kernel_size=prefilt_kernel_size,
+                          padding=prefilt_padding,
+                          stride=1),
+                activation_layer(),
+                nn.Dropout(p=p_dropout)
+            )
+            for _ in range(n_prefilt_layers-1)
+        ])
         self.residual = residual
 
-        self.conv2 = nn.Sequential(
-            nn.Conv1d(
-                in_channels=n_ch[0],
-                out_channels=n_ch[1],
-                kernel_size=1,
-                stride=1,
-                padding=0
-            ),
-            activation_layer()
-        )
-
-        self.conv3 = nn.Sequential(
-            nn.Conv1d(in_channels=n_ch[1], out_channels=n_ch[2], kernel_size=1, padding=0, stride=1),
-            activation_layer()
-        )
-
-        self.conv4 = nn.Sequential(
-            nn.Conv1d(in_channels=n_ch[2], out_channels=n_ch[3], kernel_size=1, padding=0, stride=1),
-            activation_layer(),
-            nn.Dropout(),
-            nn.Conv1d(in_channels=n_ch[3], out_channels=n_ch[4], kernel_size=1, padding=0, stride=1)
-        )
+        conv_layers = []
+        for i in range(len(n_chan_layers)-1):
+            conv_layers.extend([
+                nn.Conv1d(in_channels=n_ch[i],
+                          out_channels=n_ch[i + 1],
+                          kernel_size=1,
+                          padding=0,
+                          stride=1),
+                activation_layer(),
+                nn.Dropout(p=p_dropout)
+            ])
+        self.conv_layers = nn.Sequential(*conv_layers)
 
         self.flatten = nn.Flatten(start_dim=1)
-
-        layers = []
-        pre_fc_dim = n_bins_in * n_ch[4]
-        for i in range(num_output_layers-1):
-            layers.extend([
-                ToeplitzLinear(pre_fc_dim, pre_fc_dim),
-                activation_layer()
-            ])
-        self.pre_fc = nn.Sequential(*layers)
-        self.fc = ToeplitzLinear(pre_fc_dim, output_dim)
+        self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
 
         self.final_norm = nn.Softmax(dim=-1)
 
-        self.register_buffer("abs_shift", torch.zeros((), dtype=torch.long), persistent=True)
-
     def forward(self, x):
         r"""
 
         Args:
             x (torch.Tensor): shape (batch, channels, freq_bins)
         """
-        x_norm = self.layernorm(x)
+        x = self.layernorm(x)
 
-        x = self.conv1(x_norm)
+        x = self.conv1(x)
         for p in range(0, self.n_prefilt_layers - 1):
-            prefilt_layer = self.prefilt_list[p]
+            prefilt_layer = self.prefilt_layers[p]
             if self.residual:
                 x_new = prefilt_layer(x)
                 x = x_new + x
             else:
                 x = prefilt_layer(x)
-        conv2_lrelu = self.conv2(x)
-        conv3_lrelu = self.conv3(conv2_lrelu)
 
-        y_pred = self.conv4(conv3_lrelu)
-        y_pred = self.flatten(y_pred)
-        y_pred = self.pre_fc(y_pred)
-        y_pred = self.fc(y_pred)
+        x = self.conv_layers(x)
+        x = self.flatten(x)
+
+        y_pred = self.fc(x)
+
         return self.final_norm(y_pred)
+
+
+class PESTO(nn.Module):
+    def __init__(self,
+                 encoder: nn.Module,
+                 preprocessor: nn.Module,
+                 crop_kwargs: Optional[Mapping[str, Any]] = None,
+                 reduction: str = "alwa"):
+        super(PESTO, self).__init__()
+        self.encoder = encoder
+        self.preprocessor = preprocessor
+
+        # crop CQT
+        if crop_kwargs is None:
+            crop_kwargs = {}
+        self.crop_cqt = CropCQT(**crop_kwargs)
+
+        self.reduction = reduction
+
+        # constant shift to get absolute pitch from predictions
+        self.register_buffer('shift', torch.zeros((), dtype=torch.float), persistent=True)
+
+    def forward(self,
+                audio_waveforms: torch.Tensor,
+                sr: Optional[int] = None,
+                convert_to_freq: bool = False,
+                return_activations: bool = False) -> OUTPUT_TYPE:
+        r"""
+
+        Args:
+            audio_waveforms (torch.Tensor): mono audio waveform or batch of mono audio waveforms,
+                shape (batch_size?, num_samples)
+            sr (int, optional): sampling rate, defaults to the previously used sampling rate
+            convert_to_freq (bool): whether to convert the result to frequencies or return fractional semitones instead.
+            return_activations (bool): whether to return activations or pitch predictions only
+
+        Returns:
+            preds (torch.Tensor): pitch predictions in SEMITONES, shape (batch_size?, num_timesteps)
+                where `num_timesteps` ~= `num_samples` / (`self.hop_size` * `sr`)
+            confidence (torch.Tensor): confidence of whether frame is voiced or unvoiced in [0, 1],
+                shape (batch_size?, num_timesteps)
+            activations (torch.Tensor): activations of the model, shape (batch_size?, num_timesteps, output_dim)
+        """
+        batch_size = audio_waveforms.size(0) if audio_waveforms.ndim == 2 else None
+        x = self.preprocessor(audio_waveforms, sr=sr)
+        x = self.crop_cqt(x)  # the CQT has to be cropped beforehand
+
+        # for now, confidence is computed very naively just based on energy in the CQT
+        confidence = x.mean(dim=-2).max(dim=-1).values
+        conf_min, conf_max = confidence.min(dim=-1, keepdim=True).values, confidence.max(dim=-1, keepdim=True).values
+        confidence = (confidence - conf_min) / (conf_max - conf_min)
+
+        # flatten batch_size and time_steps since anyway predictions are made on CQT frames independently
+        if batch_size:
+            x = x.flatten(0, 1)
+
+        activations = self.encoder(x)
+        if batch_size:
+            activations = activations.view(batch_size, -1, activations.size(-1))
+
+        activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1)
+
+        preds = reduce_activations(activations, reduction=self.reduction)
+
+        if convert_to_freq:
+            preds = 440 * 2 ** ((preds - 69) / 12)
+
+        if return_activations:
+            return preds, confidence, activations
+
+        return preds, confidence
+
+    @property
+    def bins_per_semitone(self) -> int:
+        return self.preprocessor.hcqt_kwargs["bins_per_semitone"]
+
+    @property
+    def hop_size(self) -> float:
+        r"""Returns the hop size of the model (in milliseconds)"""
+        return self.preprocessor.hop_size
diff --git a/pesto/utils.py b/pesto/utils.py
deleted file mode 100644
index 46281ebc1f4fff05c5c10fd14bd056244482039b..0000000000000000000000000000000000000000
--- a/pesto/utils.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import os
-from typing import Optional
-
-import torch
-
-from .config import model_args, cqt_args, bins_per_semitone
-from .data import DataProcessor
-from .model import PESTOEncoder
-
-
-def load_dataprocessor(step_size, device: Optional[torch.device] = None):
-    return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device)
-
-
-def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
-    model = PESTOEncoder(**model_args).to(device)
-    model.eval()
-
-    model_path = os.path.join(os.path.dirname(__file__), "weights", model_name + ".pth")
-    model.load_state_dict(torch.load(model_path, map_location=device))
-
-    return model
-
-
-def reduce_activation(activations: torch.Tensor, reduction: str) -> torch.Tensor:
-    r"""Computes the pitch predictions from the activation outputs of the encoder.
-    Pitch predictions are returned in semitones, NOT in frequencies.
-
-    Args:
-        activations: tensor of probability activations, shape (*, num_bins)
-        reduction:
-
-    Returns:
-        torch.Tensor: pitch predictions, shape (*,)
-    """
-    bps = bins_per_semitone
-    if reduction == "argmax":
-        pred = activations.argmax(dim=-1)
-        return pred.float() / bps
-
-    all_pitches = (torch.arange(activations.size(-1), dtype=torch.float, device=activations.device)) / bps
-    if reduction == "mean":
-        return activations @ all_pitches
-
-    if reduction == "alwa":  # argmax-local weighted averaging, see https://github.com/marl/crepe
-        center_bin = activations.argmax(dim=-1, keepdim=True)
-        window = torch.arange(-bps+1, bps, device=activations.device)
-        indices = window + center_bin
-        cropped_activations = activations.gather(-1, indices)
-        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
-        return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)
-
-    raise ValueError
diff --git a/pesto/utils/__init__.py b/pesto/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be9b1b72be5579424e7411fe2291823838cd3d4
--- /dev/null
+++ b/pesto/utils/__init__.py
@@ -0,0 +1,4 @@
+from .crop_cqt import CropCQT
+from .export import export
+from .hcqt import HarmonicCQT
+from .reduce_activations import reduce_activations
\ No newline at end of file
diff --git a/pesto/utils/crop_cqt.py b/pesto/utils/crop_cqt.py
new file mode 100644
index 0000000000000000000000000000000000000000..65829c70e8ee90554acee1c027178f6292cd4230
--- /dev/null
+++ b/pesto/utils/crop_cqt.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn as nn
+
+
+class CropCQT(nn.Module):
+    def __init__(self, min_steps: int, max_steps: int):
+        super(CropCQT, self).__init__()
+        self.min_steps = min_steps
+        self.max_steps = max_steps
+
+        # lower bin
+        self.lower_bin = self.max_steps
+
+    def forward(self, spectrograms: torch.Tensor) -> torch.Tensor:
+        # WARNING: didn't check that it works, it may be dangerous
+        return spectrograms[..., self.max_steps: self.min_steps]
+
+        # old implementation
+        batch_size, _, input_height = spectrograms.size()
+
+        output_height = input_height - self.max_steps + self.min_steps
+        assert output_height > 0, \
+            f"With input height {input_height:d} and output height {output_height:d}, impossible " \
+            f"to have a range of {self.max_steps - self.min_steps:d} bins."
+
+        return spectrograms[..., self.lower_bin: self.lower_bin + output_height]
diff --git a/pesto/export.py b/pesto/utils/export.py
similarity index 89%
rename from pesto/export.py
rename to pesto/utils/export.py
index daab6a96f5d474f18e75cb914c3e7171ccdc3a49..70aeb54861b9bff67661fa58793be8e3230ae507 100644
--- a/pesto/export.py
+++ b/pesto/utils/export.py
@@ -41,11 +41,13 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1
 
     bps = activations.shape[1] // 128
     activations = activations[:, bps*lims[0]: bps*lims[1]]
-    activations = activations * confidence.unsqueeze(1)
+    activations = activations * confidence[:, None]
     plt.imshow(activations.T,
                aspect='auto', origin='lower', cmap='inferno',
-               extent=(timesteps[0], timesteps[-1]) + lims)
+               extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims)
 
+    plt.xlabel("Time (s)")
+    plt.ylabel("Pitch (semitones)")
     plt.title(output_file.rsplit('.', 2)[0])
     plt.tight_layout()
     plt.savefig(output_file)
diff --git a/pesto/cqt.py b/pesto/utils/hcqt.py
similarity index 91%
rename from pesto/cqt.py
rename to pesto/utils/hcqt.py
index 6fdfccb6372077fd1f94b645fded1178c90983d7..f7ff5610feb4b3c8d3bb29affc9a4ea5d2661538 100644
--- a/pesto/cqt.py
+++ b/pesto/utils/hcqt.py
@@ -354,3 +354,39 @@ class CQT(nn.Module):
             phase_real = torch.cos(torch.atan2(CQT_imag, CQT_real))
             phase_imag = torch.sin(torch.atan2(CQT_imag, CQT_real))
             return torch.stack((phase_real, phase_imag), -1)
+
+
+class HarmonicCQT(nn.Module):
+    r"""Harmonic CQT layer, as described in Bittner et al. (20??)"""
+    def __init__(
+            self,
+            harmonics,
+            sr: int = 22050,
+            hop_length: int = 512,
+            fmin: float = 32.7,
+            fmax: Optional[float] = None,
+            bins_per_semitone: int = 1,
+            n_bins: int = 84,
+            center_bins: bool = True
+    ):
+        super(HarmonicCQT, self).__init__()
+
+        if center_bins:
+            fmin = fmin / 2 ** ((bins_per_semitone - 1) / (24 * bins_per_semitone))
+
+        self.cqt_kernels = nn.ModuleList([
+            CQT(sr=sr, hop_length=hop_length, fmin=h*fmin, fmax=fmax, n_bins=n_bins,
+                bins_per_octave=12*bins_per_semitone, output_format="Complex")
+            for h in harmonics
+        ])
+
+    def forward(self, audio_waveforms: torch.Tensor):
+        r"""Converts a batch of waveforms into a batch of HCQTs.
+
+        Args:
+            audio_waveforms (torch.Tensor): Batch of waveforms, shape (batch_size, num_samples)
+
+        Returns:
+            Harmonic CQT, shape (batch_size, num_harmonics, num_freqs, num_timesteps, 2)
+        """
+        return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_kernels], dim=1)
diff --git a/pesto/parser.py b/pesto/utils/parser.py
similarity index 100%
rename from pesto/parser.py
rename to pesto/utils/parser.py
diff --git a/pesto/utils/reduce_activations.py b/pesto/utils/reduce_activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..b22dc508c2c0299b03c3038c9c50bf282bf1b76f
--- /dev/null
+++ b/pesto/utils/reduce_activations.py
@@ -0,0 +1,36 @@
+import torch
+
+
+def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor:
+    r"""
+
+    Args:
+        activations: tensor of probability activations, shape (*, num_bins)
+        reduction (str): reduction method to compute pitch out of activations,
+            choose between "argmax", "mean", "alwa".
+
+    Returns:
+        torch.Tensor: pitches as fractions of MIDI semitones, shape (*)
+    """
+    device = activations.device
+    num_bins = activations.size(-1)
+    bps, r = divmod(num_bins, 128)
+    assert r == 0, f"Activations should have output size 128*bins_per_semitone, got {num_bins}."
+
+    if reduction == "argmax":
+        pred = activations.argmax(dim=-1)
+        return pred.float() / bps
+
+    all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps)
+    if reduction == "mean":
+        return torch.matmul(activations, all_pitches)
+
+    if reduction == "alwa":  # argmax-local weighted averaging, see https://github.com/marl/crepe
+        center_bin = activations.argmax(dim=-1, keepdim=True)
+        window = torch.arange(1, 2 * bps, device=device) - bps  # [-bps+1, -bps+2, ..., bps-2, bps-1]
+        indices = (center_bin + window).clip_(min=0, max=num_bins - 1)
+        cropped_activations = activations.gather(-1, indices)
+        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
+        return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)
+
+    raise ValueError
diff --git a/pesto/weights/mir-1k.ckpt b/pesto/weights/mir-1k.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..6fa22686ecb3f763c98d42f290527797da2bd3ae
Binary files /dev/null and b/pesto/weights/mir-1k.ckpt differ
diff --git a/pesto/weights/mir-1k.pth b/pesto/weights/mir-1k.pth
deleted file mode 100644
index aadf78ceeff86ae4e86ac305053c5bcc9791ec7f..0000000000000000000000000000000000000000
Binary files a/pesto/weights/mir-1k.pth and /dev/null differ
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..b2fb7d05dcc35b499ede3b0562769602baaab66d
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,48 @@
+[build-system]
+requires = ["setuptools"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "pesto-pitch"
+dynamic = ["version"]
+authors = [
+    {name = "Alain Riou", email = "alain.riou@sony.com"}
+]
+description = "Efficient pitch estimation with self-supervised learning"
+readme = {file = "README.md", content-type = "text/markdown"}
+requires-python = ">=3.8"
+classifiers = [
+    'Development Status :: 4 - Beta',
+    'Intended Audience :: Developers',
+    'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)',  # If licence is provided must be on the repository
+    'Programming Language :: Python :: 3.8',
+    'Programming Language :: Python :: 3.9',
+    'Programming Language :: Python :: 3.10',
+    'Programming Language :: Python :: 3.11',
+]
+dependencies = [
+    'numpy>=1.21.5',
+    'scipy>=1.8.1',
+    'tqdm>=4.66.1',
+    'torch>=2.0.1',
+    'torchaudio>=2.0.2'
+]
+
+[project.optional-dependencies]
+matplotlib = ["matplotlib"]
+test = ["pytest"]
+
+[project.scripts]
+pesto = "pesto.main:pesto"
+
+[project.urls]
+source = "https://github.com/SonyCSLParis/pesto"
+
+[tool.pytest.ini_options]
+testpaths = "tests/"
+
+[tool.setuptools.dynamic]
+version = {attr = "pesto.__version__"}
+
+[tool.setuptools.package-data]
+pesto = ["weights/*.ckpt"]
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 6eafee00df5b97574f0ba784f53a51978fc303c3..0000000000000000000000000000000000000000
--- a/setup.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from pathlib import Path
-from setuptools import setup, find_packages
-
-def get_readme_text():
-    root_dir = Path(__file__).parent
-    readme_path = root_dir / "README.md"
-    return readme_path.read_text()
-
-
-setup(
-    name='pesto-pitch',
-    version='0.1.0',
-    description='Efficient pitch estimation with self-supervised learning',
-    long_description=get_readme_text(),
-    long_description_content_type='text/markdown',
-    author='Alain Riou',
-    url='https://github.com/SonyCSLParis/pesto',
-    license='LGPL-3.0',
-    packages=find_packages(),
-    include_package_data=True,
-    package_data={
-        'pesto': ['weights/*'],  # To include the .pth
-    },
-    install_requires=[
-        'numpy>=1.21.5',
-        'scipy>=1.8.1',
-        'tqdm>=4.66.1',
-        'torch>=2.0.1',
-        'torchaudio>=2.0.2'
-    ],
-    classifiers=[
-        # 'Development Status :: 1 - Planning',
-        # 'Development Status :: 2 - Pre-Alpha',
-        # 'Development Status :: 3 - Alpha',
-        # 'Development Status :: 4 - Beta',
-        'Development Status :: 5 - Production/Stable',
-        # 'Development Status :: 6 - Mature',
-        # 'Development Status :: 7 - Inactive',
-        'Intended Audience :: Developers',
-        'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)',  # If licence is provided must be on the repository
-        'Programming Language :: Python :: 3.8',
-        'Programming Language :: Python :: 3.9',
-        'Programming Language :: Python :: 3.10',
-        'Programming Language :: Python :: 3.11',
-    ],
-    entry_points={
-        'console_scripts': [
-            'pesto=pesto.main:pesto',  # For the command line, executes function pesto() in pesto/main as 'pesto'
-        ],
-    },
-    python_requires='>=3.8',
-)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/audios/example.wav b/tests/audios/example.wav
new file mode 100644
index 0000000000000000000000000000000000000000..9f592085cff7d80a91c27b0f5920df629b36cd36
Binary files /dev/null and b/tests/audios/example.wav differ
diff --git a/tests/test_basic.py b/tests/test_basic.py
deleted file mode 100644
index 1f4096eab0b3ed39a4b697104f2f9d33ddad8dc8..0000000000000000000000000000000000000000
--- a/tests/test_basic.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import unittest
-import pesto
-
-
-class MyTestCase(unittest.TestCase):
-    def test_something(self):
-        self.assertEqual(True, True)  # add assertion here
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/tests/test_cli.py b/tests/test_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef40055cf51781976bb5f07757cbf48406ecde78
--- /dev/null
+++ b/tests/test_cli.py
@@ -0,0 +1,41 @@
+import glob
+import itertools
+import os
+
+import pytest
+
+import torch
+
+
+AUDIOS_DIR = os.path.join(os.path.dirname(__file__), "audios")
+
+
+@pytest.mark.parametrize("file, fmt, convert_to_freq",
+                         itertools.product(glob.glob(AUDIOS_DIR + "/*.wav"), ["csv", "npz", "png"], [True, False]))
+def test_cli(file, fmt, convert_to_freq):
+    if convert_to_freq:
+        suffix = ".f0." + fmt
+        option = ""
+    else:
+        suffix = ".semitones." + fmt
+        option = " -F"
+    os.system(f"pesto {file} --export_format " + fmt + option)
+    out_file = file.rsplit('.', 1)[0] + suffix
+    assert os.path.isfile(out_file)
+    os.unlink(out_file)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
+@pytest.mark.parametrize("file, fmt, convert_to_freq",
+                         itertools.product(glob.glob(AUDIOS_DIR + "/*.wav"), ["csv", "npz", "png"], [True, False]))
+def test_cli_gpu(file, fmt, convert_to_freq):
+    if convert_to_freq:
+        suffix = ".f0." + fmt
+        option = ""
+    else:
+        suffix = ".semitones." + fmt
+        option = " -F"
+    os.system(f"pesto {file} --gpu 0 --export_format " + fmt + option)
+    out_file = file.rsplit('.', 1)[0] + suffix
+    assert os.path.isfile(out_file)
+    os.unlink(out_file)
diff --git a/tests/test_performances.py b/tests/test_performances.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aed8de0b4172658236ddbfce55f15549346b7d2
--- /dev/null
+++ b/tests/test_performances.py
@@ -0,0 +1,23 @@
+import itertools
+
+import pytest
+
+import torch
+
+from pesto import load_model
+from .utils import generate_synth_data
+
+
+@pytest.fixture
+def model():
+    return load_model('mir-1k', step_size=10.)
+
+
+@pytest.mark.parametrize('pitch, sr, reduction',
+                         itertools.product(range(50, 80), [16000, 44100, 48000], ["argmax", "alwa"]))
+def test_performances(model, pitch, sr, reduction):
+    x = generate_synth_data(pitch, sr=sr)
+
+    preds, conf = model(x, sr=sr, return_activations=False)
+
+    torch.testing.assert_close(preds, torch.full_like(preds, pitch), atol=0.1, rtol=0.1)
diff --git a/tests/test_predict.py b/tests/test_predict.py
deleted file mode 100644
index 464090415c47109523e91779d4f40e19495c9cf1..0000000000000000000000000000000000000000
--- a/tests/test_predict.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO
diff --git a/tests/test_shape.py b/tests/test_shape.py
new file mode 100644
index 0000000000000000000000000000000000000000..eed58bb8e4179c5f91c242fefe4e398f574512c6
--- /dev/null
+++ b/tests/test_shape.py
@@ -0,0 +1,97 @@
+import itertools
+
+import pytest
+
+import torch
+
+from pesto import load_model, predict
+from .utils import generate_synth_data
+
+
+@pytest.fixture
+def model():
+    return load_model('mir-1k', step_size=10.)
+
+
+@pytest.fixture
+def synth_data_16k():
+    return generate_synth_data(pitch=69, duration=5., sr=16000), 16000
+
+
+@pytest.mark.parametrize('reduction', ["argmax", "mean", "alwa"])
+def test_shape_no_batch(model, synth_data_16k, reduction):
+    x, sr = synth_data_16k
+
+    model.reduction = reduction
+
+    num_samples = x.size(-1)
+
+    num_timesteps = int(num_samples * 1000 / (model.hop_size * sr)) + 1
+
+    preds, conf, activations = model(x, sr=sr, return_activations=True)
+
+    assert preds.shape == (num_timesteps,)
+    assert conf.shape == (num_timesteps,)
+    assert activations.shape == (num_timesteps, 128 * model.bins_per_semitone)
+
+
+@pytest.mark.parametrize('sr, reduction',
+                         itertools.product([16000, 44100, 48000], ["argmax", "mean", "alwa"]))
+def test_shape_batch(model, sr, reduction):
+    model.reduction = reduction
+
+    batch_size = 13
+
+    batch = torch.stack([
+        generate_synth_data(pitch=p, duration=5., sr=sr)
+        for p in range(50, 50+batch_size)
+    ])
+
+    num_timesteps = int(batch.size(-1) * 1000 / (model.hop_size * sr)) + 1
+
+    preds, conf, activations = model(batch, sr=sr, return_activations=True)
+
+    assert preds.shape == (batch_size, num_timesteps)
+    assert conf.shape == (batch_size, num_timesteps)
+    assert activations.shape == (batch_size, num_timesteps, 128 * model.bins_per_semitone)
+
+
+@pytest.mark.parametrize('step_size, reduction',
+                         itertools.product([10., 20., 50., 100], ["argmax", "mean", "alwa"]))
+def test_predict_shape_no_batch(synth_data_16k, step_size, reduction):
+    x, sr = synth_data_16k
+
+    num_samples = x.size(-1)
+
+    num_timesteps = int(num_samples * 1000 / (step_size * sr)) + 1
+
+    timesteps, preds, conf, activations = predict(x,
+                                                  sr,
+                                                  step_size=step_size,
+                                                  reduction=reduction)
+
+    assert timesteps.shape == (num_timesteps,)
+    assert preds.shape == (num_timesteps,)
+    assert conf.shape == (num_timesteps,)
+
+
+@pytest.mark.parametrize('sr, step_size, reduction',
+                         itertools.product([16000, 44100, 48000], [10., 20., 50., 100.], ["argmax", "mean", "alwa"]))
+def test_predict_shape_batch(sr, step_size, reduction):
+    batch_size = 13
+
+    batch = torch.stack([
+        generate_synth_data(pitch=p, duration=5., sr=sr)
+        for p in range(50, 50+batch_size)
+    ])
+
+    num_timesteps = int(batch.size(-1) * 1000 / (step_size * sr)) + 1
+
+    timesteps, preds, conf, activations = predict(batch,
+                                                  sr=sr,
+                                                  step_size=step_size,
+                                                  reduction=reduction)
+
+    assert timesteps.shape == (num_timesteps,)
+    assert preds.shape == (batch_size, num_timesteps)
+    assert conf.shape == (batch_size, num_timesteps)
diff --git a/tests/test_timesteps.py b/tests/test_timesteps.py
new file mode 100644
index 0000000000000000000000000000000000000000..546ee09beb2b00cfd3a0c316c88d5167c69d9c72
--- /dev/null
+++ b/tests/test_timesteps.py
@@ -0,0 +1,18 @@
+import pytest
+
+import torch
+
+from pesto import predict
+from .utils import generate_synth_data
+
+
+@pytest.fixture
+def synth_data_16k():
+    return generate_synth_data(pitch=69, duration=5., sr=16000), 16000
+
+
+@pytest.mark.parametrize('step_size', [10., 20., 50., 100])
+def test_build_timesteps(synth_data_16k, step_size):
+    timesteps, *_ = predict(*synth_data_16k, step_size=step_size)
+    diff = timesteps[1:] - timesteps[:-1]
+    torch.testing.assert_close(diff, torch.full_like(diff, step_size))
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..71210e340044aa7e463908c841b45d8e719d457e
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,20 @@
+import torch
+
+
+def mid_to_hz(pitch: int):
+    return 440 * 2 ** ((pitch - 69) / 12)
+
+
+def generate_synth_data(pitch: int, num_harmonics: int = 5, duration=2, sr=16000):
+    f0 = mid_to_hz(pitch)
+    t = torch.arange(0, duration, 1/sr)
+    harmonics = torch.stack([
+        torch.cos(2 * torch.pi * k * f0 * t + torch.rand(()))
+        for k in range(1, num_harmonics+1)
+    ], dim=1)
+    # volume = torch.randn(()) * torch.arange(num_harmonics).neg().div(0.5).exp()
+    volume = torch.rand(num_harmonics)
+    volume[0] = 1
+    volume *= torch.randn(())
+    audio = torch.sum(volume * harmonics, dim=1)
+    return audio
\ No newline at end of file