Commit 48f4a3a3 authored by ferrari's avatar ferrari
Browse files

initial commit

parents
UpDimV3
==================
The repository contain example codes to train the UpDim model
This project require the following libraries:
* torch
* [torchelie](https://github.com/Vermeille/Torchelie)
* scipy
* soundfile
* numpy as np
* [tqdm](https://github.com/tqdm/tqdm)
The files in this project are
* sam.py contains the code for the Sharpness aware minimizer (SAM)
* UpDimV3.py contains the UpDimV3 class
* UpDimV3_quebec.py contains the codes to train the model with Adam optimizer
* UpDimV3_quebec_SAM.py contains the codes to train the model with SAM optimizer
Usage
-----
For the main scripts, the options can be displayed with `python script_name -h`
You can train a network with
```shell
python UpDimV3_quebec_SAM.py --train xeno_train_1500.npy mcaulay_train_2200.npy
```
The train and test files should be array containing the path to the sound files that will be used.
The abiotic file `sons_abiotiques.npy` is also requiered.
The database `xeno_train_1500.npy` and`mcaulay_train_2200.npy` are private.
\ No newline at end of file
import torch
class UpDimV3(torch.nn.Module):
def __init__(self, num_class):
super(UpDimV3, self).__init__()
self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
self.dropout = torch.nn.Dropout()
# Block 1D 1
self.seq1 = torch.nn.Sequential(torch.nn.Conv1d(1, 32, 3, 1, 1),
torch.nn.MaxPool1d(3, 2, 1),
torch.nn.BatchNorm1d(32),
self.activation,
torch.nn.Conv1d(32, 32, 3, 1, 1),
torch.nn.MaxPool1d(3, 2, 1),
torch.nn.BatchNorm1d(32))
self.skip11 = torch.nn.Conv1d(1, 32, 1, 1)
self.skip_pool11 = torch.nn.MaxPool1d(5, 4, 2)
# Block 1D 2
self.seq2 = torch.nn.Sequential(torch.nn.Conv1d(32, 64, 3, 1, 1),
torch.nn.MaxPool1d(3, 2, 1),
torch.nn.BatchNorm1d(64),
self.activation,
torch.nn.Conv1d(64, 128, 5, 1, 2),
torch.nn.MaxPool1d(5, 4, 2),
torch.nn.BatchNorm1d(128))
self.skip12 = torch.nn.Conv1d(32, 128, 1, 1)
self.skip_pool12 = torch.nn.MaxPool1d(9, 8, 4)
# Block 2D 1
self.seq3 = torch.nn.Sequential(torch.nn.Conv2d(1, 32, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d((1, 3), (1, 2), (0, 1)),
torch.nn.BatchNorm2d(32),
self.activation,
torch.nn.Conv2d(32, 32, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d((3, 5), (2, 4), (1, 2)),
torch.nn.BatchNorm2d(32))
self.skip21 = torch.nn.Conv2d(1, 32, 1)
self.skip_pool21 = torch.nn.MaxPool2d((3, 9), (2, 8), (1, 4))
# Block 2D 2
self.seq4 = torch.nn.Sequential(torch.nn.Conv2d(32, 64, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d((3, 5), (2, 4), (1, 2)),
torch.nn.BatchNorm2d(64),
self.activation,
torch.nn.Conv2d(64, 128, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d(3, 2, 1),
torch.nn.BatchNorm2d(128))
self.skip22 = torch.nn.Conv2d(32, 128, 1)
self.skip_pool22 = torch.nn.MaxPool2d((5, 9), (4, 8), (2, 4))
# Block 3D 1
self.seq5 = torch.nn.Sequential(torch.nn.Conv3d(1, 32, (3, 5, 9), 1, (1, 2, 4)),
torch.nn.MaxPool3d((1,1,3), (1, 1, 2), (0,0,1)),
torch.nn.BatchNorm3d(32),
self.activation,
torch.nn.Conv3d(32, 64, (3, 5, 9), 1, (1, 2, 4)),
torch.nn.MaxPool3d(3, 2, 1),
torch.nn.BatchNorm3d(64))
self.skip31 = torch.nn.Conv3d(1, 64, 1)
self.skip_pool31 = torch.nn.MaxPool3d((3, 3, 5), (2, 2, 4), (1, 1, 2))
# Block 3D 2
self.seq6 = torch.nn.Sequential(torch.nn.Conv3d(64, 128, (3, 5, 15), 1, (1, 2, 7)),
torch.nn.MaxPool3d(3, 2, 1),
torch.nn.BatchNorm3d(128),
self.activation,
torch.nn.Conv3d(128, 256, (3, 5, 15), 1, (1, 2, 7)),
torch.nn.MaxPool3d(3, 2, 1),
torch.nn.BatchNorm3d(256))
self.skip32 = torch.nn.Conv3d(64, 256, 1)
self.skip_pool32 = torch.nn.MaxPool3d(5, 4, 2)
# Fully connected
self.obo = torch.nn.Sequential(torch.nn.Conv1d(8192, 1024, 1, 1),
torch.nn.Conv1d(1024, 1024, 1, 1),
torch.nn.Conv1d(1024, 1024, 1, 1),
self.activation,
torch.nn.Conv1d(1024, 1024, 1, 1),
torch.nn.Conv1d(1024, 1024, 1, 1),
torch.nn.Conv1d(1024, 1, 1, 1))
self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton
# which always ouput a 1
self.fc1 = torch.nn.Linear(8192, 1024)
self.fc2 = torch.nn.Linear(1024, 512)
self.fc3 = torch.nn.Linear(512, num_class)
def forward(self, x):
# Block 1D 1
out = self.seq1(x)
skip = self.skip_pool11(self.skip11(x))
out = self.activation(out + skip)
# Block 1D 2
skip = self.skip_pool12(self.skip12(out))
out = self.seq2(out)
out = self.activation(out + skip)
# Block 2D 1
out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
skip = self.skip_pool21(self.skip21(out))
out = self.seq3(out)
out = self.activation(out + skip)
# Block 2D 2
skip = self.skip_pool22(self.skip22(out))
out = self.seq4(out)
out = self.activation(out + skip)
# Block 3D 1
out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
skip = self.skip_pool31(self.skip31(out))
out = self.seq5(out)
out = self.activation(out + skip)
# Block 3D 2
skip = self.skip_pool32(self.skip32(out))
out = self.seq6(out)
out = self.activation(out + skip)
# Fully connected
out = out.reshape(len(out), 8192, -1)
out = torch.mean(out*self.soft_max(self.obo(out)), -1)
out = self.dropout(self.activation(self.fc1(out)))
out = self.dropout(self.activation(self.fc2(out)))
return self.fc3(out)
import torch
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
import scipy.signal as sg
import soundfile as sf
import argparse
import numpy as np
from tqdm import tqdm, trange
from UpDimV3 import UpDimV3
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Train an UpDimV3 model with ADAM optimizer""")
parser.add_argument("--train",
default=['xeno_train_1500.npy'],
type=str, nargs='*', help="Paths to train file(s)")
parser.add_argument("--test",
default=['JF_propres_43.npy'],
type=str, nargs='*', help="Paths to test file(s)")
parser.add_argument("--rng", action='store_false', help='unset rng state')
parser.add_argument("--weight", type=str, default='', help="Model weight for resume")
parser.add_argument("--hot_weight", type=str, default='', help="Model weight for hot restart")
parser.add_argument("--run_name", type=str, default='', help='suffix of the run name')
args = parser.parse_args()
batch_size = 32
sound_len = 7
sound_sr = 22050
num_feature = sound_len * sound_sr
num_classes = 43
if args.rng:
rng = np.random.RandomState(42)
else:
rng = np.random.RandomState()
def pink_noise(size, rng, ncols=16, axis=-1):
"""Generates pink noise using the Voss-McCartney algorithm.
size: either a tuple of int or an int. If an int : number of sample to generate. If a tuple: shape of the return array.
ncols: number of random sources to add. Should be high enough so that num_samples*0.5**(ncols-2) is near zero.
axis: axis which contains the sound samples. Generate white noise otherwise.
returns: NumPy array of shape size
"""
if type(size) is not tuple:
size = (size,)
array = rng.rand(*size)
assert -len(size) <= axis < len(size)
axis %= len(size)
axis += 1
# the total number of changes is nrows
cols = rng.geometric(0.5, size)
cols[cols >= ncols] = 0
cols = (1. * (np.arange(1, ncols).reshape((-1,) + len(size) * (1,)) == cols)).swapaxes(axis, -1)
cols[..., 0] = 1.
cols = np.cumsum(cols).reshape(cols.shape).astype(int).swapaxes(axis, -1)
array = np.concatenate([array[np.newaxis], rng.rand(cols.max() + 1)[cols]], axis=0).sum(0)
return array
class Dataset:
def __init__(self, paths, train, rank, rng=None):
self.path = np.concatenate([np.load(p) for p in paths])
self.wav = len(self.path) * [None]
self.labels = len(self.path) * [None]
for i, p in enumerate(tqdm(self.path)):
self.wav[i] = p
self.labels[i] = p.split('/')[-1][:4]
self.wav = np.array([x for x in self.wav if x is not None])
self.labels = np.array([x for x in self.labels if x is not None])
self.names, self.labels = np.unique(self.labels, return_inverse=True)
assert len(self.wav) == len(self.labels)
print(f'Found {len(self.wav)} wav')
assert len(self.names) == num_classes
self.labels = np.eye(num_classes)[self.labels]
self.classes = self.names
self.train = train
self.rank = rank
self.rng = rng
self.wind = np.load('sons_abiotiques.npy').squeeze()
self.test_loaded = False
if not self.train:
print(f'Loading test dataset')
self.test_tuples = self.__len__() * [(None, None)]
for i in trange(self.__len__()):
self.test_tuples[i] = self.__getitem__(i)
self.test_loaded = True
def __len__(self):
return len(self.wav)
def get_class_names(self):
return self.names
def __getitem__(self, i):
if self.test_loaded and not self.train:
return self.test_tuples[i]
p = self.wav[i]
label = self.labels[i]
sample, sr = sf.read(p, always_2d=True)
sample = sample[:, 0]
if sr != sound_sr:
sample = sg.resample(sample, int(len(sample)*sound_sr/sr))
if self.train:
sample = sg.resample(sample, int(np.clip(len(sample) * rng.normal(1, 0.025, 1),
num_feature+10, 1.10*len(sample))))
p = self.rng.randint(0, len(sample) - num_feature)
r = self.rng.randint(0, 100)
if r < 20:
rev = np.zeros(sound_sr)
if r % 3:
rev += self.rng.normal(0, 0.001, len(rev))
rev[0] = 1
for k in range(r%4):
rev[self.rng.randint(0.001*sound_sr, len(rev))] = self.rng.uniform(0, 1)
sample = torch.nn.functional.conv1d(torch.Tensor(sample[None, None]).to(self.rank),
torch.Tensor(np.flip(rev).copy()[None, None]).to(self.rank),
padding=len(rev)//2+1).cpu().numpy().squeeze()
sample = sample[p:p+num_feature]
else:
sample = sample[:num_feature]
sample -= sample.mean(-1,keepdims=True)
sample /= sample.std(-1,keepdims=True) + 1e-18
if self.train:
sample += self.rng.normal(0, 1, num_feature) * 10**self.rng.uniform(-2.5, -0.12, 1)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))\
(np.cumsum(self.rng.normal(0, 1, num_feature))) * 10**self.rng.uniform(-2.5, -0.5, 1)
sample += ((lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))
(pink_noise(num_feature, self.rng))) * 10**self.rng.uniform(-4, -1, 1)
wind, sr_wind = sf.read(self.rng.choice(self.wind), always_2d=True)
wind = wind[:, 0]
if sr != sound_sr:
wind = sg.resample(wind, int(len(wind) * sound_sr / sr))
wind = sg.resample(wind, int(np.clip(len(wind) * rng.normal(1, 0.025, 1),
num_feature + 10, 1.10 * len(wind))))
p_wind = self.rng.randint(0, len(wind) - num_feature)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))(wind[p_wind:p_wind + num_feature])\
* 10**self.rng.uniform(-2, -0.05, 1) * (-1)**self.rng.randint(0, 10)
sample -= sample.mean(-1, keepdims=True)
sample /= sample.std(-1, keepdims=True)
p = self.rng.randint(0, len(sample) - int(sound_sr * 0.5))
sample[p:p+self.rng.randint(0, sound_sr * 0.5)] = 0
sample *= (-1)**self.rng.randint(0, 10)
return torch.from_numpy(sample[np.newaxis]).float(), np.argmax(label)
def train(rank, world_size):
ds = torch.utils.data.DataLoader(tch.datasets.MixUpDataset(Dataset(args.train, True, rank, rng), 0.3),
batch_size//world_size,
shuffle=True,
drop_last=True,
num_workers=6,
pin_memory=True)
test_dst = Dataset(args.test, False, rank)
dst = torch.utils.data.DataLoader(test_dst,
batch_size=batch_size//world_size,
num_workers=8,
shuffle=True)
print('Dataset loaded')
model = torch.nn.parallel.DistributedDataParallel(UpDimV3(num_classes).to(rank), [rank], rank)
def accuracy(preds, labels):
with torch.no_grad():
id_preds = torch.max(preds, 1)[1]
id_labels = torch.max(labels, 1)[1]
return 100 * (id_preds == id_labels).sum().item() / len(labels)
def train(batch):
x, y = batch
pred = model(x)
loss = tch.loss.continuous_cross_entropy(pred, y)
loss.backward()
return {'loss': loss, 'pred': pred, 'accuracy': accuracy(pred, y)}
def test(batch):
x, y = batch
pred = model(x)
y = torch.eye(num_classes)[y].to(y.device)
loss = tch.loss.continuous_cross_entropy(pred, y)
return {'loss': loss, 'pred': pred, 'accuracy': accuracy(pred, y)}
recipe = tch.recipes.TrainAndTest(model,
train,
test,
ds,
dst,
test_every=1000,
log_every=64,
checkpoint='model3_quebec' + args.run_name if rank == 0 else None,
visdom_env=None,
key_best=(lambda x: x['test_loop']['callbacks']['state']['metrics']['accuracy']))
opt = tch.optim.RAdamW(model.parameters(),
5e-3,
weight_decay=0.005)
recipe.callbacks.add_callbacks([
tcb.Optimizer(opt, log_lr=True, clip_grad_norm=5),
tcb.LRSched(torch.optim.lr_scheduler.MultiStepLR(opt, list(range(20,80)), (1-1/np.exp(1))/2), metric=None),
tcb.WindowedMetricAvg('loss'),
tcb.WindowedMetricAvg('accuracy'),
])
recipe.callbacks.add_epilogues([
tcb.TensorboardLogger(log_dir='UpDimV3_quebec' + args.run_name + '/train' if rank == 0 else None, log_every=64),])
recipe.test_loop.callbacks.add_callbacks([
tcb.EpochMetricAvg('loss', False),
tcb.EpochMetricAvg('accuracy', False),
# tcb.ConfusionMatrix(test_dst.get_class_names(), True)
])
recipe.test_loop.callbacks.add_epilogues([
tcb.TensorboardLogger(log_dir='UpDimV3_quebec' + args.run_name + '/test' if rank == 0 else None,
log_every=-1),])
recipe.to(rank)
if args.hot_weight != '':
ckpt = torch.load(args.hot_weight)
del ckpt['callbacks']['callbacks']['Counter_prologue_0']
del ckpt['callbacks']['callbacks']['LRSched_middle_0']
del ckpt['test_loop']
del ckpt['callbacks']['state']['metrics']['lr_0']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['lr'] = opt.param_groups[0]['lr']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['eps'] = opt.param_groups[0]['eps']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['weight_decay'] = opt.param_groups[0]['weight_decay']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['betas'] = opt.param_groups[0]['betas']
recipe.load_state_dict(ckpt)
if args.weight != '':
ckpt = torch.load(args.weight)
recipe.load_state_dict(ckpt)
recipe.run(80)
if __name__ == '__main__':
tch.utils.parallel_run(train)
import torch
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
import scipy.signal as sg
import soundfile as sf
import argparse
import numpy as np
from tqdm import tqdm, trange
from sam import SAMSGD
import torchelie.utils as tu
from UpDimV3 import UpDimV3
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Train an UpDimV3 model with SAM optimizer""")
parser.add_argument("--train",
default=['xeno_train_1500.npy'],
type=str, nargs='*', help="Paths to train file(s)")
parser.add_argument("--test",
default=['JF_propres_43.npy'],
type=str, nargs='*', help="Paths to test file(s)")
parser.add_argument("--rng", action='store_false', help='unset rng state')
parser.add_argument("--weight", type=str, default='', help="recipy weight for resume")
parser.add_argument("--from_ckpt", type=str, default='', help="Model weight for resume")
parser.add_argument("--hot_weight", type=str, default='', help="Model weight for hot restart")
parser.add_argument("--run_name", type=str, default='', help='suffix of the run name')
args = parser.parse_args()
batch_size = 32
sound_len = 7
sound_sr = 22050
num_feature = sound_len * sound_sr
num_classes = 43
if args.rng:
rng = np.random.RandomState(42)
else:
rng = np.random.RandomState()
def pink_noise(size, rng, ncols=16, axis=-1):
"""Generates pink noise using the Voss-McCartney algorithm.
size: either a tuple of int or an int. If an int : number of sample to generate. If a tuple: shape of the return array.
ncols: number of random sources to add. Should be high enough so that num_samples*0.5**(ncols-2) is near zero.
axis: axis which contains the sound samples. Generate white noise otherwise.
returns: NumPy array of shape size
"""
if type(size) is not tuple:
size = (size,)
array = rng.rand(*size)
assert -len(size) <= axis < len(size)
axis %= len(size)
axis += 1
# the total number of changes is nrows
cols = rng.geometric(0.5, size)
cols[cols >= ncols] = 0
cols = (1. * (np.arange(1, ncols).reshape((-1,) + len(size) * (1,)) == cols)).swapaxes(axis, -1)
cols[..., 0] = 1.
cols = np.cumsum(cols).reshape(cols.shape).astype(int).swapaxes(axis, -1)
array = np.concatenate([array[np.newaxis], rng.rand(cols.max() + 1)[cols]], axis=0).sum(0)
return array
class Dataset:
def __init__(self, paths, train, rank, rng=None):
self.path = np.concatenate([np.load(p) for p in paths])
self.wav = len(self.path) * [None]
self.labels = len(self.path) * [None]
for i, p in enumerate(tqdm(self.path)):
self.wav[i] = p
self.labels[i] = p.split('/')[-1][:4]
self.wav = np.array([x for x in self.wav if x is not None])
self.labels = np.array([x for x in self.labels if x is not None])
self.names, self.labels = np.unique(self.labels, return_inverse=True)
assert len(self.wav) == len(self.labels)
print(f'Found {len(self.wav)} wav')
assert len(self.names) == num_classes
self.labels = np.eye(num_classes)[self.labels]
self.classes = self.names
self.train = train
self.rank = rank
self.rng = rng
self.wind = np.load('/nfs/NAS4/anatole/data/quebec/sons_abiotiques/sons_abiotiques.npy').squeeze()
self.test_loaded = False
if not self.train:
print(f'Loading test dataset')
self.test_tuples = self.__len__() * [(None, None)]
for i in trange(self.__len__()):
self.test_tuples[i] = self.__getitem__(i)
self.test_loaded = True
def __len__(self):
return len(self.wav)
def get_class_names(self):
return self.names
def __getitem__(self, i):
if self.test_loaded and not self.train:
return self.test_tuples[i]
p = self.wav[i]
label = self.labels[i]
sample, sr = sf.read(p, always_2d=True)
sample = sample[:, 0]
if sr != sound_sr:
sample = sg.resample(sample, int(len(sample)*sound_sr/sr))
if self.train:
sample = sg.resample(sample, int(np.clip(len(sample) * rng.normal(1, 0.025, 1),
num_feature+10, 1.10*len(sample))))
p = self.rng.randint(0, len(sample) - num_feature)
r = self.rng.randint(0, 100)
if r < 20:
rev = np.zeros(sound_sr)
if r % 3:
rev += self.rng.normal(0, 0.001, len(rev))
rev[0] = 1
for k in range(r%4):
rev[self.rng.randint(0.001*sound_sr, len(rev))] = self.rng.uniform(0, 1)
sample = torch.nn.functional.conv1d(torch.Tensor(sample[None, None]).to(self.rank),
torch.Tensor(np.flip(rev).copy()[None, None]).to(self.rank),
padding=len(rev)//2+1).cpu().numpy().squeeze()
sample = sample[p:p+num_feature]
else:
sample = sample[:num_feature]
sample -= sample.mean(-1,keepdims=True)
sample /= sample.std(-1,keepdims=True) + 1e-18
if self.train:
sample += self.rng.normal(0, 1, num_feature) * 10**self.rng.uniform(-2.5, -0.12, 1)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))\
(np.cumsum(self.rng.normal(0, 1, num_feature))) * 10**self.rng.uniform(-2.5, -0.5, 1)
sample += ((lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))
(pink_noise(num_feature, self.rng))) * 10**self.rng.uniform(-4, -1, 1)
wind, sr_wind = sf.read(self.rng.choice(self.wind), always_2d=True)
wind = wind[:, 0]
if sr != sound_sr:
wind = sg.resample(wind, int(len(wind) * sound_sr / sr))
wind = sg.resample(wind, int(np.clip(len(wind) * rng.normal(1, 0.025, 1),
num_feature + 10, 1.10 * len(wind))))
p_wind = self.rng.randint(0, len(wind) - num_feature)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))(wind[p_wind:p_wind + num_feature])\
* 10**self.rng.uniform(-2, -0.05, 1) * (-1)**self.rng.randint(0, 10)
sample -= sample.mean(-1, keepdims=True)
sample /= sample.std(-1, keepdims=True)
p = self.rng.randint(0, len(sample) - int(sound_sr * 0.5))
sample[p:p+self.rng.randint(0, sound_sr * 0.5)] = 0
sample *= (-1)**self.rng.randint(0, 10)
return torch.from_numpy(sample[np.newaxis]).float(), np.argmax(label)
def train(rank, world_size):
ds = torch.utils.data.DataLoader(tch.datasets.MixUpDataset(Dataset(args.train, True, rank, rng), 0.3),
batch_size//world_size,
shuffle=True,
drop_last=True,
num_workers=6,
pin_memory=True)
test_dst = Dataset(args.test, False, rank)
dst = torch.utils.data.DataLoader(test_dst,
batch_size=batch_size//world_size,
num_workers=8,
shuffle=True)
print('Dataset loaded')
model = UpDimV3(num_classes)
if args.from_ckpt != '':
model.load_state_dict(
torch.load(args.from_ckpt, map_location='cuda:' +