diff --git a/README.md b/README.md index 306b0c595bf9ff25ea7b343a492018bbb186272c..efaddd27f8793f7b216c371e06f499963acafc88 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,6 @@ For example : `python run_CNN_HB.py file1.wav file2.wav -outfn predictions.pkl` This script relies on torch, pandas, numpy, scipy, and tqdm to run. Install dependencies with pip or conda. -If a GPU and cuda are available on the current machine, process will run on GPU for faster computation. +If a GPU and cuda are available on the current machine, processes will run on GPU for faster computation. paul.best@univ-tln.fr for more information diff --git a/models.py b/models.py index abf66e23d20ead2b787fd2b7e7531ac4d89e1a84..75660662c698d0c32e6f7f09ea74a560f8caafa6 100644 --- a/models.py +++ b/models.py @@ -2,8 +2,59 @@ from torch import nn from frontend import STFT, MelFilter, PCENLayer, Log1p +class depthwise_separable_conv1d(nn.Module): + def __init__(self, nin, nout, kernel, padding=0, stride=1): + super(depthwise_separable_conv1d, self).__init__() + self.depthwise = nn.Conv1d(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin) + self.pointwise = nn.Conv1d(nin, nout, kernel_size=1) + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + +class Dropout1d(nn.Module): + def __init__(self, pdropout=.25): + super(Dropout1d, self).__init__() + self.dropout = nn.Dropout2d(pdropout) + def forward(self, x): + x = x.unsqueeze(-1) + x = self.dropout(x) + return x.squeeze(-1) + +PHYSETER_NFEAT = 128 +PHYSETER_KERNEL = 7 +BALAENOPTERA_NFEAT = 128 +BALAENOPTERA_KERNEL = 5 get = { + 'physeter' : nn.Sequential( + STFT(512, 256), + MelFilter(50000, 512, 64, 2000, 25000), + Log1p(), + depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), + nn.BatchNorm1d(PHYSETER_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(PHYSETER_NFEAT, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), + nn.BatchNorm1d(PHYSETER_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2) + ), + 'balaenoptera': nn.Sequential( + STFT(256, 32), + MelFilter(200, 256, 128, 0, 100), + Log1p(), + depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), + nn.BatchNorm1d(BALAENOPTERA_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), + nn.BatchNorm1d(BALAENOPTERA_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(BALAENOPTERA_NFEAT, 1, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2) + ), 'megaptera' : nn.Sequential( nn.Sequential( STFT(512, 64), diff --git a/run_CNN.py b/run_CNN.py index dcd92d060c63da859ca2b35161acb4194b960395..52606ad333192aa2202e8a369edeaab3217dd7a5 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -20,19 +20,24 @@ parser.add_argument('-no-maxPool', dest='maxPool', action='store_false') parser.set_defaults(maxPool=True) args = parser.parse_args() - meta_model = { 'delphinid': { - 'stdc':'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc', + 'stdc': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc', 'fs': 96000 }, 'megaptera': { - 'stdc':'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc', - 'fs':11025 + 'stdc': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc', + 'fs': 11025 }, 'orcinus': '', - 'physeter': '', - 'balaenoptera': '' + 'physeter': { + 'stdc': 'stft_depthwise_ovs_128_k7_r1.stdc', + 'fs': 50000 + }, + 'balaenoptera': { + 'stdc': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc', + 'fs': 200 + } }[args.specie] @@ -42,33 +47,6 @@ def collate_fn(batch): norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr) - -def run(folder, stdcfile, model, fs, lensample, batch_size, maxPool): - model.load_state_dict(torch.load(stdcfile)) - model.eval() - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model.to(device) - - out = pd.DataFrame(columns=['fn', 'offset', 'pred']) - fns, offsets, preds = [], [], [] - loader = data.DataLoader(Dataset(folder, fs, lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4) - if len(loader) == 0: - print('Unable to open any audio file in the given folder') - with torch.no_grad(): - for x, meta in tqdm(loader): - x = x.to(device) - pred = model(x).cpu().detach().numpy() - if maxPool: - pred = pred.max(axis=-1).reshape(len(x)) - else: - pred = pred.reshape(len(x), -1) - fns.extend(meta['fn']) - offsets.extend(meta['offset'].numpy()) - preds.extend(pred) - out.fn, out.offset, out.pred = fns, offsets, preds - return out - - class Dataset(data.Dataset): def __init__(self, folder, fs, lensample): super(Dataset, self) @@ -78,10 +56,9 @@ class Dataset(data.Dataset): try: duration = sf.info(folder+fn).duration except: - print(f'Skipping {fn} (unable to read)') + print(f'Skipping {fn} (unable to read as audio)') continue - for offset in np.arange(0, duration+.01-lensample, lensample): - self.samples.append({'fn':fn, 'offset':offset}) + self.samples.extend([{'fn':fn, 'offset':offset} for offset in np.arange(0, duration+.01-lensample, lensample)]) self.fs, self.folder, self.lensample = fs, folder, lensample def __len__(self): @@ -101,13 +78,33 @@ class Dataset(data.Dataset): sig = norm(sig) return torch.tensor(sig).float(), sample -preds = run(args.audio_folder, - meta_model['stdc'], - models.get[args.specie], - meta_model['fs'], - args.lensample, - args.batch_size, - args.maxPool - ) -preds.to_pickle(args.pred_fn) +# prepare model +model = models.get[args.specie] +model.load_state_dict(torch.load(f"weights/{meta_model['stdc']}")) +model.eval() +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model.to(device) + +# prepare data loader and output storage for predictions +loader = data.DataLoader(Dataset(args.audio_folder, meta_model['fs'], args.lensample), batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4) +out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) +fns, offsets, preds = [], [], [] +if len(loader) == 0: + print('Unable to open any audio file in the given folder') + exit() + +with torch.no_grad(): + for x, meta in tqdm(loader): + x = x.to(device) + pred = model(x).cpu().detach().numpy() + if args.maxPool: + pred = pred.max(axis=-1).reshape(len(x)) + else: + pred = pred.reshape(len(x), -1) + preds.extend(pred) + fns.extend(meta['fn']) + offsets.extend(meta['offset'].numpy()) + +out.filename, out.offset, out.prediction = fns, offsets, preds +out.to_pickle(args.pred_fn) diff --git a/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc new file mode 100644 index 0000000000000000000000000000000000000000..dd62d9ee17bfda5a56610aaa10a36473a59034cc Binary files /dev/null and b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc differ diff --git a/sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc b/weights/sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc similarity index 100% rename from sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc rename to weights/sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc diff --git a/sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc b/weights/sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc similarity index 100% rename from sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc rename to weights/sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc diff --git a/weights/stft_depthwise_ovs_128_k7_r1.stdc b/weights/stft_depthwise_ovs_128_k7_r1.stdc new file mode 100644 index 0000000000000000000000000000000000000000..e131f2b536060353a50b14963462863205981944 Binary files /dev/null and b/weights/stft_depthwise_ovs_128_k7_r1.stdc differ