Skip to content
Snippets Groups Projects
Commit 5fa2a7b0 authored by Maxence Ferrari's avatar Maxence Ferrari
Browse files

Upload New File

parent 0943a7d0
No related branches found
No related tags found
No related merge requests found
from pathlib import Path
import random
from collections import defaultdict
from PIL import Image
import torch
import torchvision as tv
import torchelie.recipes
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
import argparse
import numpy as np
import soundfile as sf
import os
import scipy.signal as sg
from tqdm import tqdm, trange
from math import ceil
def main(args):
batch_size = 64
num_feature = 4096
num_classes = 10
rng = np.random.RandomState(42)
class UpDimV2(torch.nn.Module):
def __init__(self, num_class):
super(UpDimV2, self).__init__()
self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
# Block 1D 1
self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1)
self.norm11 = torch.nn.BatchNorm1d(32)
self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1)
self.norm21 = torch.nn.BatchNorm1d(32)
self.skip11 = torch.nn.Conv1d(1, 32, 1, 2)
# Block 1D 2
self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1)
self.norm12 = torch.nn.BatchNorm1d(64)
self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1)
self.norm22 = torch.nn.BatchNorm1d(128)
self.skip12 = torch.nn.Conv1d(32, 128, 1, 4)
# Block 2D 1
self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1)
self.norm31 = torch.nn.BatchNorm2d(32)
self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1)
self.norm41 = torch.nn.BatchNorm2d(32)
self.skip21 = torch.nn.Conv2d(1, 32, 1, 2)
# Block 2D 2
self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1)
self.norm32 = torch.nn.BatchNorm2d(64)
self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1)
self.norm42 = torch.nn.BatchNorm2d(128)
self.skip22 = torch.nn.Conv2d(32, 128, 1, 4)
# Block 3D 1
self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1)
self.norm51 = torch.nn.BatchNorm3d(32)
self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1)
self.norm61 = torch.nn.BatchNorm3d(64)
self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2))
# Block 3D 2
self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1)
self.norm52 = torch.nn.BatchNorm3d(128)
self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1)
self.norm62 = torch.nn.BatchNorm3d(256)
self.skip32 = torch.nn.Conv3d(64, 256, 1, 4)
# Fully connected
self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton
# which always ouput a 1
self.fc1 = torch.nn.Linear(4096, 1024)
self.fc2 = torch.nn.Linear(1024, 512)
self.fc3 = torch.nn.Linear(512, num_class)
def forward(self, x):
# Block 1D 1
out = self.conv11(x)
out = self.norm11(out)
out = self.activation(out)
out = self.conv21(out)
out = self.norm21(out)
skip = self.skip11(x)
out = self.activation(out + skip)
# Block 1D 2
skip = self.skip12(out)
out = self.conv12(out)
out = self.norm12(out)
out = self.activation(out)
out = self.conv22(out)
out = self.norm22(out)
out = self.activation(out + skip)
# Block 2D 1
out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
skip = self.skip21(out)
out = self.conv31(out)
out = self.norm31(out)
out = self.activation(out)
out = self.conv41(out)
out = self.norm41(out)
out = self.activation(out + skip)
# Block 2D 2
skip = self.skip22(out)
out = self.conv32(out)
out = self.norm32(out)
out = self.activation(out)
out = self.conv42(out)
out = self.norm42(out)
out = self.activation(out + skip)
# Block 3D 1
out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
skip = self.skip31(out)
out = self.conv51(out)
out = self.norm51(out)
out = self.activation(out)
out = self.conv61(out)
out = self.norm61(out)
out = self.activation(out + skip)
# Block 3D 2
skip = self.skip32(out)
out = self.conv52(out)
out = self.norm52(out)
out = self.activation(out)
out = self.conv62(out)
out = self.norm62(out)
out = self.activation(out + skip)
# Fully connected
out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096)
out = self.activation(self.fc1(out))
out = self.activation(self.fc2(out))
return self.fc3(out)
model = torch.nn.DataParallel(UpDimV2(num_classes))
model.load_state_dict((torch.load(args.weight)['model']))
model.to('cuda')
model.eval()
if os.path.isfile(args.input_path):
if args.input_path.endswith('.npy'):
click_data = np.load(args.input_path)
click_data = click_data[:, click_data.shape[1]//2 - num_feature//2:click_data.shape[1]//2+num_feature//2]
with torch.no_grad():
preds = np.empty((len(click_data), num_classes))
for i in trange(len(click_data)//batch_size, desc=f'file: {args.input_path}'):
clicks = click_data[i*batch_size:(i+1)*batch_size]
clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float()
preds[i*batch_size:(i+1)*batch_size] = model(clicks).cpu().numpy()
if not (len(click_data) % batch_size):
clicks = click_data[-(len(click_data) % batch_size):]
clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float()
preds[-(len(click_data) % batch_size):] = model(clicks).cpu().numpy()
np.savetxt(args.input_path.rsplit('.',1)[0] + args.suffix, preds)
else:
song, sr = sf.read(args.input_path, always_2d=True)
song = song[:, args.channel]
sos = sg.butter(3, 200_000/sr, 'lowpass', output='sos')
song = sg.sosfiltfilt(sos, song)
song = sg.resample(song, int(200_000/sr*len(song)))
batch_pos = np.linspace(0, len(song) - num_feature, args.overlap * batch_size * ceil((len(song)//num_feature + 1)/batch_size)).astype(int)
with torch.no_grad():
preds = np.empty((len(batch_pos)//batch_size, batch_size, num_classes))
for i, pos in enumerate(tqdm(batch_pos.reshape(-1, batch_size), desc=f'file: {args.input_path}')):
clicks = np.array([song[p:p+num_feature] for p in pos])
clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float()
preds[i] = model(clicks).cpu().numpy()
np.savetxt(args.input_path.rsplit('.',1)[0] + args.suffix, preds.reshape(-1, num_classes))
else:
for d, _, dire in os.walk(args.input_path):
if args.output_path is not None:
dout = os.path.join(args.output_path, d[len(args.input_path):])
os.makedirs(dout, exist_ok=True)
for f in tqdm(dire, desc=f'directory: {d}'):
if f.rsplit('.', 1)[-1].lower() not in ['wav', 'mp3', 'ogg', 'flac']:
continue
try:
current_file = os.path.join(d, f)
if args.output_path is None:
out_file = os.path.join(d, f).rsplit('.',1)[0] + args.suffix
else:
out_file = os.path.join(dout, f).rsplit('.',1)[0] + args.suffix
if os.path.isfile(out_file) and args.erase:
continue
if args.undersample is not None:
if np.random.random_sample() > args.undersample/100:
continue
song, sr = sf.read(current_file, always_2d=True)
song = song[:, args.channel]
sos = sg.butter(3, 200_000/sr, 'lowpass', output='sos')
song = sg.sosfiltfilt(sos, song)
song = sg.resample(song, int(200_000/sr*len(song)))
batch_pos = np.linspace(0, len(song) - num_feature, args.overlap * batch_size * ceil((len(song)//num_feature + 1)/batch_size)).astype(int)
with torch.no_grad():
preds = np.empty((len(batch_pos)//batch_size, batch_size, num_classes))
for i, pos in enumerate(tqdm(batch_pos.reshape(-1, batch_size), desc=f'file: {current_file}')):
clicks = np.array([song[p:p+num_feature] for p in pos])
clicks = torch.from_numpy(((clicks - clicks.mean(-1, keepdims=True))/(clicks.std(-1, keepdims=True) + 1e-18))[:, np.newaxis]).to('cuda').float()
preds[i] = model(clicks).cpu().numpy()
np.savetxt(out_file, preds.reshape(-1, num_classes))
except Exception as e:
print(f'error with file {current_file}: {e}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Analyse wav(s) and return logits prediction. Use a softmax to have probabilities. The classes order is Gg, Gma, La, Mb, Me, Pm, Ssp, UDA, UDB, Zc")
parser.add_argument("input_path", type=str, help="Folder or path")
parser.add_argument("--weight", type=str, default='best_acc_updimv2_3dlong.pth', help="Model weight")
parser.add_argument("--suffix", type=str, default='.pred', help="Suffix of the output file")
parser.add_argument("--channel", type=int, default=0, help="Channel used for prediction")
parser.add_argument("--overlap", type=int, default=2, help="Overlap factor of prediction windows (win_size/hop_size)")
parser.add_argument("--undersample", type=float, default=None, help="In case of folders, only undersample percent of files will be computed")
parser.add_argument("--output_path", type=str, help="Path to root dir of ouput. Only used if input is folder. Default to input_path")
parser.add_argument("--erase", action='store_false', help="If out_file exist and erase not specified, file will be skip. (Only for folder input)")
args = parser.parse_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment