Skip to content
Snippets Groups Projects
Commit 726ea6dc authored by Paul Best's avatar Paul Best
Browse files

add medfilt

parent 362b4aa4
No related branches found
No related tags found
No related merge requests found
import torchvision.models as torchmodels import torchvision.models as torchmodels
from torch import nn from torch import nn
import utils as u import utils as u
from filterbank import STFT, MelFilter, Log1p from filterbank import STFT, MelFilter, Log1p, MedFilt
vgg16 = torchmodels.vgg16(weights=torchmodels.VGG16_Weights.DEFAULT) vgg16 = torchmodels.vgg16(weights=torchmodels.VGG16_Weights.DEFAULT)
vgg16 = vgg16.features[:13] vgg16 = vgg16.features[:13]
...@@ -10,6 +10,15 @@ for nm, mod in vgg16.named_modules(): ...@@ -10,6 +10,15 @@ for nm, mod in vgg16.named_modules():
setattr(vgg16, nm, nn.AvgPool2d(2 ,2)) setattr(vgg16, nm, nn.AvgPool2d(2 ,2))
frontend_medfilt = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
Log1p(7, trainable=False),
nn.InstanceNorm2d(1),
MedFilt(),
u.Croper2D(n_mel, 128)
)
frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential( frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)), STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2), MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
......
...@@ -15,6 +15,8 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec ...@@ -15,6 +15,8 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.set_defaults(feature=False)
parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck') parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
args = parser.parse_args() args = parser.parse_args()
...@@ -22,7 +24,7 @@ args = parser.parse_args() ...@@ -22,7 +24,7 @@ args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32" assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32"
assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)" assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)"
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device) model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
...@@ -36,7 +38,7 @@ loss_fun = torch.nn.MSELoss() ...@@ -36,7 +38,7 @@ loss_fun = torch.nn.MSELoss()
# data loader # data loader
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn) loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc' modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc'
step, writer = 0, SummaryWriter('runs/'+modelname) step, writer = 0, SummaryWriter('runs/'+modelname)
......
...@@ -56,4 +56,4 @@ class Croper2D(nn.Module): ...@@ -56,4 +56,4 @@ class Croper2D(nn.Module):
super(Croper2D, self).__init__() super(Croper2D, self).__init__()
self.shape = shape self.shape = shape
def forward(self, x): def forward(self, x):
return x[:,:,:self.shape[0],:self.shape[1]] return x[:,:,:self.shape[0],(x.shape[-1] - self.shape[1])//2:-(x.shape[-1] - self.shape[1])//2]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment