diff --git a/forward_UpDimV2_long.py b/forward_UpDimV2_long.py new file mode 100644 index 0000000000000000000000000000000000000000..945056c3e0307740066aff175d7f63957e1790cc --- /dev/null +++ b/forward_UpDimV2_long.py @@ -0,0 +1,233 @@ +from pathlib import Path +import random +from collections import defaultdict + +from PIL import Image +import torch +import torchvision as tv +import torchelie.recipes +import torchelie as tch +import torchelie.callbacks.callbacks as tcb +import argparse +import numpy as np +import soundfile as sf +import os +import scipy.signal as sg +from tqdm import tqdm, trange +from math import ceil + + +def main(args): + batch_size = 64 + num_feature = 4096 + num_classes = 10 + rng = np.random.RandomState(42) + + class UpDimV2(torch.nn.Module): + + def __init__(self, num_class): + super(UpDimV2, self).__init__() + self.activation = torch.nn.LeakyReLU(0.001, inplace=True) + + # Block 1D 1 + self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1) + self.norm11 = torch.nn.BatchNorm1d(32) + self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1) + self.norm21 = torch.nn.BatchNorm1d(32) + self.skip11 = torch.nn.Conv1d(1, 32, 1, 2) + + # Block 1D 2 + self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1) + self.norm12 = torch.nn.BatchNorm1d(64) + self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1) + self.norm22 = torch.nn.BatchNorm1d(128) + self.skip12 = torch.nn.Conv1d(32, 128, 1, 4) + + # Block 2D 1 + self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1) + self.norm31 = torch.nn.BatchNorm2d(32) + self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1) + self.norm41 = torch.nn.BatchNorm2d(32) + self.skip21 = torch.nn.Conv2d(1, 32, 1, 2) + + # Block 2D 2 + self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1) + self.norm32 = torch.nn.BatchNorm2d(64) + self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1) + self.norm42 = torch.nn.BatchNorm2d(128) + self.skip22 = torch.nn.Conv2d(32, 128, 1, 4) + + # Block 3D 1 + self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1) + self.norm51 = torch.nn.BatchNorm3d(32) + self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1) + self.norm61 = torch.nn.BatchNorm3d(64) + self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2)) + + # Block 3D 2 + self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1) + self.norm52 = torch.nn.BatchNorm3d(128) + self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1) + self.norm62 = torch.nn.BatchNorm3d(256) + self.skip32 = torch.nn.Conv3d(64, 256, 1, 4) + + # Fully connected + self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton + # which always ouput a 1 + self.fc1 = torch.nn.Linear(4096, 1024) + self.fc2 = torch.nn.Linear(1024, 512) + self.fc3 = torch.nn.Linear(512, num_class) + + def forward(self, x): + # Block 1D 1 + out = self.conv11(x) + out = self.norm11(out) + out = self.activation(out) + out = self.conv21(out) + out = self.norm21(out) + skip = self.skip11(x) + out = self.activation(out + skip) + + # Block 1D 2 + skip = self.skip12(out) + out = self.conv12(out) + out = self.norm12(out) + out = self.activation(out) + out = self.conv22(out) + out = self.norm22(out) + out = self.activation(out + skip) + + # Block 2D 1 + out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape)) + skip = self.skip21(out) + out = self.conv31(out) + out = self.norm31(out) + out = self.activation(out) + out = self.conv41(out) + out = self.norm41(out) + out = self.activation(out + skip) + + # Block 2D 2 + skip = self.skip22(out) + out = self.conv32(out) + out = self.norm32(out) + out = self.activation(out) + out = self.conv42(out) + out = self.norm42(out) + out = self.activation(out + skip) + + # Block 3D 1 + out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape)) + skip = self.skip31(out) + out = self.conv51(out) + out = self.norm51(out) + out = self.activation(out) + out = self.conv61(out) + out = self.norm61(out) + out = self.activation(out + skip) + + # Block 3D 2 + skip = self.skip32(out) + out = self.conv52(out) + out = self.norm52(out) + out = self.activation(out) + out = self.conv62(out) + out = self.norm62(out) + out = self.activation(out + skip) + + # Fully connected + out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096) + out = self.activation(self.fc1(out)) + out = self.activation(self.fc2(out)) + return self.fc3(out) + + + model = torch.nn.DataParallel(UpDimV2(num_classes)) + model.load_state_dict((torch.load(args.weight)['model'])) + model.to('cuda') + model.eval() + + if os.path.isfile(args.input_path): + if args.input_path.endswith('.npy'): + click_data = np.load(args.input_path) + click_data = click_data[:, click_data.shape[1]//2 - num_feature//2:click_data.shape[1]//2+num_feature//2] + with torch.no_grad(): + preds = np.empty((len(click_data), num_classes)) + for i in trange(len(click_data)//batch_size, desc=f'file: {args.input_path}'): + clicks = click_data[i*batch_size:(i+1)*batch_size] + clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float() + preds[i*batch_size:(i+1)*batch_size] = model(clicks).cpu().numpy() + if not (len(click_data) % batch_size): + clicks = click_data[-(len(click_data) % batch_size):] + clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float() + preds[-(len(click_data) % batch_size):] = model(clicks).cpu().numpy() + np.savetxt(args.input_path.rsplit('.',1)[0] + args.suffix, preds) + + else: + song, sr = sf.read(args.input_path, always_2d=True) + song = song[:, args.channel] + sos = sg.butter(3, 200_000/sr, 'lowpass', output='sos') + song = sg.sosfiltfilt(sos, song) + song = sg.resample(song, int(200_000/sr*len(song))) + batch_pos = np.linspace(0, len(song) - num_feature, args.overlap * batch_size * ceil((len(song)//num_feature + 1)/batch_size)).astype(int) + with torch.no_grad(): + preds = np.empty((len(batch_pos)//batch_size, batch_size, num_classes)) + for i, pos in enumerate(tqdm(batch_pos.reshape(-1, batch_size), desc=f'file: {args.input_path}')): + clicks = np.array([song[p:p+num_feature] for p in pos]) + clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float() + preds[i] = model(clicks).cpu().numpy() + np.savetxt(args.input_path.rsplit('.',1)[0] + args.suffix, preds.reshape(-1, num_classes)) + else: + for d, _, dire in os.walk(args.input_path): + if args.output_path is not None: + dout = os.path.join(args.output_path, d[len(args.input_path):]) + os.makedirs(dout, exist_ok=True) + for f in tqdm(dire, desc=f'directory: {d}'): + if f.rsplit('.', 1)[-1].lower() not in ['wav', 'mp3', 'ogg', 'flac']: + continue + try: + current_file = os.path.join(d, f) + if args.output_path is None: + out_file = os.path.join(d, f).rsplit('.',1)[0] + args.suffix + else: + out_file = os.path.join(dout, f).rsplit('.',1)[0] + args.suffix + if os.path.isfile(out_file) and args.erase: + continue + if args.undersample is not None: + if np.random.random_sample() > args.undersample/100: + continue + song, sr = sf.read(current_file, always_2d=True) + song = song[:, args.channel] + sos = sg.butter(3, 200_000/sr, 'lowpass', output='sos') + song = sg.sosfiltfilt(sos, song) + song = sg.resample(song, int(200_000/sr*len(song))) + batch_pos = np.linspace(0, len(song) - num_feature, args.overlap * batch_size * ceil((len(song)//num_feature + 1)/batch_size)).astype(int) + with torch.no_grad(): + preds = np.empty((len(batch_pos)//batch_size, batch_size, num_classes)) + for i, pos in enumerate(tqdm(batch_pos.reshape(-1, batch_size), desc=f'file: {current_file}')): + clicks = np.array([song[p:p+num_feature] for p in pos]) + clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float() + preds[i] = model(clicks).cpu().numpy() + np.savetxt(out_file, preds.reshape(-1, num_classes)) + except Exception as e: + print(f'error with file {current_file}: {e}') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Analyse wav(s) and return logits prediction. Use a softmax to have probabilities. The classes order is Gg, Gma, La, Mb, Me, Pm, Ssp, UDA, UDB, Zc") + parser.add_argument("input_path", type=str, help="Folder or path") + parser.add_argument("--weight", type=str, default='best_acc_updimv2_3dlong.pth', help="Model weight") + parser.add_argument("--suffix", type=str, default='.pred', help="Suffix of the output file") + parser.add_argument("--channel", type=int, default=0, help="Channel used for prediction") + parser.add_argument("--overlap", type=int, default=2, help="Overlap factor of prediction windows (win_size/hop_size)") + parser.add_argument("--undersample", type=float, default=None, help="In case of folders, only undersample percent of files will be computed") + parser.add_argument("--output_path", type=str, help="Path to root dir of ouput. Only used if input is folder. Default to input_path") + parser.add_argument("--erase", action='store_false', help="If out_file exist and erase not specified, file will be skip. (Only for folder input)") + + args = parser.parse_args() + main(args) + + +