From d69c336a684abd4a923ce9b59974ce0606801a71 Mon Sep 17 00:00:00 2001 From: Alain Riou <alain.riou14000@yahoo.com> Date: Mon, 15 Jan 2024 19:50:50 +0100 Subject: [PATCH] fix pyproject.toml --- pesto/data.py | 1 + pesto/loader.py | 2 +- pyproject.toml | 10 +++++----- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pesto/data.py b/pesto/data.py index 7d8b81d..652ebb4 100644 --- a/pesto/data.py +++ b/pesto/data.py @@ -63,6 +63,7 @@ class Preprocessor(nn.Module): # compute CQT from input waveform, and invert dims for (time_steps, num_harmonics, freq_bins) # in other words, time becomes the batch dimension, enabling efficient processing for long audios. complex_cqt = torch.view_as_complex(self.hcqt(x, sr=sr)).permute(0, 3, 1, 2) + complex_cqt.squeeze_(0) # convert to dB return self.to_log(complex_cqt) diff --git a/pesto/loader.py b/pesto/loader.py index 609d4b5..94c519f 100644 --- a/pesto/loader.py +++ b/pesto/loader.py @@ -24,7 +24,7 @@ def load_model(checkpoint: str, if os.path.exists(checkpoint): # handle user-provided checkpoints model_path = checkpoint else: - model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".pth") + model_path = os.path.join(os.path.dirname(__file__), "weights", checkpoint + ".ckpt") if not os.path.exists(model_path): raise FileNotFoundError(f"You passed an invalid checkpoint file: {checkpoint}.") diff --git a/pyproject.toml b/pyproject.toml index 8a3e416..6464d32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.0" authors = [ {name = "Alain Riou", email = "alain.riou@sony.com"} ] -description = "Efficient pitch estimation with self-supervised learning", +description = "Efficient pitch estimation with self-supervised learning" readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.8" classifiers = [ @@ -33,13 +33,13 @@ matplotlib = ["matplotlib"] test = ["pytest"] [project.scripts] -pesto = "pesto.main.py:pesto" +pesto = "pesto.main:pesto" [project.urls] source = "https://github.com/SonyCSLParis/pesto" -[tool.setuptools.packages.find] -where = ["pesto"] +[tool.pytest.ini_options] +testpaths = "tests/" [tool.setuptools.package-data] -weights = ["*.ckpt"] +pesto = ["weights/*.ckpt"] -- GitLab