Skip to content
Snippets Groups Projects
Commit 83cc9f5c authored by paul.best's avatar paul.best
Browse files

fixes

parent 229f78bd
No related branches found
No related tags found
No related merge requests found
......@@ -59,6 +59,8 @@ class AudioDataModule(LightningDataModule):
bins_per_semitone: int = 1,
n_bins: int = 84,
center_bins: bool = False,
min_samples: int = 0,
downsample: float = 1,
batch_size: int = 256,
num_workers: int = 0,
pin_memory: bool = False,
......@@ -93,7 +95,7 @@ class AudioDataModule(LightningDataModule):
self.fold = fold
self.n_folds = n_folds
self.min_samples, self.downsample = min_samples, downsample
# HCQT
self.hcqt_sr = None
self.hcqt_kernels = None
......@@ -178,6 +180,7 @@ class AudioDataModule(LightningDataModule):
def load_data(self, audio_files: Path, annot_files: Path | None = None) -> torch.utils.data.Dataset:
cache_cqt = self.build_cqt_filename(audio_files)
if cache_cqt.exists():
print('loading cached CQT', cache_cqt)
inputs = np.load(cache_cqt, mmap_mode=self.mmap_mode)
cache_annot = cache_cqt.with_suffix(".csv")
annotations = np.loadtxt(cache_annot, dtype=np.float32) if cache_annot.exists() else None
......@@ -217,15 +220,22 @@ class AudioDataModule(LightningDataModule):
annot_files = []
annot_list = None
log.info("Precomputing HCQT...")
log.info(f"Precomputing HCQT for {len(audio_files)} files")
pbar = tqdm(itertools.zip_longest(audio_files, annot_files, fillvalue=None),
total=len(audio_files),
leave=False)
for fname, annot in pbar:
fname = fname.strip()
pbar.set_description(fname)
downsample = self.downsample[fname.split('/')[0]] if str(type(self.downsample)) == "<class 'omegaconf.dictconfig.DictConfig'>" else self.downsample
if torchaudio.info(data_dir / fname).num_frames < self.min_samples:
print(f'{fname} is too small :/')
continue
if torchaudio.info(data_dir / fname).sample_rate / downsample / 2 < self.hcqt_kwargs['fmax']:
print(f'{fname} has a too small sampling rate for the given fmax :/')
continue
x, sr = torchaudio.load(data_dir / fname)
out = self.hcqt(x.mean(dim=0), sr) # convert to mono and compute HCQT
out = self.hcqt(x.mean(dim=0), sr/downsample) # convert to mono and compute HCQT
if annot is not None:
annot = annot.strip()
......
......@@ -171,14 +171,14 @@ class PESTO(LightningModule):
def estimate_shift(self) -> None:
r"""Estimate the shift to predict absolute pitches from relative activations"""
# 0. Define labels
labels = torch.arange(60, 72)
labels = torch.arange(60, 96, 2)
# 1. Generate synthetic audio and convert it to HCQT
sr = 16000
dm = self.trainer.datamodule
batch = []
for p in labels:
audio = generate_synth_data(p, sr=sr)
audio = generate_synth_data(p, sr=sr, num_harmonics=2)
hcqt = dm.hcqt(audio, sr)
batch.append(hcqt[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment