diff --git a/pesto/core.py b/pesto/core.py index 899baa6f933fb7d8403e8a5b0adfaa3b2aaa3f0e..7c3522510d142d98d7df2310e6f337d468f6dd6f 100644 --- a/pesto/core.py +++ b/pesto/core.py @@ -93,8 +93,7 @@ def predict_from_files( export_format: Sequence[str] = ("csv",), no_convert_to_freq: bool = False, num_chunks: int = 1, - gpu: int = -1 -): + gpu: int = -1): r""" Args: diff --git a/pesto/data.py b/pesto/data.py index 652ebb4d8edbc2a80b207b5f19c5bf47bcd281c1..3044bbdf3fcfc0818b198555f7711623517d8e62 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -84,11 +84,11 @@ class Preprocessor(nn.Module): # compute HCQT kernels if it does not exist or if the sampling rate has changed if sr is not None and sr != self.hcqt_sr: self.hcqt_sr = sr - self._reset_hcqt_layer() + self._reset_hcqt_kernels() return self.hcqt_kernels(audio) - def _reset_hcqt_layer(self) -> None: + def _reset_hcqt_kernels(self) -> None: hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5) self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr, hop_length=hop_length, diff --git a/pesto/loader.py b/pesto/loader.py index 94c519fb5f3c3b1c6f7bafabba35b0bb659dbfa4..a270615997197cfa4abdcfeef30de6a68d67be98 100644 --- a/pesto/loader.py +++ b/pesto/loader.py @@ -45,7 +45,7 @@ def load_model(checkpoint: str, preprocessor=preprocessor, crop_kwargs=hparams["pitch_shift"], reduction=hparams["reduction"]) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) model.eval() return model diff --git a/pesto/model.py b/pesto/model.py index 27edd932faa90d51e8d23f5a55dde498a33fbb0d..27241271f15211799e8573bac077f56fa95f400d 100644 --- a/pesto/model.py +++ b/pesto/model.py @@ -228,6 +228,10 @@ class PESTO(nn.Module): return preds, confidence + @property + def bins_per_semitone(self) -> int: + return self.preprocessor.hcqt_kwargs["bins_per_semitone"] + @property def hop_size(self) -> float: r"""Returns the hop size of the model (in milliseconds)""" diff --git a/pesto/utils/reduce_activations.py b/pesto/utils/reduce_activations.py index 191cb4f060ce53e1081b8696b5408ab0af1d6b28..b22dc508c2c0299b03c3038c9c50bf282bf1b76f 100644 --- a/pesto/utils/reduce_activations.py +++ b/pesto/utils/reduce_activations.py @@ -30,7 +30,7 @@ def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> to window = torch.arange(1, 2 * bps, device=device) - bps # [-bps+1, -bps+2, ..., bps-2, bps-1] indices = (center_bin + window).clip_(min=0, max=num_bins - 1) cropped_activations = activations.gather(-1, indices) - cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(1, indices) - return (cropped_activations * cropped_pitches).sum(dim=1) / cropped_activations.sum(dim=1) + cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices) + return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1) raise ValueError diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index 1f4096eab0b3ed39a4b697104f2f9d33ddad8dc8..0000000000000000000000000000000000000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest -import pesto - - -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, True) # add assertion here - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_performances.py b/tests/test_performances.py new file mode 100644 index 0000000000000000000000000000000000000000..9aed8de0b4172658236ddbfce55f15549346b7d2 --- /dev/null +++ b/tests/test_performances.py @@ -0,0 +1,23 @@ +import itertools + +import pytest + +import torch + +from pesto import load_model +from .utils import generate_synth_data + + +@pytest.fixture +def model(): + return load_model('mir-1k', step_size=10.) + + +@pytest.mark.parametrize('pitch, sr, reduction', + itertools.product(range(50, 80), [16000, 44100, 48000], ["argmax", "alwa"])) +def test_performances(model, pitch, sr, reduction): + x = generate_synth_data(pitch, sr=sr) + + preds, conf = model(x, sr=sr, return_activations=False) + + torch.testing.assert_close(preds, torch.full_like(preds, pitch), atol=0.1, rtol=0.1) diff --git a/tests/test_predict.py b/tests/test_predict.py deleted file mode 100644 index 464090415c47109523e91779d4f40e19495c9cf1..0000000000000000000000000000000000000000 --- a/tests/test_predict.py +++ /dev/null @@ -1 +0,0 @@ -# TODO diff --git a/tests/test_shape.py b/tests/test_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..eed58bb8e4179c5f91c242fefe4e398f574512c6 --- /dev/null +++ b/tests/test_shape.py @@ -0,0 +1,97 @@ +import itertools + +import pytest + +import torch + +from pesto import load_model, predict +from .utils import generate_synth_data + + +@pytest.fixture +def model(): + return load_model('mir-1k', step_size=10.) + + +@pytest.fixture +def synth_data_16k(): + return generate_synth_data(pitch=69, duration=5., sr=16000), 16000 + + +@pytest.mark.parametrize('reduction', ["argmax", "mean", "alwa"]) +def test_shape_no_batch(model, synth_data_16k, reduction): + x, sr = synth_data_16k + + model.reduction = reduction + + num_samples = x.size(-1) + + num_timesteps = int(num_samples * 1000 / (model.hop_size * sr)) + 1 + + preds, conf, activations = model(x, sr=sr, return_activations=True) + + assert preds.shape == (num_timesteps,) + assert conf.shape == (num_timesteps,) + assert activations.shape == (num_timesteps, 128 * model.bins_per_semitone) + + +@pytest.mark.parametrize('sr, reduction', + itertools.product([16000, 44100, 48000], ["argmax", "mean", "alwa"])) +def test_shape_batch(model, sr, reduction): + model.reduction = reduction + + batch_size = 13 + + batch = torch.stack([ + generate_synth_data(pitch=p, duration=5., sr=sr) + for p in range(50, 50+batch_size) + ]) + + num_timesteps = int(batch.size(-1) * 1000 / (model.hop_size * sr)) + 1 + + preds, conf, activations = model(batch, sr=sr, return_activations=True) + + assert preds.shape == (batch_size, num_timesteps) + assert conf.shape == (batch_size, num_timesteps) + assert activations.shape == (batch_size, num_timesteps, 128 * model.bins_per_semitone) + + +@pytest.mark.parametrize('step_size, reduction', + itertools.product([10., 20., 50., 100], ["argmax", "mean", "alwa"])) +def test_predict_shape_no_batch(synth_data_16k, step_size, reduction): + x, sr = synth_data_16k + + num_samples = x.size(-1) + + num_timesteps = int(num_samples * 1000 / (step_size * sr)) + 1 + + timesteps, preds, conf, activations = predict(x, + sr, + step_size=step_size, + reduction=reduction) + + assert timesteps.shape == (num_timesteps,) + assert preds.shape == (num_timesteps,) + assert conf.shape == (num_timesteps,) + + +@pytest.mark.parametrize('sr, step_size, reduction', + itertools.product([16000, 44100, 48000], [10., 20., 50., 100.], ["argmax", "mean", "alwa"])) +def test_predict_shape_batch(sr, step_size, reduction): + batch_size = 13 + + batch = torch.stack([ + generate_synth_data(pitch=p, duration=5., sr=sr) + for p in range(50, 50+batch_size) + ]) + + num_timesteps = int(batch.size(-1) * 1000 / (step_size * sr)) + 1 + + timesteps, preds, conf, activations = predict(batch, + sr=sr, + step_size=step_size, + reduction=reduction) + + assert timesteps.shape == (num_timesteps,) + assert preds.shape == (batch_size, num_timesteps) + assert conf.shape == (batch_size, num_timesteps) diff --git a/tests/test_timesteps.py b/tests/test_timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..546ee09beb2b00cfd3a0c316c88d5167c69d9c72 --- /dev/null +++ b/tests/test_timesteps.py @@ -0,0 +1,18 @@ +import pytest + +import torch + +from pesto import predict +from .utils import generate_synth_data + + +@pytest.fixture +def synth_data_16k(): + return generate_synth_data(pitch=69, duration=5., sr=16000), 16000 + + +@pytest.mark.parametrize('step_size', [10., 20., 50., 100]) +def test_build_timesteps(synth_data_16k, step_size): + timesteps, *_ = predict(*synth_data_16k, step_size=step_size) + diff = timesteps[1:] - timesteps[:-1] + torch.testing.assert_close(diff, torch.full_like(diff, step_size)) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71210e340044aa7e463908c841b45d8e719d457e --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,20 @@ +import torch + + +def mid_to_hz(pitch: int): + return 440 * 2 ** ((pitch - 69) / 12) + + +def generate_synth_data(pitch: int, num_harmonics: int = 5, duration=2, sr=16000): + f0 = mid_to_hz(pitch) + t = torch.arange(0, duration, 1/sr) + harmonics = torch.stack([ + torch.cos(2 * torch.pi * k * f0 * t + torch.rand(())) + for k in range(1, num_harmonics+1) + ], dim=1) + # volume = torch.randn(()) * torch.arange(num_harmonics).neg().div(0.5).exp() + volume = torch.rand(num_harmonics) + volume[0] = 1 + volume *= torch.randn(()) + audio = torch.sum(volume * harmonics, dim=1) + return audio \ No newline at end of file