diff --git a/pesto/data.py b/pesto/data.py index 7d8b81d77b2a982c559e258d3f7c8cde301f6d75..652ebb4d8edbc2a80b207b5f19c5bf47bcd281c1 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -63,6 +63,7 @@ class Preprocessor(nn.Module): # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins) # in other words, time becomes the batch dimension, enabling efficient processing for long audios. complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2) + complex_cqt.squeeze_(0) # convert to dB return self.to_log(complex_cqt) diff --git a/pesto/loader.py b/pesto/loader.py index 609d4b5df3cc4b9e2e1c891034b9479379cb774e..94c519fb5f3c3b1c6f7bafabba35b0bb659dbfa4 100644 --- a/pesto/loader.py +++ b/pesto/loader.py @@ -24,7 +24,7 @@ def load_model(checkpoint: str, if os.path.exists(checkpoint): # handle user-provided checkpoints model_path = checkpoint else: - model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".pth") + model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".ckpt") if not os.path.exists(model_path): raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.") diff --git a/pyproject.toml b/pyproject.toml index 8a3e4168f6ec4a67997d8e9f5e948442584acf71..6464d322af62df07aeccc2ccf85ecdbde83a636b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.0" authors = [ {name = "Alain Riou", email = "alain.riou@sony.com"} ] -description = "Efficient pitch estimation with self-supervised learning", +description = "Efficient pitch estimation with self-supervised learning" readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.8" classifiers = [ @@ -33,13 +33,13 @@ matplotlib = ["matplotlib"] test = ["pytest"] [project.scripts] -pesto = "pesto.main.py:pesto" +pesto = "pesto.main:pesto" [project.urls] source = "https://github.com/SonyCSLParis/pesto" -[tool.setuptools.packages.find] -where = ["pesto"] +[tool.pytest.ini_options] +testpaths = "tests/" [tool.setuptools.package-data] -weights = ["*.ckpt"] +pesto = ["weights/*.ckpt"]