diff --git a/new_specie/models.py b/new_specie/models.py index f7dcc80573136e04eed5839b15caf38a11c030c7..d5d0bb85352d1682ab8fcd32e2c112ff18bbc70b 100755 --- a/new_specie/models.py +++ b/new_specie/models.py @@ -1,7 +1,7 @@ import torchvision.models as torchmodels from torch import nn import utils as u -from filterbank import STFT, MelFilter, Log1p +from filterbank import STFT, MelFilter, Log1p, MedFilt vgg16 = torchmodels.vgg16(weights=torchmodels.VGG16_Weights.DEFAULT) vgg16 = vgg16.features[:13] @@ -10,6 +10,15 @@ for nm, mod in vgg16.named_modules(): setattr(vgg16, nm, nn.AvgPool2d(2 ,2)) +frontend_medfilt = lambda sr, nfft, sampleDur, n_mel : nn.Sequential( + STFT(nfft, int((sampleDur*sr - nfft)/128)), + MelFilter(sr, nfft, n_mel, sr//nfft, sr//2), + Log1p(7, trainable=False), + nn.InstanceNorm2d(1), + MedFilt(), + u.Croper2D(n_mel, 128) +) + frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential( STFT(nfft, int((sampleDur*sr - nfft)/128)), MelFilter(sr, nfft, n_mel, sr//nfft, sr//2), diff --git a/new_specie/train_AE.py b/new_specie/train_AE.py index 15f63ebe99eb00e0d2a958b21476a1c60e6a673d..34100fa7f138c4f93a9fdbbd5c356c9fcec8acee 100755 --- a/new_specie/train_AE.py +++ b/new_specie/train_AE.py @@ -15,6 +15,8 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") +parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)") +parser.set_defaults(feature=False) parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck') args = parser.parse_args() @@ -22,7 +24,7 @@ args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32" assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)" -frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) +frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(frontend, encoder, decoder).to(device) @@ -36,7 +38,7 @@ loss_fun = torch.nn.MSELoss() # data loader df = pd.read_csv(args.detections) -loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn) +loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn) modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc' step, writer = 0, SummaryWriter('runs/'+modelname) diff --git a/new_specie/utils.py b/new_specie/utils.py index 09acb635fbd5b8f5b07d63910e52731cf8d5cb38..9697d45954936796fef6ce48982c200a0c4d4944 100755 --- a/new_specie/utils.py +++ b/new_specie/utils.py @@ -56,4 +56,4 @@ class Croper2D(nn.Module): super(Croper2D, self).__init__() self.shape = shape def forward(self, x): - return x[:,:,:self.shape[0],:self.shape[1]] + return x[:,:,:self.shape[0],(x.shape[-1] - self.shape[1])//2:-(x.shape[-1] - self.shape[1])//2]