diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..68b2ada10bbbea1e02871e037a89d0940f42e765 --- /dev/null +++ b/model.py @@ -0,0 +1,229 @@ +from torch import nn +import torch +import numpy as np +from torch import tensor, nn, exp, log, ones, stack + + +class PCENLayer(nn.Module): + def __init__(self, num_bands, + s=0.025, + alpha=.8, + delta=10., + r=.25, + eps=1e-6, + init_smoother_from_data=True): + super(PCENLayer, self).__init__() + self.log_s = nn.Parameter( log(ones((1,1,num_bands)) * s)) + self.log_alpha = nn.Parameter( log(ones((1,1,num_bands,1)) * alpha)) + self.log_delta = nn.Parameter( log(ones((1,1,num_bands,1)) * delta)) + self.log_r = nn.Parameter( log(ones((1,1,num_bands,1)) * r)) + self.eps = tensor(eps) + self.init_smoother_from_data = init_smoother_from_data + + def forward(self, input): # expected input (batch, channel, freqs, time) + init = input[:,:,:,0] # initialize the filter with the first frame + if not self.init_smoother_from_data: + init = torch.zeros(init.shape) # initialize with zeros instead + + filtered = [init] + for iframe in range(1, input.shape[-1]): + filtered.append( (1-exp(self.log_s)) * filtered[iframe-1] + exp(self.log_s) * input[:,:,:,iframe] ) + filtered = stack(filtered).permute(1,2,3,0) + + # stable reformulation due to Vincent Lostanlen; original formula was: + alpha, delta, r = exp(self.log_alpha), exp(self.log_delta), exp(self.log_r) + return (input / (self.eps + filtered)**alpha + delta)**r - delta**r +# filtered = exp(-alpha * (log(self.eps) + log(1 + filtered / self.eps))) +# return (input * filtered + delta)**r - delta**r + + +def create_mel_filterbank(sample_rate, frame_len, num_bands, min_freq, max_freq, + norm=True, crop=False): + """ + Creates a mel filterbank of `num_bands` triangular filters, with the first + filter starting at `min_freq` and the last one stopping at `max_freq`. + Returns the filterbank as a matrix suitable for a dot product against + magnitude spectra created from samples at a sample rate of `sample_rate` + with a window length of `frame_len` samples. If `norm`, will normalize + each filter by its area. If `crop`, will exclude rows that exceed the + maximum frequency and are therefore zero. + """ + # mel-spaced peak frequencies + min_mel = 1127 * np.log1p(min_freq / 700.0) + max_mel = 1127 * np.log1p(max_freq / 700.0) + peaks_mel = torch.linspace(min_mel, max_mel, num_bands + 2) + peaks_hz = 700 * (torch.expm1(peaks_mel / 1127)) + peaks_bin = peaks_hz * frame_len / sample_rate + + # create filterbank + input_bins = (frame_len // 2) + 1 + if crop: + input_bins = min(input_bins, + int(np.ceil(max_freq * frame_len / + float(sample_rate)))) + x = torch.arange(input_bins, dtype=peaks_bin.dtype)[:, np.newaxis] + l, c, r = peaks_bin[0:-2], peaks_bin[1:-1], peaks_bin[2:] + # triangles are the minimum of two linear functions f(x) = a*x + b + # left side of triangles: f(l) = 0, f(c) = 1 -> a=1/(c-l), b=-a*l + tri_left = (x - l) / (c - l) + # right side of triangles: f(c) = 1, f(r) = 0 -> a=1/(c-r), b=-a*r + tri_right = (x - r) / (c - r) + # combine by taking the minimum of the left and right sides + tri = torch.min(tri_left, tri_right) + # and clip to only keep positive values + filterbank = torch.clamp(tri, min=0) + + # normalize by area + if norm: + filterbank /= filterbank.sum(0) + + return filterbank + + +class MelFilter(nn.Module): + def __init__(self, sample_rate, winsize, num_bands, min_freq, max_freq): + super(MelFilter, self).__init__() + melbank = create_mel_filterbank(sample_rate, winsize, num_bands, + min_freq, max_freq, crop=True) + self.register_buffer('bank', melbank) + + def forward(self, x): + x = x.transpose(-1, -2) # put fft bands last + x = x[..., :self.bank.shape[0]] # remove unneeded fft bands + x = x.matmul(self.bank) # turn fft bands into mel bands + x = x.transpose(-1, -2) # put time last + return x + + def state_dict(self, destination=None, prefix='', keep_vars=False): + result = super(MelFilter, self).state_dict(destination, prefix, keep_vars) + # remove all buffers; we use them as cached constants + for k in self._buffers: + del result[prefix + k] + return result + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # ignore stored buffers for backwards compatibility + for k in self._buffers: + state_dict.pop(prefix + k, None) + # temporarily hide the buffers; we do not want to restore them + buffers = self._buffers + self._buffers = {} + result = super(MelFilter, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs) + self._buffers = buffers + return result + +class STFT(nn.Module): + def __init__(self, winsize, hopsize, complex=False): + super(STFT, self).__init__() + self.winsize = winsize + self.hopsize = hopsize + self.register_buffer('window', + torch.hann_window(winsize, periodic=False)) + self.complex = complex + + def state_dict(self, destination=None, prefix='', keep_vars=False): + result = super(STFT, self).state_dict(destination, prefix, keep_vars) + # remove all buffers; we use them as cached constants + for k in self._buffers: + del result[prefix + k] + return result + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # ignore stored buffers for backwards compatibility + for k in self._buffers: + state_dict.pop(prefix + k, None) + # temporarily hide the buffers; we do not want to restore them + buffers = self._buffers + self._buffers = {} + result = super(STFT, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs) + self._buffers = buffers + return result + + def forward(self, x): + x = x.unsqueeze(1) + # we want each channel to be treated separately, so we mash + # up the channels and batch size and split them up afterwards + batchsize, channels = x.shape[:2] + x = x.reshape((-1,) + x.shape[2:]) + # we apply the STFT + x = torch.stft(x, self.winsize, self.hopsize, window=self.window, + center=False, return_complex=False) + # we compute magnitudes, if requested + if not self.complex: + x = x.norm(p=2, dim=-1) + # restore original batchsize and channels in case we mashed them + x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:]) + return x + + +HB_model = nn.Sequential(nn.Sequential( + STFT(512, 64), + MelFilter(11025, 512, 64, 100, 3000), + PCENLayer(64), + ), + nn.Sequential( + nn.Conv2d(1, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3,bias=False), + nn.BatchNorm2d(32), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 64, (16, 3), bias=False), + nn.BatchNorm2d(64), + nn.MaxPool2d((1,3)), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 256, (1, 9), bias=False), # for 80 bands + nn.BatchNorm2d(256), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(256, 64, 1, bias=False), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 1, 1, bias=False) + ) + ) + +delphi_model = nn.Sequential(nn.Sequential( + STFT(4096, 1024), + MelFilter(96000, 4096, 128, 3000, 30000), + PCENLayer(128), + ), + nn.Sequential( + nn.Conv2d(1, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3,bias=False), + nn.BatchNorm2d(32), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 64, (19, 3), bias=False), + nn.BatchNorm2d(64), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 256, (1, 9), bias=False), # for 80 bands + nn.BatchNorm2d(256), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(256, 64, 1, bias=False), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 1, 1, bias=False), + ) +) \ No newline at end of file