Skip to content
Snippets Groups Projects
Commit 48f4a3a3 authored by ferrari's avatar ferrari
Browse files

initial commit

parents
No related branches found
No related tags found
No related merge requests found
UpDimV3
==================
The repository contain example codes to train the UpDim model
This project require the following libraries:
* torch
* [torchelie](https://github.com/Vermeille/Torchelie)
* scipy
* soundfile
* numpy as np
* [tqdm](https://github.com/tqdm/tqdm)
The files in this project are
* sam.py contains the code for the Sharpness aware minimizer (SAM)
* UpDimV3.py contains the UpDimV3 class
* UpDimV3_quebec.py contains the codes to train the model with Adam optimizer
* UpDimV3_quebec_SAM.py contains the codes to train the model with SAM optimizer
Usage
-----
For the main scripts, the options can be displayed with `python script_name -h`
You can train a network with
```shell
python UpDimV3_quebec_SAM.py --train xeno_train_1500.npy mcaulay_train_2200.npy
```
The train and test files should be array containing the path to the sound files that will be used.
The abiotic file `sons_abiotiques.npy` is also requiered.
The database `xeno_train_1500.npy` and`mcaulay_train_2200.npy` are private.
\ No newline at end of file
import torch
class UpDimV3(torch.nn.Module):
def __init__(self, num_class):
super(UpDimV3, self).__init__()
self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
self.dropout = torch.nn.Dropout()
# Block 1D 1
self.seq1 = torch.nn.Sequential(torch.nn.Conv1d(1, 32, 3, 1, 1),
torch.nn.MaxPool1d(3, 2, 1),
torch.nn.BatchNorm1d(32),
self.activation,
torch.nn.Conv1d(32, 32, 3, 1, 1),
torch.nn.MaxPool1d(3, 2, 1),
torch.nn.BatchNorm1d(32))
self.skip11 = torch.nn.Conv1d(1, 32, 1, 1)
self.skip_pool11 = torch.nn.MaxPool1d(5, 4, 2)
# Block 1D 2
self.seq2 = torch.nn.Sequential(torch.nn.Conv1d(32, 64, 3, 1, 1),
torch.nn.MaxPool1d(3, 2, 1),
torch.nn.BatchNorm1d(64),
self.activation,
torch.nn.Conv1d(64, 128, 5, 1, 2),
torch.nn.MaxPool1d(5, 4, 2),
torch.nn.BatchNorm1d(128))
self.skip12 = torch.nn.Conv1d(32, 128, 1, 1)
self.skip_pool12 = torch.nn.MaxPool1d(9, 8, 4)
# Block 2D 1
self.seq3 = torch.nn.Sequential(torch.nn.Conv2d(1, 32, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d((1, 3), (1, 2), (0, 1)),
torch.nn.BatchNorm2d(32),
self.activation,
torch.nn.Conv2d(32, 32, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d((3, 5), (2, 4), (1, 2)),
torch.nn.BatchNorm2d(32))
self.skip21 = torch.nn.Conv2d(1, 32, 1)
self.skip_pool21 = torch.nn.MaxPool2d((3, 9), (2, 8), (1, 4))
# Block 2D 2
self.seq4 = torch.nn.Sequential(torch.nn.Conv2d(32, 64, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d((3, 5), (2, 4), (1, 2)),
torch.nn.BatchNorm2d(64),
self.activation,
torch.nn.Conv2d(64, 128, (3, 5), 1, (1, 2)),
torch.nn.MaxPool2d(3, 2, 1),
torch.nn.BatchNorm2d(128))
self.skip22 = torch.nn.Conv2d(32, 128, 1)
self.skip_pool22 = torch.nn.MaxPool2d((5, 9), (4, 8), (2, 4))
# Block 3D 1
self.seq5 = torch.nn.Sequential(torch.nn.Conv3d(1, 32, (3, 5, 9), 1, (1, 2, 4)),
torch.nn.MaxPool3d((1,1,3), (1, 1, 2), (0,0,1)),
torch.nn.BatchNorm3d(32),
self.activation,
torch.nn.Conv3d(32, 64, (3, 5, 9), 1, (1, 2, 4)),
torch.nn.MaxPool3d(3, 2, 1),
torch.nn.BatchNorm3d(64))
self.skip31 = torch.nn.Conv3d(1, 64, 1)
self.skip_pool31 = torch.nn.MaxPool3d((3, 3, 5), (2, 2, 4), (1, 1, 2))
# Block 3D 2
self.seq6 = torch.nn.Sequential(torch.nn.Conv3d(64, 128, (3, 5, 15), 1, (1, 2, 7)),
torch.nn.MaxPool3d(3, 2, 1),
torch.nn.BatchNorm3d(128),
self.activation,
torch.nn.Conv3d(128, 256, (3, 5, 15), 1, (1, 2, 7)),
torch.nn.MaxPool3d(3, 2, 1),
torch.nn.BatchNorm3d(256))
self.skip32 = torch.nn.Conv3d(64, 256, 1)
self.skip_pool32 = torch.nn.MaxPool3d(5, 4, 2)
# Fully connected
self.obo = torch.nn.Sequential(torch.nn.Conv1d(8192, 1024, 1, 1),
torch.nn.Conv1d(1024, 1024, 1, 1),
torch.nn.Conv1d(1024, 1024, 1, 1),
self.activation,
torch.nn.Conv1d(1024, 1024, 1, 1),
torch.nn.Conv1d(1024, 1024, 1, 1),
torch.nn.Conv1d(1024, 1, 1, 1))
self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton
# which always ouput a 1
self.fc1 = torch.nn.Linear(8192, 1024)
self.fc2 = torch.nn.Linear(1024, 512)
self.fc3 = torch.nn.Linear(512, num_class)
def forward(self, x):
# Block 1D 1
out = self.seq1(x)
skip = self.skip_pool11(self.skip11(x))
out = self.activation(out + skip)
# Block 1D 2
skip = self.skip_pool12(self.skip12(out))
out = self.seq2(out)
out = self.activation(out + skip)
# Block 2D 1
out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
skip = self.skip_pool21(self.skip21(out))
out = self.seq3(out)
out = self.activation(out + skip)
# Block 2D 2
skip = self.skip_pool22(self.skip22(out))
out = self.seq4(out)
out = self.activation(out + skip)
# Block 3D 1
out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
skip = self.skip_pool31(self.skip31(out))
out = self.seq5(out)
out = self.activation(out + skip)
# Block 3D 2
skip = self.skip_pool32(self.skip32(out))
out = self.seq6(out)
out = self.activation(out + skip)
# Fully connected
out = out.reshape(len(out), 8192, -1)
out = torch.mean(out*self.soft_max(self.obo(out)), -1)
out = self.dropout(self.activation(self.fc1(out)))
out = self.dropout(self.activation(self.fc2(out)))
return self.fc3(out)
import torch
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
import scipy.signal as sg
import soundfile as sf
import argparse
import numpy as np
from tqdm import tqdm, trange
from UpDimV3 import UpDimV3
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Train an UpDimV3 model with ADAM optimizer""")
parser.add_argument("--train",
default=['xeno_train_1500.npy'],
type=str, nargs='*', help="Paths to train file(s)")
parser.add_argument("--test",
default=['JF_propres_43.npy'],
type=str, nargs='*', help="Paths to test file(s)")
parser.add_argument("--rng", action='store_false', help='unset rng state')
parser.add_argument("--weight", type=str, default='', help="Model weight for resume")
parser.add_argument("--hot_weight", type=str, default='', help="Model weight for hot restart")
parser.add_argument("--run_name", type=str, default='', help='suffix of the run name')
args = parser.parse_args()
batch_size = 32
sound_len = 7
sound_sr = 22050
num_feature = sound_len * sound_sr
num_classes = 43
if args.rng:
rng = np.random.RandomState(42)
else:
rng = np.random.RandomState()
def pink_noise(size, rng, ncols=16, axis=-1):
"""Generates pink noise using the Voss-McCartney algorithm.
size: either a tuple of int or an int. If an int : number of sample to generate. If a tuple: shape of the return array.
ncols: number of random sources to add. Should be high enough so that num_samples*0.5**(ncols-2) is near zero.
axis: axis which contains the sound samples. Generate white noise otherwise.
returns: NumPy array of shape size
"""
if type(size) is not tuple:
size = (size,)
array = rng.rand(*size)
assert -len(size) <= axis < len(size)
axis %= len(size)
axis += 1
# the total number of changes is nrows
cols = rng.geometric(0.5, size)
cols[cols >= ncols] = 0
cols = (1. * (np.arange(1, ncols).reshape((-1,) + len(size) * (1,)) == cols)).swapaxes(axis, -1)
cols[..., 0] = 1.
cols = np.cumsum(cols).reshape(cols.shape).astype(int).swapaxes(axis, -1)
array = np.concatenate([array[np.newaxis], rng.rand(cols.max() + 1)[cols]], axis=0).sum(0)
return array
class Dataset:
def __init__(self, paths, train, rank, rng=None):
self.path = np.concatenate([np.load(p) for p in paths])
self.wav = len(self.path) * [None]
self.labels = len(self.path) * [None]
for i, p in enumerate(tqdm(self.path)):
self.wav[i] = p
self.labels[i] = p.split('/')[-1][:4]
self.wav = np.array([x for x in self.wav if x is not None])
self.labels = np.array([x for x in self.labels if x is not None])
self.names, self.labels = np.unique(self.labels, return_inverse=True)
assert len(self.wav) == len(self.labels)
print(f'Found {len(self.wav)} wav')
assert len(self.names) == num_classes
self.labels = np.eye(num_classes)[self.labels]
self.classes = self.names
self.train = train
self.rank = rank
self.rng = rng
self.wind = np.load('sons_abiotiques.npy').squeeze()
self.test_loaded = False
if not self.train:
print(f'Loading test dataset')
self.test_tuples = self.__len__() * [(None, None)]
for i in trange(self.__len__()):
self.test_tuples[i] = self.__getitem__(i)
self.test_loaded = True
def __len__(self):
return len(self.wav)
def get_class_names(self):
return self.names
def __getitem__(self, i):
if self.test_loaded and not self.train:
return self.test_tuples[i]
p = self.wav[i]
label = self.labels[i]
sample, sr = sf.read(p, always_2d=True)
sample = sample[:, 0]
if sr != sound_sr:
sample = sg.resample(sample, int(len(sample)*sound_sr/sr))
if self.train:
sample = sg.resample(sample, int(np.clip(len(sample) * rng.normal(1, 0.025, 1),
num_feature+10, 1.10*len(sample))))
p = self.rng.randint(0, len(sample) - num_feature)
r = self.rng.randint(0, 100)
if r < 20:
rev = np.zeros(sound_sr)
if r % 3:
rev += self.rng.normal(0, 0.001, len(rev))
rev[0] = 1
for k in range(r%4):
rev[self.rng.randint(0.001*sound_sr, len(rev))] = self.rng.uniform(0, 1)
sample = torch.nn.functional.conv1d(torch.Tensor(sample[None, None]).to(self.rank),
torch.Tensor(np.flip(rev).copy()[None, None]).to(self.rank),
padding=len(rev)//2+1).cpu().numpy().squeeze()
sample = sample[p:p+num_feature]
else:
sample = sample[:num_feature]
sample -= sample.mean(-1,keepdims=True)
sample /= sample.std(-1,keepdims=True) + 1e-18
if self.train:
sample += self.rng.normal(0, 1, num_feature) * 10**self.rng.uniform(-2.5, -0.12, 1)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))\
(np.cumsum(self.rng.normal(0, 1, num_feature))) * 10**self.rng.uniform(-2.5, -0.5, 1)
sample += ((lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))
(pink_noise(num_feature, self.rng))) * 10**self.rng.uniform(-4, -1, 1)
wind, sr_wind = sf.read(self.rng.choice(self.wind), always_2d=True)
wind = wind[:, 0]
if sr != sound_sr:
wind = sg.resample(wind, int(len(wind) * sound_sr / sr))
wind = sg.resample(wind, int(np.clip(len(wind) * rng.normal(1, 0.025, 1),
num_feature + 10, 1.10 * len(wind))))
p_wind = self.rng.randint(0, len(wind) - num_feature)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))(wind[p_wind:p_wind + num_feature])\
* 10**self.rng.uniform(-2, -0.05, 1) * (-1)**self.rng.randint(0, 10)
sample -= sample.mean(-1, keepdims=True)
sample /= sample.std(-1, keepdims=True)
p = self.rng.randint(0, len(sample) - int(sound_sr * 0.5))
sample[p:p+self.rng.randint(0, sound_sr * 0.5)] = 0
sample *= (-1)**self.rng.randint(0, 10)
return torch.from_numpy(sample[np.newaxis]).float(), np.argmax(label)
def train(rank, world_size):
ds = torch.utils.data.DataLoader(tch.datasets.MixUpDataset(Dataset(args.train, True, rank, rng), 0.3),
batch_size//world_size,
shuffle=True,
drop_last=True,
num_workers=6,
pin_memory=True)
test_dst = Dataset(args.test, False, rank)
dst = torch.utils.data.DataLoader(test_dst,
batch_size=batch_size//world_size,
num_workers=8,
shuffle=True)
print('Dataset loaded')
model = torch.nn.parallel.DistributedDataParallel(UpDimV3(num_classes).to(rank), [rank], rank)
def accuracy(preds, labels):
with torch.no_grad():
id_preds = torch.max(preds, 1)[1]
id_labels = torch.max(labels, 1)[1]
return 100 * (id_preds == id_labels).sum().item() / len(labels)
def train(batch):
x, y = batch
pred = model(x)
loss = tch.loss.continuous_cross_entropy(pred, y)
loss.backward()
return {'loss': loss, 'pred': pred, 'accuracy': accuracy(pred, y)}
def test(batch):
x, y = batch
pred = model(x)
y = torch.eye(num_classes)[y].to(y.device)
loss = tch.loss.continuous_cross_entropy(pred, y)
return {'loss': loss, 'pred': pred, 'accuracy': accuracy(pred, y)}
recipe = tch.recipes.TrainAndTest(model,
train,
test,
ds,
dst,
test_every=1000,
log_every=64,
checkpoint='model3_quebec' + args.run_name if rank == 0 else None,
visdom_env=None,
key_best=(lambda x: x['test_loop']['callbacks']['state']['metrics']['accuracy']))
opt = tch.optim.RAdamW(model.parameters(),
5e-3,
weight_decay=0.005)
recipe.callbacks.add_callbacks([
tcb.Optimizer(opt, log_lr=True, clip_grad_norm=5),
tcb.LRSched(torch.optim.lr_scheduler.MultiStepLR(opt, list(range(20,80)), (1-1/np.exp(1))/2), metric=None),
tcb.WindowedMetricAvg('loss'),
tcb.WindowedMetricAvg('accuracy'),
])
recipe.callbacks.add_epilogues([
tcb.TensorboardLogger(log_dir='UpDimV3_quebec' + args.run_name + '/train' if rank == 0 else None, log_every=64),])
recipe.test_loop.callbacks.add_callbacks([
tcb.EpochMetricAvg('loss', False),
tcb.EpochMetricAvg('accuracy', False),
# tcb.ConfusionMatrix(test_dst.get_class_names(), True)
])
recipe.test_loop.callbacks.add_epilogues([
tcb.TensorboardLogger(log_dir='UpDimV3_quebec' + args.run_name + '/test' if rank == 0 else None,
log_every=-1),])
recipe.to(rank)
if args.hot_weight != '':
ckpt = torch.load(args.hot_weight)
del ckpt['callbacks']['callbacks']['Counter_prologue_0']
del ckpt['callbacks']['callbacks']['LRSched_middle_0']
del ckpt['test_loop']
del ckpt['callbacks']['state']['metrics']['lr_0']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['lr'] = opt.param_groups[0]['lr']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['eps'] = opt.param_groups[0]['eps']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['weight_decay'] = opt.param_groups[0]['weight_decay']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['betas'] = opt.param_groups[0]['betas']
recipe.load_state_dict(ckpt)
if args.weight != '':
ckpt = torch.load(args.weight)
recipe.load_state_dict(ckpt)
recipe.run(80)
if __name__ == '__main__':
tch.utils.parallel_run(train)
import torch
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
import scipy.signal as sg
import soundfile as sf
import argparse
import numpy as np
from tqdm import tqdm, trange
from sam import SAMSGD
import torchelie.utils as tu
from UpDimV3 import UpDimV3
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Train an UpDimV3 model with SAM optimizer""")
parser.add_argument("--train",
default=['xeno_train_1500.npy'],
type=str, nargs='*', help="Paths to train file(s)")
parser.add_argument("--test",
default=['JF_propres_43.npy'],
type=str, nargs='*', help="Paths to test file(s)")
parser.add_argument("--rng", action='store_false', help='unset rng state')
parser.add_argument("--weight", type=str, default='', help="recipy weight for resume")
parser.add_argument("--from_ckpt", type=str, default='', help="Model weight for resume")
parser.add_argument("--hot_weight", type=str, default='', help="Model weight for hot restart")
parser.add_argument("--run_name", type=str, default='', help='suffix of the run name')
args = parser.parse_args()
batch_size = 32
sound_len = 7
sound_sr = 22050
num_feature = sound_len * sound_sr
num_classes = 43
if args.rng:
rng = np.random.RandomState(42)
else:
rng = np.random.RandomState()
def pink_noise(size, rng, ncols=16, axis=-1):
"""Generates pink noise using the Voss-McCartney algorithm.
size: either a tuple of int or an int. If an int : number of sample to generate. If a tuple: shape of the return array.
ncols: number of random sources to add. Should be high enough so that num_samples*0.5**(ncols-2) is near zero.
axis: axis which contains the sound samples. Generate white noise otherwise.
returns: NumPy array of shape size
"""
if type(size) is not tuple:
size = (size,)
array = rng.rand(*size)
assert -len(size) <= axis < len(size)
axis %= len(size)
axis += 1
# the total number of changes is nrows
cols = rng.geometric(0.5, size)
cols[cols >= ncols] = 0
cols = (1. * (np.arange(1, ncols).reshape((-1,) + len(size) * (1,)) == cols)).swapaxes(axis, -1)
cols[..., 0] = 1.
cols = np.cumsum(cols).reshape(cols.shape).astype(int).swapaxes(axis, -1)
array = np.concatenate([array[np.newaxis], rng.rand(cols.max() + 1)[cols]], axis=0).sum(0)
return array
class Dataset:
def __init__(self, paths, train, rank, rng=None):
self.path = np.concatenate([np.load(p) for p in paths])
self.wav = len(self.path) * [None]
self.labels = len(self.path) * [None]
for i, p in enumerate(tqdm(self.path)):
self.wav[i] = p
self.labels[i] = p.split('/')[-1][:4]
self.wav = np.array([x for x in self.wav if x is not None])
self.labels = np.array([x for x in self.labels if x is not None])
self.names, self.labels = np.unique(self.labels, return_inverse=True)
assert len(self.wav) == len(self.labels)
print(f'Found {len(self.wav)} wav')
assert len(self.names) == num_classes
self.labels = np.eye(num_classes)[self.labels]
self.classes = self.names
self.train = train
self.rank = rank
self.rng = rng
self.wind = np.load('/nfs/NAS4/anatole/data/quebec/sons_abiotiques/sons_abiotiques.npy').squeeze()
self.test_loaded = False
if not self.train:
print(f'Loading test dataset')
self.test_tuples = self.__len__() * [(None, None)]
for i in trange(self.__len__()):
self.test_tuples[i] = self.__getitem__(i)
self.test_loaded = True
def __len__(self):
return len(self.wav)
def get_class_names(self):
return self.names
def __getitem__(self, i):
if self.test_loaded and not self.train:
return self.test_tuples[i]
p = self.wav[i]
label = self.labels[i]
sample, sr = sf.read(p, always_2d=True)
sample = sample[:, 0]
if sr != sound_sr:
sample = sg.resample(sample, int(len(sample)*sound_sr/sr))
if self.train:
sample = sg.resample(sample, int(np.clip(len(sample) * rng.normal(1, 0.025, 1),
num_feature+10, 1.10*len(sample))))
p = self.rng.randint(0, len(sample) - num_feature)
r = self.rng.randint(0, 100)
if r < 20:
rev = np.zeros(sound_sr)
if r % 3:
rev += self.rng.normal(0, 0.001, len(rev))
rev[0] = 1
for k in range(r%4):
rev[self.rng.randint(0.001*sound_sr, len(rev))] = self.rng.uniform(0, 1)
sample = torch.nn.functional.conv1d(torch.Tensor(sample[None, None]).to(self.rank),
torch.Tensor(np.flip(rev).copy()[None, None]).to(self.rank),
padding=len(rev)//2+1).cpu().numpy().squeeze()
sample = sample[p:p+num_feature]
else:
sample = sample[:num_feature]
sample -= sample.mean(-1,keepdims=True)
sample /= sample.std(-1,keepdims=True) + 1e-18
if self.train:
sample += self.rng.normal(0, 1, num_feature) * 10**self.rng.uniform(-2.5, -0.12, 1)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))\
(np.cumsum(self.rng.normal(0, 1, num_feature))) * 10**self.rng.uniform(-2.5, -0.5, 1)
sample += ((lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))
(pink_noise(num_feature, self.rng))) * 10**self.rng.uniform(-4, -1, 1)
wind, sr_wind = sf.read(self.rng.choice(self.wind), always_2d=True)
wind = wind[:, 0]
if sr != sound_sr:
wind = sg.resample(wind, int(len(wind) * sound_sr / sr))
wind = sg.resample(wind, int(np.clip(len(wind) * rng.normal(1, 0.025, 1),
num_feature + 10, 1.10 * len(wind))))
p_wind = self.rng.randint(0, len(wind) - num_feature)
sample += (lambda x: x-x.mean(-1,keepdims=True)/x.std(-1,keepdims=True))(wind[p_wind:p_wind + num_feature])\
* 10**self.rng.uniform(-2, -0.05, 1) * (-1)**self.rng.randint(0, 10)
sample -= sample.mean(-1, keepdims=True)
sample /= sample.std(-1, keepdims=True)
p = self.rng.randint(0, len(sample) - int(sound_sr * 0.5))
sample[p:p+self.rng.randint(0, sound_sr * 0.5)] = 0
sample *= (-1)**self.rng.randint(0, 10)
return torch.from_numpy(sample[np.newaxis]).float(), np.argmax(label)
def train(rank, world_size):
ds = torch.utils.data.DataLoader(tch.datasets.MixUpDataset(Dataset(args.train, True, rank, rng), 0.3),
batch_size//world_size,
shuffle=True,
drop_last=True,
num_workers=6,
pin_memory=True)
test_dst = Dataset(args.test, False, rank)
dst = torch.utils.data.DataLoader(test_dst,
batch_size=batch_size//world_size,
num_workers=8,
shuffle=True)
print('Dataset loaded')
model = UpDimV3(num_classes)
if args.from_ckpt != '':
model.load_state_dict(
torch.load(args.from_ckpt, map_location='cuda:' +
str(rank))['model'])
model = torch.nn.parallel.DistributedDataParallel(model.to(rank), [rank], rank)
def accuracy(preds, labels):
with torch.no_grad():
id_preds = torch.max(preds, 1)[1]
id_labels = torch.max(labels, 1)[1]
return 100 * (id_preds == id_labels).sum().item() / len(labels)
opt = SAMSGD(model.parameters(),
lr=5e-3,
weight_decay=0.005,
momentum=0.9)
def train(batch):
x, y = batch
def closure():
opt.zero_grad()
pred = model(x)
loss = tch.loss.continuous_cross_entropy(pred, y)
loss.backward()
return loss, pred
#norm = torch.nn.utils.clip_grad_norm_((p for pg in opt.param_groups for p in pg['params']), 5)
loss, pred = opt.step(closure)
return {'loss': loss, 'pred': pred, 'accuracy': accuracy(pred, y), 'grad_norm':0}
def test(batch):
x, y = batch
pred = model(x)
y = torch.eye(num_classes)[y].to(y.device)
loss = tch.loss.continuous_cross_entropy(pred, y)
return {'loss': loss, 'pred': pred, 'accuracy': accuracy(pred, y)}
recipe = tch.recipes.TrainAndTest(model,
train,
test,
ds,
dst,
test_every=1000,
log_every=64,
checkpoint='model3sam_quebec' + args.run_name if rank == 0 else None,
visdom_env=None,
key_best=(lambda x: x['test_loop']['callbacks']['state']['metrics']['accuracy']))
class OptimizerLog(tu.AutoStateDict):
def __init__(self,
opt,
clip_grad_norm=None,
log_lr=False,
log_mom=False):
super(OptimizerLog, self).__init__()
self.opt = opt
self.log_lr = log_lr
self.log_mom = log_mom
self.clip_grad_norm = clip_grad_norm
def on_batch_start(self, state):
if self.log_lr:
for i in range(len(self.opt.param_groups)):
pg = self.opt.param_groups[i]
state['metrics']['lr_' + str(i)] = pg['lr']
if self.log_mom:
for i in range(len(self.opt.param_groups)):
pg = self.opt.param_groups[i]
if 'momentum' in pg:
state['metrics']['mom_' + str(i)] = pg['momentum']
elif 'betas' in pg:
state['metrics']['mom_' + str(i)] = pg['betas'][0]
recipe.callbacks.add_callbacks([
OptimizerLog(opt, log_lr=True, clip_grad_norm=5),
tcb.LRSched(torch.optim.lr_scheduler.MultiStepLR(opt, list(range(10,80)), 0.95), metric=None),
tcb.WindowedMetricAvg('loss'),
tcb.WindowedMetricAvg('accuracy'),
tcb.WindowedMetricAvg('grad_norm'),
])
recipe.callbacks.add_epilogues([
tcb.TensorboardLogger(log_dir='UpDimV3SAM_quebec' + args.run_name + '/train' if rank == 0 else None, log_every=64),])
recipe.test_loop.callbacks.add_callbacks([
tcb.EpochMetricAvg('loss', False),
tcb.EpochMetricAvg('accuracy', False),
# tcb.ConfusionMatrix(test_dst.get_class_names(), True)
])
recipe.test_loop.callbacks.add_epilogues([
tcb.TensorboardLogger(log_dir='UpDimV3SAM_quebec' + args.run_name + '/test' if rank == 0 else None,
log_every=-1),])
recipe.to(rank)
if args.hot_weight != '':
ckpt = torch.load(args.hot_weight)
del ckpt['callbacks']['callbacks']['Counter_prologue_0']
del ckpt['callbacks']['callbacks']['LRSched_middle_0']
del ckpt['test_loop']
del ckpt['callbacks']['state']['metrics']['lr_0']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['lr'] = opt.param_groups[0]['lr']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['eps'] = opt.param_groups[0]['eps']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['weight_decay'] = opt.param_groups[0]['weight_decay']
ckpt['callbacks']['callbacks']['Optimizer_middle_0']['opt']['param_groups'][0]['betas'] = opt.param_groups[0]['betas']
recipe.load_state_dict(ckpt)
if args.weight != '':
ckpt = torch.load(args.weight)
recipe.load_state_dict(ckpt)
recipe.run(80)
if __name__ == '__main__':
tch.utils.parallel_run(train)
sam.py 0 → 100644
from typing import Iterable
import torch
from torch.optim._multi_tensor import SGD
__all__ = ["SAMSGD"]
class SAMSGD(SGD):
""" SGD wrapped with Sharp-Aware Minimization
Args:
params: tensors to be optimized
lr: learning rate
momentum: momentum factor
dampening: damping factor
weight_decay: weight decay factor
nesterov: enables Nesterov momentum
rho: neighborhood size
"""
def __init__(self,
params: Iterable[torch.Tensor],
lr: float,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov: bool = False,
rho: float = 0.05,
):
if rho <= 0:
raise ValueError(f"Invalid neighborhood size: {rho}")
super().__init__(params, lr, momentum, dampening, weight_decay, nesterov)
# todo: generalize this
if len(self.param_groups) > 1:
raise ValueError("Not supported")
self.param_groups[0]["rho"] = rho
@torch.no_grad()
def step(self,
closure
) -> torch.Tensor:
"""
Args:
closure: A closure that reevaluates the model and returns the loss.
Returns: the loss value evaluated on the original point
"""
closure = torch.enable_grad()(closure)
loss, pred = closure()
loss = loss.detach()
pred = pred.detach()
for group in self.param_groups:
grads = []
params_with_grads = []
rho = group['rho']
# update internal_optim's learning rate
for p in group['params']:
if p.grad is not None:
# without clone().detach(), p.grad will be zeroed by closure()
grads.append(p.grad.clone().detach())
params_with_grads.append(p)
device = grads[0].device
# compute \hat{\epsilon}=\rho/\norm{g}\|g\|
grad_norm = torch.stack([g.detach().norm(2).to(device) for g in grads]).norm(2)
epsilon = grads # alias for readability
torch._foreach_mul_(epsilon, rho / grad_norm)
# virtual step toward \epsilon
torch._foreach_add_(params_with_grads, epsilon)
# compute g=\nabla_w L_B(w)|_{w+\hat{\epsilon}}
closure()
# virtual step back to the original point
torch._foreach_sub_(params_with_grads, epsilon)
super().step()
return loss, pred
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment