Skip to content
Snippets Groups Projects
Unverified Commit 8201e4f7 authored by Alain Riou's avatar Alain Riou Committed by GitHub
Browse files

Merge pull request #19 from ben-hayes/bugfix/dataprocessor-device

Persist CQT device
parents 5cac507c 752df3b7
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ This code is the implementation of the [PESTO paper](https://arxiv.org/abs/2309.
that has been accepted at [ISMIR 2023](https://ismir2023.ismir.net/).
**Disclaimer:** This repository contains minimal code and should be used for inference only.
If you want full implementation details or want to use PESTO for research purposes, take a look at ~~[this repository](https://github.com/aRI0U/pesto-full)~~ (work in progress).
If you want full implementation details or want to use PESTO for research purposes, take a look at ~~[this repository](https://github.com/aRI0U/pesto-full)~~ (coming soon!).
## Installation
......@@ -59,7 +59,8 @@ Alternatively, one can save timesteps, pitch, confidence and activation outputs
Finally, you can also visualize the pitch predictions by exporting them as a `png` file (you need `matplotlib` to be installed for PNG export). Here is an example:
![example f0](https://github.com/SonyCSLParis/pesto/assets/36546630/2ad82c86-136a-4125-bf47-ea1b93408022)
![example f0](https://github.com/SonyCSLParis/pesto/assets/36546630/5aa18c23-0154-4d2d-8021-2c23277b27a3)
Multiple formats can be specified after the `-e` option.
......@@ -81,7 +82,8 @@ Additionally, audio files can have any sampling rate; no resampling is required.
By default, the model returns a probability distribution over all pitch bins.
To convert it to a proper pitch, by default, we use Argmax-Local Weighted Averaging as in CREPE:
![image](https://github.com/SonyCSLParis/pesto/assets/36546630/7d06bf85-585c-401f-a3c2-f2fab90dd1a7)
![image](https://github.com/SonyCSLParis/pesto/assets/36546630/3138c33f-672a-477f-95a9-acaacf4418ab)
Alternatively, one can use basic argmax of weighted average with option `-r`/`--reduction`.
......@@ -150,11 +152,11 @@ Note that batched predictions are available only from the Python API and not fro
## Performances
On [MIR-1K]() and [MDB-stem-synth](), PESTO outperforms other self-supervised baselines.
On [MIR-1K](https://zenodo.org/record/3532216#.ZG0kWhlBxhE) and [MDB-stem-synth](https://zenodo.org/records/1481172), PESTO outperforms other self-supervised baselines.
Its performances are close to CREPE's, which has 800x more parameters and was trained in a supervised way on a vast
dataset containing MIR-1K and MDB-stem-synth, among others.
![image](https://github.com/SonyCSLParis/pesto/assets/36546630/9fbf15ef-7af9-4cd5-8832-f8fc24d43f25)
![image](https://github.com/SonyCSLParis/pesto/assets/36546630/d6ae0306-ba8b-465a-8ca7-f916479a0ba5)
## Speed benchmark
......@@ -165,7 +167,8 @@ granularity of the predictions, which can be controlled with the `--step_size` p
Here is a speed comparison between CREPE and PESTO, averaged over 10 runs on the same machine.
![speed](https://github.com/SonyCSLParis/pesto/assets/36546630/8353c93d-e79f-497d-a09e-d8762e9a5cbc)
![speed](https://github.com/SonyCSLParis/pesto/assets/36546630/612b1850-c2cf-4df1-9824-b8460a2f9148)
- Audio file: `wav` format, 2m51s
- Hardware: 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz, 8 cores
......
from typing import Optional
import torch
import torch.nn as nn
......@@ -10,10 +12,12 @@ class DataProcessor(nn.Module):
Args:
step_size (float): step size between consecutive CQT frames (in SECONDS)
"""
_sampling_rate: Optional[int] = None
def __init__(self,
step_size: float,
bins_per_semitone: int = 3,
device: torch.device = torch.device("cpu"),
sampling_rate: int = 44100,
**cqt_kwargs):
super(DataProcessor, self).__init__()
self.bins_per_semitone = bins_per_semitone
......@@ -31,9 +35,8 @@ class DataProcessor(nn.Module):
self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone
# handling different sampling rates
self._sampling_rate = None
self.step_size = step_size
self.device = device
self.sampling_rate = sampling_rate
def forward(self, x: torch.Tensor):
r"""
......@@ -58,7 +61,11 @@ class DataProcessor(nn.Module):
return log_cqt
def _init_cqt_layer(self, sr: int, hop_length: int):
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device)
if self.cqt is not None:
device = self.cqt.cqt_kernels_real.device
else:
device = None
self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device)
@property
def sampling_rate(self) -> int:
......
......@@ -9,7 +9,7 @@ from .model import PESTOEncoder
def load_dataprocessor(step_size, device: Optional[torch.device] = None):
return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device)
return DataProcessor(step_size=step_size, **cqt_args).to(device)
def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder:
......
......@@ -9,7 +9,7 @@ def get_readme_text():
setup(
name='pesto-pitch',
version='1.0.0',
version='0.1.0',
description='Efficient pitch estimation with self-supervised learning',
long_description=get_readme_text(),
long_description_content_type='text/markdown',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment