Skip to content
Snippets Groups Projects
Commit 794b4711 authored by Alain Riou's avatar Alain Riou
Browse files

tests + minor fixes

parent d69c336a
Branches
No related tags found
No related merge requests found
...@@ -93,8 +93,7 @@ def predict_from_files( ...@@ -93,8 +93,7 @@ def predict_from_files(
export_format: Sequence[str] = ("csv",), export_format: Sequence[str] = ("csv",),
no_convert_to_freq: bool = False, no_convert_to_freq: bool = False,
num_chunks: int = 1, num_chunks: int = 1,
gpu: int = -1 gpu: int = -1):
):
r""" r"""
Args: Args:
......
...@@ -84,11 +84,11 @@ class Preprocessor(nn.Module): ...@@ -84,11 +84,11 @@ class Preprocessor(nn.Module):
# compute HCQT kernels if it does not exist or if the sampling rate has changed # compute HCQT kernels if it does not exist or if the sampling rate has changed
if sr is not None and sr != self.hcqt_sr: if sr is not None and sr != self.hcqt_sr:
self.hcqt_sr = sr self.hcqt_sr = sr
self._reset_hcqt_layer() self._reset_hcqt_kernels()
return self.hcqt_kernels(audio) return self.hcqt_kernels(audio)
def _reset_hcqt_layer(self) -> None: def _reset_hcqt_kernels(self) -> None:
hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5) hop_length = int(self.hop_size * self.hcqt_sr / 1000 + 0.5)
self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr, self.hcqt_kernels = HarmonicCQT(sr=self.hcqt_sr,
hop_length=hop_length, hop_length=hop_length,
......
...@@ -45,7 +45,7 @@ def load_model(checkpoint: str, ...@@ -45,7 +45,7 @@ def load_model(checkpoint: str,
preprocessor=preprocessor, preprocessor=preprocessor,
crop_kwargs=hparams["pitch_shift"], crop_kwargs=hparams["pitch_shift"],
reduction=hparams["reduction"]) reduction=hparams["reduction"])
model.load_state_dict(state_dict) model.load_state_dict(state_dict, strict=False)
model.eval() model.eval()
return model return model
...@@ -228,6 +228,10 @@ class PESTO(nn.Module): ...@@ -228,6 +228,10 @@ class PESTO(nn.Module):
return preds, confidence return preds, confidence
@property
def bins_per_semitone(self) -> int:
return self.preprocessor.hcqt_kwargs["bins_per_semitone"]
@property @property
def hop_size(self) -> float: def hop_size(self) -> float:
r"""Returns the hop size of the model (in milliseconds)""" r"""Returns the hop size of the model (in milliseconds)"""
......
...@@ -30,7 +30,7 @@ def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> to ...@@ -30,7 +30,7 @@ def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> to
window = torch.arange(1, 2 * bps, device=device) - bps # [-bps+1, -bps+2, ..., bps-2, bps-1] window = torch.arange(1, 2 * bps, device=device) - bps # [-bps+1, -bps+2, ..., bps-2, bps-1]
indices = (center_bin + window).clip_(min=0, max=num_bins - 1) indices = (center_bin + window).clip_(min=0, max=num_bins - 1)
cropped_activations = activations.gather(-1, indices) cropped_activations = activations.gather(-1, indices)
cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(1, indices) cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
return (cropped_activations * cropped_pitches).sum(dim=1) / cropped_activations.sum(dim=1) return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)
raise ValueError raise ValueError
import unittest
import pesto
class MyTestCase(unittest.TestCase):
def test_something(self):
self.assertEqual(True, True) # add assertion here
if __name__ == '__main__':
unittest.main()
import itertools
import pytest
import torch
from pesto import load_model
from .utils import generate_synth_data
@pytest.fixture
def model():
return load_model('mir-1k', step_size=10.)
@pytest.mark.parametrize('pitch, sr, reduction',
itertools.product(range(50, 80), [16000, 44100, 48000], ["argmax", "alwa"]))
def test_performances(model, pitch, sr, reduction):
x = generate_synth_data(pitch, sr=sr)
preds, conf = model(x, sr=sr, return_activations=False)
torch.testing.assert_close(preds, torch.full_like(preds, pitch), atol=0.1, rtol=0.1)
# TODO
import itertools
import pytest
import torch
from pesto import load_model, predict
from .utils import generate_synth_data
@pytest.fixture
def model():
return load_model('mir-1k', step_size=10.)
@pytest.fixture
def synth_data_16k():
return generate_synth_data(pitch=69, duration=5., sr=16000), 16000
@pytest.mark.parametrize('reduction', ["argmax", "mean", "alwa"])
def test_shape_no_batch(model, synth_data_16k, reduction):
x, sr = synth_data_16k
model.reduction = reduction
num_samples = x.size(-1)
num_timesteps = int(num_samples * 1000 / (model.hop_size * sr)) + 1
preds, conf, activations = model(x, sr=sr, return_activations=True)
assert preds.shape == (num_timesteps,)
assert conf.shape == (num_timesteps,)
assert activations.shape == (num_timesteps, 128 * model.bins_per_semitone)
@pytest.mark.parametrize('sr, reduction',
itertools.product([16000, 44100, 48000], ["argmax", "mean", "alwa"]))
def test_shape_batch(model, sr, reduction):
model.reduction = reduction
batch_size = 13
batch = torch.stack([
generate_synth_data(pitch=p, duration=5., sr=sr)
for p in range(50, 50+batch_size)
])
num_timesteps = int(batch.size(-1) * 1000 / (model.hop_size * sr)) + 1
preds, conf, activations = model(batch, sr=sr, return_activations=True)
assert preds.shape == (batch_size, num_timesteps)
assert conf.shape == (batch_size, num_timesteps)
assert activations.shape == (batch_size, num_timesteps, 128 * model.bins_per_semitone)
@pytest.mark.parametrize('step_size, reduction',
itertools.product([10., 20., 50., 100], ["argmax", "mean", "alwa"]))
def test_predict_shape_no_batch(synth_data_16k, step_size, reduction):
x, sr = synth_data_16k
num_samples = x.size(-1)
num_timesteps = int(num_samples * 1000 / (step_size * sr)) + 1
timesteps, preds, conf, activations = predict(x,
sr,
step_size=step_size,
reduction=reduction)
assert timesteps.shape == (num_timesteps,)
assert preds.shape == (num_timesteps,)
assert conf.shape == (num_timesteps,)
@pytest.mark.parametrize('sr, step_size, reduction',
itertools.product([16000, 44100, 48000], [10., 20., 50., 100.], ["argmax", "mean", "alwa"]))
def test_predict_shape_batch(sr, step_size, reduction):
batch_size = 13
batch = torch.stack([
generate_synth_data(pitch=p, duration=5., sr=sr)
for p in range(50, 50+batch_size)
])
num_timesteps = int(batch.size(-1) * 1000 / (step_size * sr)) + 1
timesteps, preds, conf, activations = predict(batch,
sr=sr,
step_size=step_size,
reduction=reduction)
assert timesteps.shape == (num_timesteps,)
assert preds.shape == (batch_size, num_timesteps)
assert conf.shape == (batch_size, num_timesteps)
import pytest
import torch
from pesto import predict
from .utils import generate_synth_data
@pytest.fixture
def synth_data_16k():
return generate_synth_data(pitch=69, duration=5., sr=16000), 16000
@pytest.mark.parametrize('step_size', [10., 20., 50., 100])
def test_build_timesteps(synth_data_16k, step_size):
timesteps, *_ = predict(*synth_data_16k, step_size=step_size)
diff = timesteps[1:] - timesteps[:-1]
torch.testing.assert_close(diff, torch.full_like(diff, step_size))
import torch
def mid_to_hz(pitch: int):
return 440 * 2 ** ((pitch - 69) / 12)
def generate_synth_data(pitch: int, num_harmonics: int = 5, duration=2, sr=16000):
f0 = mid_to_hz(pitch)
t = torch.arange(0, duration, 1/sr)
harmonics = torch.stack([
torch.cos(2 * torch.pi * k * f0 * t + torch.rand(()))
for k in range(1, num_harmonics+1)
], dim=1)
# volume = torch.randn(()) * torch.arange(num_harmonics).neg().div(0.5).exp()
volume = torch.rand(num_harmonics)
volume[0] = 1
volume *= torch.randn(())
audio = torch.sum(volume * harmonics, dim=1)
return audio
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment