Skip to content
Snippets Groups Projects
Unverified Commit 8dc3badf authored by Alain Riou's avatar Alain Riou Committed by GitHub
Browse files

Merge pull request #25 from SonyCSLParis/roll_activations

Roll activations instead of substracting shift, so that activations and predictions are correct
parents 2d110510 834ffa8b
Branches
No related tags found
No related merge requests found
from .core import load_model, predict, predict_from_files from .core import load_model, predict, predict_from_files
__version__ = '1.1.0'
...@@ -146,4 +146,3 @@ def predict_from_files( ...@@ -146,4 +146,3 @@ def predict_from_files(
predictions = [p.cpu().numpy() for p in predictions] predictions = [p.cpu().numpy() for p in predictions]
for fmt in export_format: for fmt in export_format:
export(fmt, output_file, *predictions) export(fmt, output_file, *predictions)
...@@ -215,10 +215,9 @@ class PESTO(nn.Module): ...@@ -215,10 +215,9 @@ class PESTO(nn.Module):
if batch_size: if batch_size:
activations = activations.view(batch_size, -1, activations.size(-1)) activations = activations.view(batch_size, -1, activations.size(-1))
preds = reduce_activations(activations, reduction=self.reduction) activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1)
# decrease by shift to get absolute pitch preds = reduce_activations(activations, reduction=self.reduction)
preds.sub_(self.shift)
if convert_to_freq: if convert_to_freq:
preds = 440 * 2 ** ((preds - 69) / 12) preds = 440 * 2 ** ((preds - 69) / 12)
......
...@@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1 ...@@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1
activations = activations * confidence[:, None] activations = activations * confidence[:, None]
plt.imshow(activations.T, plt.imshow(activations.T,
aspect='auto', origin='lower', cmap='inferno', aspect='auto', origin='lower', cmap='inferno',
extent=(timesteps[0], timesteps[-1]) + lims) extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims)
plt.xlabel("Time (s)")
plt.ylabel("Pitch (semitones)")
plt.title(output_file.rsplit('.', 2)[0]) plt.title(output_file.rsplit('.', 2)[0])
plt.tight_layout() plt.tight_layout()
plt.savefig(output_file) plt.savefig(output_file)
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "pesto-pitch" name = "pesto-pitch"
version = "1.0.0" dynamic = ["version"]
authors = [ authors = [
{name = "Alain Riou", email = "alain.riou@sony.com"} {name = "Alain Riou", email = "alain.riou@sony.com"}
] ]
...@@ -41,5 +41,8 @@ source = "https://github.com/SonyCSLParis/pesto" ...@@ -41,5 +41,8 @@ source = "https://github.com/SonyCSLParis/pesto"
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = "tests/" testpaths = "tests/"
[tool.setuptools.dynamic]
version = {attr = "pesto.__version__"}
[tool.setuptools.package-data] [tool.setuptools.package-data]
pesto = ["weights/*.ckpt"] pesto = ["weights/*.ckpt"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment