diff --git a/README.md b/README.md index 0c7f37366b16e02c8ef856a54265635b0f0b08bb..633b184259a97f5b14d7102a8c747b581c6120c4 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ This code is the implementation of the [PESTO paper](https://arxiv.org/abs/2309. that has been accepted at [ISMIR 2023](https://ismir2023.ismir.net/). **Disclaimer:** This repository contains minimal code and should be used for inference only. -If you want full implementation details or want to use PESTO for research purposes, take a look at ~~[this repository](https://github.com/aRI0U/pesto-full)~~ (work in progress). +If you want full implementation details or want to use PESTO for research purposes, take a look at ~~[this repository](https://github.com/aRI0U/pesto-full)~~ (coming soon!). ## Installation @@ -59,7 +59,8 @@ Alternatively, one can save timesteps, pitch, confidence and activation outputs Finally, you can also visualize the pitch predictions by exporting them as a `png` file (you need `matplotlib` to be installed for PNG export). Here is an example: - + + Multiple formats can be specified after the `-e` option. @@ -81,7 +82,8 @@ Additionally, audio files can have any sampling rate; no resampling is required. By default, the model returns a probability distribution over all pitch bins. To convert it to a proper pitch, by default, we use Argmax-Local Weighted Averaging as in CREPE: - + + Alternatively, one can use basic argmax of weighted average with option `-r`/`--reduction`. @@ -150,11 +152,11 @@ Note that batched predictions are available only from the Python API and not fro ## Performances -On [MIR-1K]() and [MDB-stem-synth](), PESTO outperforms other self-supervised baselines. +On [MIR-1K](https://zenodo.org/record/3532216#.ZG0kWhlBxhE) and [MDB-stem-synth](https://zenodo.org/records/1481172), PESTO outperforms other self-supervised baselines. Its performances are close to CREPE's, which has 800x more parameters and was trained in a supervised way on a vast dataset containing MIR-1K and MDB-stem-synth, among others. - + ## Speed benchmark @@ -165,7 +167,8 @@ granularity of the predictions, which can be controlled with the `--step_size` p Here is a speed comparison between CREPE and PESTO, averaged over 10 runs on the same machine. - + + - Audio file: `wav` format, 2m51s - Hardware: 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz, 8 cores diff --git a/pesto/data.py b/pesto/data.py index ea053726694aa78a6645a7d0c2223b02d81a2f2a..4c51bafa830fdaa26b3645ecf7668c22e5049c31 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn as nn @@ -10,10 +12,12 @@ class DataProcessor(nn.Module): Args: step_size (float): step size between consecutive CQT frames (in SECONDS) """ + _sampling_rate: Optional[int] = None + def __init__(self, step_size: float, bins_per_semitone: int = 3, - device: torch.device = torch.device("cpu"), + sampling_rate: int = 44100, **cqt_kwargs): super(DataProcessor, self).__init__() self.bins_per_semitone = bins_per_semitone @@ -31,9 +35,8 @@ class DataProcessor(nn.Module): self.highest_bin = self.lowest_bin + 88 * self.bins_per_semitone # handling different sampling rates - self._sampling_rate = None self.step_size = step_size - self.device = device + self.sampling_rate = sampling_rate def forward(self, x: torch.Tensor): r""" @@ -58,7 +61,11 @@ class DataProcessor(nn.Module): return log_cqt def _init_cqt_layer(self, sr: int, hop_length: int): - self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(self.device) + if self.cqt is not None: + device = self.cqt.cqt_kernels_real.device + else: + device = None + self.cqt = CQT(sr=sr, hop_length=hop_length, **self.cqt_kwargs).to(device) @property def sampling_rate(self) -> int: diff --git a/pesto/utils.py b/pesto/utils.py index 46281ebc1f4fff05c5c10fd14bd056244482039b..f3fbe10ce1d630ae41153ffa4612692e15ef06a1 100644 --- a/pesto/utils.py +++ b/pesto/utils.py @@ -9,7 +9,7 @@ from .model import PESTOEncoder def load_dataprocessor(step_size, device: Optional[torch.device] = None): - return DataProcessor(step_size=step_size, device=device, **cqt_args).to(device) + return DataProcessor(step_size=step_size, **cqt_args).to(device) def load_model(model_name: str, device: Optional[torch.device] = None) -> PESTOEncoder: diff --git a/setup.py b/setup.py index 600fa5e2bee0ae1c63d1c61474facd64abafe7f5..6eafee00df5b97574f0ba784f53a51978fc303c3 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ def get_readme_text(): setup( name='pesto-pitch', - version='1.0.0', + version='0.1.0', description='Efficient pitch estimation with self-supervised learning', long_description=get_readme_text(), long_description_content_type='text/markdown',