From 83cc9f5c3e2524b5ac5fd9646fd869670e8aeb81 Mon Sep 17 00:00:00 2001 From: "paul.best" <paul.best@sms-cluster.lis-lab.fr> Date: Wed, 12 Jun 2024 12:52:54 +0200 Subject: [PATCH] fixes --- src/data/audio_datamodule.py | 16 +++++++++++++--- src/models/pesto.py | 4 ++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/data/audio_datamodule.py b/src/data/audio_datamodule.py index 77c0297..d0b077d 100644 --- a/src/data/audio_datamodule.py +++ b/src/data/audio_datamodule.py @@ -59,6 +59,8 @@ class AudioDataModule(LightningDataModule): bins_per_semitone: int = 1, n_bins: int = 84, center_bins: bool = False, + min_samples: int = 0, + downsample: float = 1, batch_size: int = 256, num_workers: int = 0, pin_memory: bool = False, @@ -93,7 +95,7 @@ class AudioDataModule(LightningDataModule): self.fold = fold self.n_folds = n_folds - + self.min_samples, self.downsample = min_samples, downsample # HCQT self.hcqt_sr = None self.hcqt_kernels = None @@ -178,6 +180,7 @@ class AudioDataModule(LightningDataModule): def load_data(self, audio_files: Path, annot_files: Path | None = None) -> torch.utils.data.Dataset: cache_cqt = self.build_cqt_filename(audio_files) if cache_cqt.exists(): + print('loading cached CQT', cache_cqt) inputs = np.load(cache_cqt, mmap_mode=self.mmap_mode) cache_annot = cache_cqt.with_suffix(".csv") annotations = np.loadtxt(cache_annot, dtype=np.float32) if cache_annot.exists() else None @@ -217,15 +220,22 @@ class AudioDataModule(LightningDataModule): annot_files = [] annot_list = None - log.info("Precomputing HCQT...") + log.info(f"Precomputing HCQT for {len(audio_files)} files") pbar = tqdm(itertools.zip_longest(audio_files, annot_files, fillvalue=None), total=len(audio_files), leave=False) for fname, annot in pbar: fname = fname.strip() pbar.set_description(fname) + downsample = self.downsample[fname.split('/')[0]] if str(type(self.downsample)) == "<class 'omegaconf.dictconfig.DictConfig'>" else self.downsample + if torchaudio.info(data_dir / fname).num_frames < self.min_samples: + print(f'{fname} is too small :/') + continue + if torchaudio.info(data_dir / fname).sample_rate / downsample / 2 < self.hcqt_kwargs['fmax']: + print(f'{fname} has a too small sampling rate for the given fmax :/') + continue x, sr = torchaudio.load(data_dir / fname) - out = self.hcqt(x.mean(dim=0), sr) # convert to mono and compute HCQT + out = self.hcqt(x.mean(dim=0), sr/downsample) # convert to mono and compute HCQT if annot is not None: annot = annot.strip() diff --git a/src/models/pesto.py b/src/models/pesto.py index 638c543..bcd3830 100644 --- a/src/models/pesto.py +++ b/src/models/pesto.py @@ -171,14 +171,14 @@ class PESTO(LightningModule): def estimate_shift(self) -> None: r"""Estimate the shift to predict absolute pitches from relative activations""" # 0. Define labels - labels = torch.arange(60, 72) + labels = torch.arange(60, 96, 2) # 1. Generate synthetic audio and convert it to HCQT sr = 16000 dm = self.trainer.datamodule batch = [] for p in labels: - audio = generate_synth_data(p, sr=sr) + audio = generate_synth_data(p, sr=sr, num_harmonics=2) hcqt = dm.hcqt(audio, sr) batch.append(hcqt[0]) -- GitLab