Skip to content
Snippets Groups Projects
Commit 52770838 authored by paul.best's avatar paul.best
Browse files
parents 7ac24424 a413afb2
No related branches found
No related tags found
No related merge requests found
......@@ -72,26 +72,27 @@ frontend = {
'Mel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
nn.BatchNorm2d(1, affine=False),
nn.InstanceNorm2d(1),
u.Croper2D(n_mel, 128)
),
'logMel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, 0, sr//2),
Log1p(7, trainable=False),
nn.BatchNorm2d(1, affine=False),
nn.Instancenorm2d(1),
u.Croper2D(n_mel, 128)
),
'logSTFT': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
Log1p(7, trainable=False),
nn.BatchNorm2d(1, affine=False),
nn.InstanceNorm2d(1),
u.Croper2D(n_mel, 128)
),
'pcenMel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
PCENLayer(n_mel, trainable=False),
nn.InstanceNorm2d(1),
u.Croper2D(n_mel, 128)
)
}
......
......@@ -79,7 +79,7 @@ for epoch in range(100_000//len(loader)):
scheduler.step()
# Actual test
model.eval()
model[1:].eval()
with torch.no_grad():
encodings, idxs = [], []
for x, idx in tqdm(loader, desc='test '+str(step), leave=False):
......@@ -152,5 +152,5 @@ for epoch in range(100_000//len(loader)):
writer.add_histogram('K-Means Recalls ', np.array(recs), step)
df.drop('cluster', axis=1, inplace=True)
print('\r', end='')
model.train()
model[1:].train()
step += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment