Skip to content
Snippets Groups Projects
Commit 726ea6dc authored by Paul Best's avatar Paul Best
Browse files

add medfilt

parent 362b4aa4
No related branches found
No related tags found
No related merge requests found
import torchvision.models as torchmodels
from torch import nn
import utils as u
from filterbank import STFT, MelFilter, Log1p
from filterbank import STFT, MelFilter, Log1p, MedFilt
vgg16 = torchmodels.vgg16(weights=torchmodels.VGG16_Weights.DEFAULT)
vgg16 = vgg16.features[:13]
......@@ -10,6 +10,15 @@ for nm, mod in vgg16.named_modules():
setattr(vgg16, nm, nn.AvgPool2d(2 ,2))
frontend_medfilt = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
Log1p(7, trainable=False),
nn.InstanceNorm2d(1),
MedFilt(),
u.Croper2D(n_mel, 128)
)
frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
......
......@@ -15,6 +15,8 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.set_defaults(feature=False)
parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
args = parser.parse_args()
......@@ -22,7 +24,7 @@ args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32"
assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)"
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
......@@ -36,7 +38,7 @@ loss_fun = torch.nn.MSELoss()
# data loader
df = pd.read_csv(args.detections)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc'
step, writer = 0, SummaryWriter('runs/'+modelname)
......
......@@ -56,4 +56,4 @@ class Croper2D(nn.Module):
super(Croper2D, self).__init__()
self.shape = shape
def forward(self, x):
return x[:,:,:self.shape[0],:self.shape[1]]
return x[:,:,:self.shape[0],(x.shape[-1] - self.shape[1])//2:-(x.shape[-1] - self.shape[1])//2]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment