Skip to content
Snippets Groups Projects
utils.py 7.9 KiB
Newer Older
Paul Best's avatar
Paul Best committed
from typing import Union, Tuple
Paul Best's avatar
Paul Best committed
import soundfile as sf
import torch
from torch import nn
from torch.utils import data
import numpy as np
from scipy.signal import resample


def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

class Dataset(data.Dataset):
    def __init__(self, df, audiopath, sr, sampleDur, retType=False):
        super(Dataset, self)
        self.audiopath, self.df, self.retType, self.sr, self.sampleDur = audiopath, df, retType, sr, sampleDur

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
Paul Best's avatar
Paul Best committed
        info = sf.info(self.audiopath+row.fn)
        dur, fs = info.duration, info.samplerate
        start = int(np.clip(row.pos - self.sampleDur/2, 0, max(0, dur - self.sampleDur)) * fs)
        sig, fs = sf.read(self.audiopath+row.fn, start=start, stop=start + int(self.sampleDur*fs))
Paul Best's avatar
Paul Best committed
        if sig.ndim == 2:
            sig = sig[:,0]
        if len(sig) < self.sampleDur * fs:
Paul Best's avatar
ok  
Paul Best committed
            sig = np.concatenate([sig, np.zeros(int(self.sampleDur * fs) - len(sig))])
Paul Best's avatar
Paul Best committed
        if fs != self.sr:
            sig = resample(sig, int(len(sig)/fs*self.sr))
Paul Best's avatar
Paul Best committed
        if np.std(sig) == 0:
            print('wrong sig '+str(row.name))
Paul Best's avatar
Paul Best committed
        if self.retType:
Paul Best's avatar
Paul Best committed
            return torch.Tensor(norm(sig)).float(), row.name, row.label
Paul Best's avatar
Paul Best committed
        else:
            return torch.Tensor(norm(sig)).float(), row.name


def norm(arr):
    return (arr - np.mean(arr) ) / np.std(arr)


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        return x.view(x.shape[0], -1)


class Reshape(nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
Paul Best's avatar
Paul Best committed
        return x.reshape(x.shape[0], *self.shape)

class Croper2D(nn.Module):
    def __init__(self, *shape):
        super(Croper2D, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x[:,:,:self.shape[0],:self.shape[1]]




class VQ(nn.Module):
    """
    Quantization layer from *Neural Discrete Representation Learning*
    Args:
        latent_dim (int): number of features along which to quantize
        num_tokens (int): number of tokens in the codebook
        dim (int): dimension along which to quantize
        return_indices (bool): whether to return the indices of the quantized
            code points
    """
    embedding: nn.Embedding
    dim: int
    commitment: float
    initialized: torch.Tensor
    return_indices: bool
    init_mode: str

    def __init__(self,
                 latent_dim: int,
                 num_tokens: int,
                 dim: int = 1,
                 commitment: float = 0.25,
                 init_mode: str = 'normal',
                 return_indices: bool = True,
                 max_age: int = 1000):
        super(VQ, self).__init__()
        self.embedding = nn.Embedding(num_tokens, latent_dim)
        nn.init.normal_(self.embedding.weight, 0, 1.1)
        self.dim = dim
        self.commitment = commitment
        self.register_buffer('initialized', torch.Tensor([0]))
        self.return_indices = return_indices
        assert init_mode in ['normal', 'first']
        self.init_mode = init_mode
        self.register_buffer('age', torch.empty(num_tokens).fill_(max_age))
        self.max_age = max_age

    def update_usage(self, indices):
        with torch.no_grad():
            self.age += 1
            if torch.distributed.is_initialized():
                n_gpu = torch.distributed.get_world_size()
                all_indices = [torch.empty_like(indices) for _ in range(n_gpu)]
                torch.distributed.all_gather(all_indices, indices)
                indices = torch.cat(all_indices)
            used = torch.unique(indices)
            self.age[used] = 0

    def resample_dead(self, x):
        with torch.no_grad():
            dead = torch.nonzero(self.age > self.max_age, as_tuple=True)[0]
            if len(dead) == 0:
                return

            print(f'{len(dead)} dead codes resampled')
            x_flat = x.view(-1, x.shape[-1])
            emb_weight = self.embedding.weight.data
            emb_weight[dead[:len(x_flat)]] = x_flat[torch.randperm(
                len(x_flat))[:len(dead)]].to(emb_weight.dtype)
            self.age[dead[:len(x_flat)]] = 0

            if torch.distributed.is_initialized():
                torch.distributed.broadcast(emb_weight, 0)

    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass
        Args:
            x (tensor): input tensor
        Returns:
            quantized tensor, or (quantized tensor, indices) if
            `self.return_indices`
        """
        dim = self.dim
        nb_codes = self.embedding.weight.shape[0]

        codebook = self.embedding.weight
        if (self.init_mode == 'first' and self.initialized.item() == 0 and
                self.training):
            n_proto = self.embedding.weight.shape[0]

            ch_first = x.transpose(dim, -1).contiguous().view(-1, x.shape[dim])
            n_samples = ch_first.shape[0]
            idx = torch.randint(0, n_samples, (n_proto,))[:nb_codes]
            self.embedding.weight.data.copy_(ch_first[idx])
            self.initialized[:] = 1

        needs_transpose = dim != -1 or dim != x.dim() - 1
        if needs_transpose:
            x = x.transpose(-1, dim).contiguous()

        if self.training:
            self.resample_dead(x)

        codes, indices = quantize(x, codebook, self.commitment, self.dim)

        if self.training:
            self.update_usage(indices)

        if needs_transpose:
            codes = codes.transpose(-1, dim)
            indices = indices.transpose(-1, dim)

        if self.return_indices:
            return codes, indices
        else:
            return codes


from torch.autograd import Function


class VectorQuantization(Function):

    @staticmethod
    def compute_indices(inputs_orig, codebook):
        bi = []
        SZ = 10000
        for i in range(0, inputs_orig.size(0), SZ):
            inputs = inputs_orig[i:i + SZ]
            # NxK
            distances_matrix = torch.cdist(inputs, codebook)
            # Nx1
            indic = torch.min(distances_matrix, dim=-1)[1].unsqueeze(1)
            bi.append(indic)
        return torch.cat(bi, dim=0)

    @staticmethod
    def flatten(x):
        code_dim = x.size(-1)
        return x.view(-1, code_dim)

    @staticmethod
    def restore_shapes(codes, indices, target_shape):
        idx_shape = list(target_shape)
        idx_shape[-1] = 1
        return codes.view(*target_shape), indices.view(*idx_shape)

    @staticmethod
    def forward(ctx, inputs, codebook, commitment=0.25, dim=1):
        inputs_flat = VectorQuantization.flatten(inputs)
        indices = VectorQuantization.compute_indices(inputs_flat, codebook)
        codes = codebook[indices.view(-1), :]
        codes, indices = VectorQuantization.restore_shapes(
            codes, indices, inputs.shape)

        ctx.save_for_backward(codes, inputs, torch.tensor([float(commitment)]),
                              codebook, indices)
        ctx.mark_non_differentiable(indices)
        return codes, indices

    @staticmethod
    def backward(ctx, straight_through, unused_indices):
        codes, inputs, beta, codebook, indices = ctx.saved_tensors

        # TODO: figure out proper vq loss reduction
        # vq_loss = F.mse_loss(inputs, codes).detach()

        # gradient of vq_loss
        diff = 2 * (inputs - codes) / inputs.numel()

        commitment = beta.item() * diff

        code_disp = VectorQuantization.flatten(-diff)
        indices = VectorQuantization.flatten(indices)
        code_disp = (torch.zeros_like(codebook).index_add_(
            0, indices.view(-1), code_disp))
        return straight_through + commitment, code_disp, None, None


quantize = VectorQuantization.apply