Skip to content
Snippets Groups Projects
Select Git revision
  • 529d0a4e8c45c6f1b19bb642712fcfec8ccc553a
  • master default protected
2 results

setup.py

Blame
  • frontend.py 7.31 KiB
    
    import torch
    import numpy as np
    
    
    class Log1p(torch.nn.Module):
        """
        Applies log(1 + 10**a * x), with scale fixed or trainable.
        """
        def __init__(self, a=0, trainable=False):
            super(Log1p, self).__init__()
            if trainable:
                a = torch.nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype()))
            self.a = a
            self.trainable = trainable
    
        def forward(self, x):
            if self.trainable or self.a != 0:
                x = torch.log1p(10 ** self.a * x)
            return x
    
        def extra_repr(self):
            return 'trainable={}'.format(repr(self.trainable))
    
    
    class PCENLayer(torch.nn.Module):
        def __init__(self, num_bands,
                     s=0.025,
                     alpha=.8,
                     delta=10.,
                     r=.25,
                     eps=1e-6,
                     init_smoother_from_data=True):
            super(PCENLayer, self).__init__()
            self.log_s = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands)) * s))
            self.log_alpha = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * alpha))
            self.log_delta = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * delta))
            self.log_r = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * r))
            self.eps = torch.tensor(eps)
            self.init_smoother_from_data = init_smoother_from_data
    
        def forward(self, input): # expected input (batch, channel, freqs, time)
            init = input[:,:,:,0]  # initialize the filter with the first frame
            if not self.init_smoother_from_data:
                init = torch.zeros(init.shape)  # initialize with zeros instead
    
            filtered = [init]
            for iframe in range(1, input.shape[-1]):
                filtered.append( (1-torch.exp(self.log_s)) * filtered[iframe-1] + torch.exp(self.log_s) * input[:,:,:,iframe] )
            filtered = torch.stack(filtered).permute(1,2,3,0)
    
            # stable reformulation due to Vincent Lostanlen; original formula was:
            alpha, delta, r = torch.exp(self.log_alpha), torch.exp(self.log_delta), torch.exp(self.log_r)
            return (input / (self.eps + filtered)**alpha + delta)**r - delta**r
    #        filtered = exp(-alpha * (log(self.eps) + log(1 + filtered / self.eps)))
    #        return (input * filtered + delta)**r - delta**r
    
    
    def create_mel_filterbank(sample_rate, frame_len, num_bands, min_freq, max_freq,
                              norm=True, crop=False):
        """
        Creates a mel filterbank of `num_bands` triangular filters, with the first
        filter starting at `min_freq` and the last one stopping at `max_freq`.
        Returns the filterbank as a matrix suitable for a dot product against
        magnitude spectra created from samples at a sample rate of `sample_rate`
        with a window length of `frame_len` samples. If `norm`, will normalize
        each filter by its area. If `crop`, will exclude rows that exceed the
        maximum frequency and are therefore zero.
        """
        # mel-spaced peak frequencies
        min_mel = 1127 * np.log1p(min_freq / 700.0)
        max_mel = 1127 * np.log1p(max_freq / 700.0)
        peaks_mel = torch.linspace(min_mel, max_mel, num_bands + 2)
        peaks_hz = 700 * (torch.expm1(peaks_mel / 1127))
        peaks_bin = peaks_hz * frame_len / sample_rate
    
        # create filterbank
        input_bins = (frame_len // 2) + 1
        if crop:
            input_bins = min(input_bins,
                             int(np.ceil(max_freq * frame_len /
                                         float(sample_rate))))
        x = torch.arange(input_bins, dtype=peaks_bin.dtype)[:, np.newaxis]
        l, c, r = peaks_bin[0:-2], peaks_bin[1:-1], peaks_bin[2:]
        # triangles are the minimum of two linear functions f(x) = a*x + b
        # left side of triangles: f(l) = 0, f(c) = 1 -> a=1/(c-l), b=-a*l
        tri_left = (x - l) / (c - l)
        # right side of triangles: f(c) = 1, f(r) = 0 -> a=1/(c-r), b=-a*r
        tri_right = (x - r) / (c - r)
        # combine by taking the minimum of the left and right sides
        tri = torch.min(tri_left, tri_right)
        # and clip to only keep positive values
        filterbank = torch.clamp(tri, min=0)
    
        # normalize by area
        if norm:
            filterbank /= filterbank.sum(0)
    
        return filterbank
    
    
    class MelFilter(torch.nn.Module):
        def __init__(self, sample_rate, winsize, num_bands, min_freq, max_freq):
            super(MelFilter, self).__init__()
            melbank = create_mel_filterbank(sample_rate, winsize, num_bands,
                                            min_freq, max_freq, crop=True)
            self.register_buffer('bank', melbank)
    
        def forward(self, x):
            x = x.transpose(-1, -2)  # put fft bands last
            x = x[..., :self.bank.shape[0]]  # remove unneeded fft bands
            x = x.matmul(self.bank)  # turn fft bands into mel bands
            x = x.transpose(-1, -2)  # put time last
            return x
    
        def state_dict(self, destination=None, prefix='', keep_vars=False):
            result = super(MelFilter, self).state_dict(destination, prefix, keep_vars)
            # remove all buffers; we use them as cached constants
            for k in self._buffers:
                del result[prefix + k]
            return result
    
        def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
            # ignore stored buffers for backwards compatibility
            for k in self._buffers:
                state_dict.pop(prefix + k, None)
            # temporarily hide the buffers; we do not want to restore them
            buffers = self._buffers
            self._buffers = {}
            result = super(MelFilter, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
            self._buffers = buffers
            return result
    
    
    class STFT(torch.nn.Module):
        def __init__(self, winsize, hopsize, complex=False):
            super(STFT, self).__init__()
            self.winsize = winsize
            self.hopsize = hopsize
            self.register_buffer('window',
                                 torch.hann_window(winsize, periodic=False))
            self.complex = complex
    
        def state_dict(self, destination=None, prefix='', keep_vars=False):
            result = super(STFT, self).state_dict(destination, prefix, keep_vars)
            # remove all buffers; we use them as cached constants
            for k in self._buffers:
                del result[prefix + k]
            return result
    
        def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
            # ignore stored buffers for backwards compatibility
            for k in self._buffers:
                state_dict.pop(prefix + k, None)
            # temporarily hide the buffers; we do not want to restore them
            buffers = self._buffers
            self._buffers = {}
            result = super(STFT, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
            self._buffers = buffers
            return result
    
        def forward(self, x):
            x = x.unsqueeze(1)
            # we want each channel to be treated separately, so we mash
            # up the channels and batch size and split them up afterwards
            batchsize, channels = x.shape[:2]
            x = x.reshape((-1,) + x.shape[2:])
            # we apply the STFT
            x = torch.stft(x, self.winsize, self.hopsize, window=self.window,
                           center=False, return_complex=False)
            # we compute magnitudes, if requested
            if not self.complex:
                x = x.norm(p=2, dim=-1)
            # restore original batchsize and channels in case we mashed them
            x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
            return x