Skip to content
Snippets Groups Projects
Commit a680d747 authored by Paul Best's avatar Paul Best
Browse files

switch to instance norm

parent 3f55dc05
Branches
No related tags found
No related merge requests found
...@@ -72,26 +72,27 @@ frontend = { ...@@ -72,26 +72,27 @@ frontend = {
'Mel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential( 'Mel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)), STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2), MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
nn.BatchNorm2d(1, affine=False), nn.InstanceNorm2d(1),
u.Croper2D(n_mel, 128) u.Croper2D(n_mel, 128)
), ),
'logMel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential( 'logMel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)), STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, 0, sr//2), MelFilter(sr, nfft, n_mel, 0, sr//2),
Log1p(7, trainable=False), Log1p(7, trainable=False),
nn.BatchNorm2d(1, affine=False), nn.Instancenorm2d(1),
u.Croper2D(n_mel, 128) u.Croper2D(n_mel, 128)
), ),
'logSTFT': lambda sr, nfft, sampleDur, n_mel : nn.Sequential( 'logSTFT': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)), STFT(nfft, int((sampleDur*sr - nfft)/128)),
Log1p(7, trainable=False), Log1p(7, trainable=False),
nn.BatchNorm2d(1, affine=False), nn.InstanceNorm2d(1),
u.Croper2D(n_mel, 128) u.Croper2D(n_mel, 128)
), ),
'pcenMel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential( 'pcenMel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)), STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2), MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
PCENLayer(n_mel), PCENLayer(n_mel),
nn.InstanceNorm2d(1),
u.Croper2D(n_mel, 128) u.Croper2D(n_mel, 128)
) )
} }
......
...@@ -79,7 +79,7 @@ for epoch in range(100_000//len(loader)): ...@@ -79,7 +79,7 @@ for epoch in range(100_000//len(loader)):
scheduler.step() scheduler.step()
# Actual test # Actual test
model.eval() model[1:].eval()
with torch.no_grad(): with torch.no_grad():
encodings, idxs = [], [] encodings, idxs = [], []
for x, idx in tqdm(loader, desc='test '+str(step), leave=False): for x, idx in tqdm(loader, desc='test '+str(step), leave=False):
...@@ -142,5 +142,5 @@ for epoch in range(100_000//len(loader)): ...@@ -142,5 +142,5 @@ for epoch in range(100_000//len(loader)):
writer.add_histogram('K-Means Recalls ', np.array(recs), step) writer.add_histogram('K-Means Recalls ', np.array(recs), step)
df.drop('cluster', axis=1, inplace=True) df.drop('cluster', axis=1, inplace=True)
print('\r', end='') print('\r', end='')
model.train() model[1:].train()
step += 1 step += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment