From 83cc9f5c3e2524b5ac5fd9646fd869670e8aeb81 Mon Sep 17 00:00:00 2001
From: "paul.best" <paul.best@sms-cluster.lis-lab.fr>
Date: Wed, 12 Jun 2024 12:52:54 +0200
Subject: [PATCH] fixes

---
 src/data/audio_datamodule.py | 16 +++++++++++++---
 src/models/pesto.py          |  4 ++--
 2 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/src/data/audio_datamodule.py b/src/data/audio_datamodule.py
index 77c0297..d0b077d 100644
--- a/src/data/audio_datamodule.py
+++ b/src/data/audio_datamodule.py
@@ -59,6 +59,8 @@ class AudioDataModule(LightningDataModule):
                  bins_per_semitone: int = 1,
                  n_bins: int = 84,
                  center_bins: bool = False,
+                 min_samples: int = 0,
+                 downsample: float = 1,
                  batch_size: int = 256,
                  num_workers: int = 0,
                  pin_memory: bool = False,
@@ -93,7 +95,7 @@ class AudioDataModule(LightningDataModule):
 
         self.fold = fold
         self.n_folds = n_folds
-
+        self.min_samples, self.downsample = min_samples, downsample
         # HCQT
         self.hcqt_sr = None
         self.hcqt_kernels = None
@@ -178,6 +180,7 @@ class AudioDataModule(LightningDataModule):
     def load_data(self, audio_files: Path, annot_files: Path | None = None) -> torch.utils.data.Dataset:
         cache_cqt = self.build_cqt_filename(audio_files)
         if cache_cqt.exists():
+            print('loading cached CQT', cache_cqt)
             inputs = np.load(cache_cqt, mmap_mode=self.mmap_mode)
             cache_annot = cache_cqt.with_suffix(".csv")
             annotations = np.loadtxt(cache_annot, dtype=np.float32) if cache_annot.exists() else None
@@ -217,15 +220,22 @@ class AudioDataModule(LightningDataModule):
             annot_files = []
             annot_list = None
 
-        log.info("Precomputing HCQT...")
+        log.info(f"Precomputing HCQT for {len(audio_files)} files")
         pbar = tqdm(itertools.zip_longest(audio_files, annot_files, fillvalue=None),
                     total=len(audio_files),
                     leave=False)
         for fname, annot in pbar:
             fname = fname.strip()
             pbar.set_description(fname)
+            downsample = self.downsample[fname.split('/')[0]] if str(type(self.downsample)) == "<class 'omegaconf.dictconfig.DictConfig'>" else self.downsample
+            if torchaudio.info(data_dir / fname).num_frames < self.min_samples:
+                print(f'{fname} is too small :/')
+                continue
+            if torchaudio.info(data_dir / fname).sample_rate / downsample / 2 < self.hcqt_kwargs['fmax']:
+                print(f'{fname} has a too small sampling rate for the given fmax :/')
+                continue
             x, sr = torchaudio.load(data_dir / fname)
-            out = self.hcqt(x.mean(dim=0), sr)  # convert to mono and compute HCQT
+            out = self.hcqt(x.mean(dim=0), sr/downsample)  # convert to mono and compute HCQT
 
             if annot is not None:
                 annot = annot.strip()
diff --git a/src/models/pesto.py b/src/models/pesto.py
index 638c543..bcd3830 100644
--- a/src/models/pesto.py
+++ b/src/models/pesto.py
@@ -171,14 +171,14 @@ class PESTO(LightningModule):
     def estimate_shift(self) -> None:
         r"""Estimate the shift to predict absolute pitches from relative activations"""
         # 0. Define labels
-        labels = torch.arange(60, 72)
+        labels = torch.arange(60, 96, 2)
 
         # 1. Generate synthetic audio and convert it to HCQT
         sr = 16000
         dm = self.trainer.datamodule
         batch = []
         for p in labels:
-            audio = generate_synth_data(p, sr=sr)
+            audio = generate_synth_data(p, sr=sr, num_harmonics=2)
             hcqt = dm.hcqt(audio, sr)
             batch.append(hcqt[0])
 
-- 
GitLab