Skip to content
Snippets Groups Projects
filterbank.py 6.29 KiB
Newer Older
Paul Best's avatar
Paul Best committed
# Author : Jan Schlüter
from torch import nn
import torch
import numpy as np


def create_mel_filterbank(sample_rate, frame_len, num_bands, min_freq, max_freq,
                          norm=True, crop=False):
    """
    Creates a mel filterbank of `num_bands` triangular filters, with the first
    filter starting at `min_freq` and the last one stopping at `max_freq`.
    Returns the filterbank as a matrix suitable for a dot product against
    magnitude spectra created from samples at a sample rate of `sample_rate`
    with a window length of `frame_len` samples. If `norm`, will normalize
    each filter by its area. If `crop`, will exclude rows that exceed the
    maximum frequency and are therefore zero.
    """
    # mel-spaced peak frequencies
Paul Best's avatar
Paul Best committed
    min_mel = 1127 * np.log1p(min_freq / 7000.0)
    max_mel = 1127 * np.log1p(max_freq / 7000.0)
Paul Best's avatar
Paul Best committed
    peaks_mel = torch.linspace(min_mel, max_mel, num_bands + 2)
Paul Best's avatar
Paul Best committed
    peaks_hz = 7000 * (torch.expm1(peaks_mel / 1127))
Paul Best's avatar
Paul Best committed
    peaks_bin = peaks_hz * frame_len / sample_rate

    # create filterbank
    input_bins = (frame_len // 2) + 1
    if crop:
        input_bins = min(input_bins,
                         int(np.ceil(max_freq * frame_len /
                                     float(sample_rate))))
    x = torch.arange(input_bins, dtype=peaks_bin.dtype)[:, np.newaxis]
    l, c, r = peaks_bin[0:-2], peaks_bin[1:-1], peaks_bin[2:]
    # triangles are the minimum of two linear functions f(x) = a*x + b
    # left side of triangles: f(l) = 0, f(c) = 1 -> a=1/(c-l), b=-a*l
    tri_left = (x - l) / (c - l)
    # right side of triangles: f(c) = 1, f(r) = 0 -> a=1/(c-r), b=-a*r
    tri_right = (x - r) / (c - r)
    # combine by taking the minimum of the left and right sides
    tri = torch.min(tri_left, tri_right)
    # and clip to only keep positive values
    filterbank = torch.clamp(tri, min=0)

    # normalize by area
    if norm:
        filterbank /= filterbank.sum(0)

    return filterbank


Paul Best's avatar
Paul Best committed

Paul Best's avatar
Paul Best committed
class MelFilter(nn.Module):
    def __init__(self, sample_rate, winsize, num_bands, min_freq, max_freq):
        super(MelFilter, self).__init__()
        melbank = create_mel_filterbank(sample_rate, winsize, num_bands,
                                        min_freq, max_freq, crop=True)
        self.register_buffer('bank', melbank)

    def forward(self, x):
        x = x.transpose(-1, -2)  # put fft bands last
        x = x[..., :self.bank.shape[0]]  # remove unneeded fft bands
        x = x.matmul(self.bank)  # turn fft bands into mel bands
        x = x.transpose(-1, -2)  # put time last
        return x

Paul Best's avatar
Paul Best committed
    def state_dict_(self, destination=None, prefix='', keep_vars=False):
Paul Best's avatar
Paul Best committed
        result = super(MelFilter, self).state_dict(destination, prefix, keep_vars)
        # remove all buffers; we use them as cached constants
        for k in self._buffers:
            del result[prefix + k]
        return result

Paul Best's avatar
Paul Best committed
    def _load_from_state_dict_(self, state_dict, prefix, *args, **kwargs):
Paul Best's avatar
Paul Best committed
        # ignore stored buffers for backwards compatibility
        for k in self._buffers:
            state_dict.pop(prefix + k, None)
        # temporarily hide the buffers; we do not want to restore them
        buffers = self._buffers
        self._buffers = {}
        result = super(MelFilter, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
        self._buffers = buffers
        return result

class STFT(nn.Module):
    def __init__(self, winsize, hopsize, complex=False):
        super(STFT, self).__init__()
        self.winsize = winsize
        self.hopsize = hopsize
        self.register_buffer('window',
                             torch.hann_window(winsize, periodic=False))
        self.complex = complex

Paul Best's avatar
Paul Best committed
    def state_dict_(self, destination=None, prefix='', keep_vars=False):
Paul Best's avatar
Paul Best committed
        result = super(STFT, self).state_dict(destination, prefix, keep_vars)
        # remove all buffers; we use them as cached constants
        for k in self._buffers:
            del result[prefix + k]
        return result

Paul Best's avatar
Paul Best committed
    def _load_from_state_dict_(self, state_dict, prefix, *args, **kwargs):
Paul Best's avatar
Paul Best committed
        # ignore stored buffers for backwards compatibility
        for k in self._buffers:
            state_dict.pop(prefix + k, None)
        # temporarily hide the buffers; we do not want to restore them
        buffers = self._buffers
        self._buffers = {}
        result = super(STFT, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
        self._buffers = buffers
        return result

    def forward(self, x):
        x = x.unsqueeze(1)
        # we want each channel to be treated separately, so we mash
        # up the channels and batch size and split them up afterwards
        batchsize, channels = x.shape[:2]
        x = x.reshape((-1,) + x.shape[2:])
        # we apply the STFT
        x = torch.stft(x, self.winsize, self.hopsize, window=self.window,
                       center=False, return_complex=False)
        # we compute magnitudes, if requested
        if not self.complex:
            x = x.norm(p=2, dim=-1)
        # restore original batchsize and channels in case we mashed them
        x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
        return x



class TemporalBatchNorm(nn.Module):
    """
    Batch normalization of a (batch, channels, bands, time) tensor over all but
    the previous to last dimension (the frequency bands).
    """
    def __init__(self, num_bands):
        super(TemporalBatchNorm, self).__init__()
        self.bn = nn.BatchNorm1d(num_bands)

    def forward(self, x):
        shape = x.shape
        # squash channels into the batch dimension
        x = x.reshape((-1,) + x.shape[-2:])
        # pass through 1D batch normalization
        x = self.bn(x)
        # restore squashed dimensions
        return x.reshape(shape)

class Log1p(nn.Module):
    """
    Applies log(1 + 10**a * x), with scale fixed or trainable.
    """
    def __init__(self, a=0, trainable=False):
        super(Log1p, self).__init__()
        if trainable:
            a = nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype()))
        self.a = a
        self.trainable = trainable

    def forward(self, x):
        if self.trainable or self.a != 0:
            x = torch.log1p(10 ** self.a * x)
        return x

    def extra_repr(self):
        return 'trainable={}'.format(repr(self.trainable))