Skip to content
Snippets Groups Projects
Commit 83cc9f5c authored by paul.best's avatar paul.best
Browse files

fixes

parent 229f78bd
No related branches found
No related tags found
No related merge requests found
...@@ -59,6 +59,8 @@ class AudioDataModule(LightningDataModule): ...@@ -59,6 +59,8 @@ class AudioDataModule(LightningDataModule):
bins_per_semitone: int = 1, bins_per_semitone: int = 1,
n_bins: int = 84, n_bins: int = 84,
center_bins: bool = False, center_bins: bool = False,
min_samples: int = 0,
downsample: float = 1,
batch_size: int = 256, batch_size: int = 256,
num_workers: int = 0, num_workers: int = 0,
pin_memory: bool = False, pin_memory: bool = False,
...@@ -93,7 +95,7 @@ class AudioDataModule(LightningDataModule): ...@@ -93,7 +95,7 @@ class AudioDataModule(LightningDataModule):
self.fold = fold self.fold = fold
self.n_folds = n_folds self.n_folds = n_folds
self.min_samples, self.downsample = min_samples, downsample
# HCQT # HCQT
self.hcqt_sr = None self.hcqt_sr = None
self.hcqt_kernels = None self.hcqt_kernels = None
...@@ -178,6 +180,7 @@ class AudioDataModule(LightningDataModule): ...@@ -178,6 +180,7 @@ class AudioDataModule(LightningDataModule):
def load_data(self, audio_files: Path, annot_files: Path | None = None) -> torch.utils.data.Dataset: def load_data(self, audio_files: Path, annot_files: Path | None = None) -> torch.utils.data.Dataset:
cache_cqt = self.build_cqt_filename(audio_files) cache_cqt = self.build_cqt_filename(audio_files)
if cache_cqt.exists(): if cache_cqt.exists():
print('loading cached CQT', cache_cqt)
inputs = np.load(cache_cqt, mmap_mode=self.mmap_mode) inputs = np.load(cache_cqt, mmap_mode=self.mmap_mode)
cache_annot = cache_cqt.with_suffix(".csv") cache_annot = cache_cqt.with_suffix(".csv")
annotations = np.loadtxt(cache_annot, dtype=np.float32) if cache_annot.exists() else None annotations = np.loadtxt(cache_annot, dtype=np.float32) if cache_annot.exists() else None
...@@ -217,15 +220,22 @@ class AudioDataModule(LightningDataModule): ...@@ -217,15 +220,22 @@ class AudioDataModule(LightningDataModule):
annot_files = [] annot_files = []
annot_list = None annot_list = None
log.info("Precomputing HCQT...") log.info(f"Precomputing HCQT for {len(audio_files)} files")
pbar = tqdm(itertools.zip_longest(audio_files, annot_files, fillvalue=None), pbar = tqdm(itertools.zip_longest(audio_files, annot_files, fillvalue=None),
total=len(audio_files), total=len(audio_files),
leave=False) leave=False)
for fname, annot in pbar: for fname, annot in pbar:
fname = fname.strip() fname = fname.strip()
pbar.set_description(fname) pbar.set_description(fname)
downsample = self.downsample[fname.split('/')[0]] if str(type(self.downsample)) == "<class 'omegaconf.dictconfig.DictConfig'>" else self.downsample
if torchaudio.info(data_dir / fname).num_frames < self.min_samples:
print(f'{fname} is too small :/')
continue
if torchaudio.info(data_dir / fname).sample_rate / downsample / 2 < self.hcqt_kwargs['fmax']:
print(f'{fname} has a too small sampling rate for the given fmax :/')
continue
x, sr = torchaudio.load(data_dir / fname) x, sr = torchaudio.load(data_dir / fname)
out = self.hcqt(x.mean(dim=0), sr) # convert to mono and compute HCQT out = self.hcqt(x.mean(dim=0), sr/downsample) # convert to mono and compute HCQT
if annot is not None: if annot is not None:
annot = annot.strip() annot = annot.strip()
......
...@@ -171,14 +171,14 @@ class PESTO(LightningModule): ...@@ -171,14 +171,14 @@ class PESTO(LightningModule):
def estimate_shift(self) -> None: def estimate_shift(self) -> None:
r"""Estimate the shift to predict absolute pitches from relative activations""" r"""Estimate the shift to predict absolute pitches from relative activations"""
# 0. Define labels # 0. Define labels
labels = torch.arange(60, 72) labels = torch.arange(60, 96, 2)
# 1. Generate synthetic audio and convert it to HCQT # 1. Generate synthetic audio and convert it to HCQT
sr = 16000 sr = 16000
dm = self.trainer.datamodule dm = self.trainer.datamodule
batch = [] batch = []
for p in labels: for p in labels:
audio = generate_synth_data(p, sr=sr) audio = generate_synth_data(p, sr=sr, num_harmonics=2)
hcqt = dm.hcqt(audio, sr) hcqt = dm.hcqt(audio, sr)
batch.append(hcqt[0]) batch.append(hcqt[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment