from torch import nn
from frontend import STFT, MelFilter, PCENLayer, Log1p


class depthwise_separable_conv1d(nn.Module):
    def __init__(self, nin, nout, kernel, padding=0, stride=1):
        super(depthwise_separable_conv1d, self).__init__()
        self.depthwise = nn.Conv1d(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin)
        self.pointwise = nn.Conv1d(nin, nout, kernel_size=1)
    def forward(self, x):
        out = self.depthwise(x.squeeze(1))
        out = self.pointwise(out)
        return out

class Dropout1d(nn.Module):
    def __init__(self, pdropout=.25):
        super(Dropout1d, self).__init__()
        self.dropout = nn.Dropout2d(pdropout)
    def forward(self, x):
        x = x.unsqueeze(-1)
        x = self.dropout(x)
        return x.squeeze(-1)

PHYSETER_NFEAT = 128
PHYSETER_KERNEL = 7
BALAENOPTERA_NFEAT = 128
BALAENOPTERA_KERNEL = 5

get = {
    'physeter': {
        'weights': 'stft_depthwise_ovs_128_k7_r1.stdc',
        'fs': 50000,
        'archi': nn.Sequential(
            STFT(512, 256),
            MelFilter(50000, 512, 64, 2000, 25000),
            Log1p(trainable=True),
            depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
            nn.BatchNorm1d(PHYSETER_NFEAT),
            nn.LeakyReLU(),
            Dropout1d(),
            depthwise_separable_conv1d(PHYSETER_NFEAT, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
            nn.BatchNorm1d(PHYSETER_NFEAT),
            nn.LeakyReLU(),
            Dropout1d(),
            depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2)
        ),
    },
    'balaenoptera': {
        'weights': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc',
        'fs': 200,
        'archi': nn.Sequential(
            STFT(256, 32),
            MelFilter(200, 256, 128, 0, 100),
            Log1p(trainable=True),
            depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
            nn.BatchNorm1d(BALAENOPTERA_NFEAT),
            nn.LeakyReLU(),
            Dropout1d(),
            depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
            nn.BatchNorm1d(BALAENOPTERA_NFEAT),
            nn.LeakyReLU(),
            Dropout1d(),
            depthwise_separable_conv1d(BALAENOPTERA_NFEAT, 1, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2)
        )
    },
    'megaptera' : {
        'weights': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
        'fs': 11025,
        'archi': nn.Sequential(
            nn.Sequential(
                STFT(512, 64),
                MelFilter(11025, 512, 64, 100, 3000),
                PCENLayer(64)
            ),
            nn.Sequential(
                nn.Conv2d(1, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3,bias=False),
                nn.BatchNorm2d(32),
                nn.MaxPool2d(3),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, (16, 3), bias=False),
                nn.BatchNorm2d(64),
                nn.MaxPool2d((1,3)),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(64, 256, (1, 9), bias=False),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(256, 64, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(64, 1, 1, bias=False)
            )
        )
    },
    'delphinid' : {
        'weights': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
        'fs': 96000,
        'archi': nn.Sequential(
            nn.Sequential(
                STFT(4096, 1024),
                MelFilter(96000, 4096, 128, 3000, 30000),
                PCENLayer(128)
            ),
            nn.Sequential(
                nn.Conv2d(1, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3,bias=False),
                nn.BatchNorm2d(32),
                nn.MaxPool2d(3),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, (19, 3), bias=False),
                nn.BatchNorm2d(64),
                nn.MaxPool2d(3),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(64, 256, (1, 9), bias=False),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(256, 64, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(64, 1, 1, bias=False),
                nn.MaxPool2d((6, 1))
            )
        )
    },
    'orcinus': {
        'weights': 'train_fe76f_00085_85_0',
        'fs': 22050,
        'archi': nn.Sequential(
            nn.Sequential(
                STFT(1024, 128),
                MelFilter(22050, 1024, 80, 300, 11025),
                PCENLayer(80)
            ),
            nn.Sequential(
                nn.Conv2d(1, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3,bias=False),
                nn.BatchNorm2d(32),
                nn.MaxPool2d(3),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, (19, 3), bias=False),
                nn.BatchNorm2d(64),
                nn.MaxPool2d(3),
                nn.LeakyReLU(0.01),
                nn.Dropout2d(p=.5),
                nn.Conv2d(64, 256, (1, 9), bias=False),  # for 80 bands
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.01),
                nn.Dropout2d(p=.5),
                nn.Conv2d(256, 64, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.Dropout2d(p=.5),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 1, 1, bias=False),
            )
        )
    },
    'globicephala': {
        'weights': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr_GLOBI.005_TRAIN_.stdc',
        'fs': 48000,
        'archi': nn.Sequential(
            nn.Sequential(
                STFT(2048, 512),
                MelFilter(48000, 2048, 128, 2000, 6000),
                PCENLayer(128)
            ),
            nn.Sequential(
                nn.Conv2d(1, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3,bias=False),
                nn.BatchNorm2d(32),
                nn.MaxPool2d(3),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 32, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, (19, 3), bias=False),
                nn.BatchNorm2d(64),
                nn.MaxPool2d(3),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(64, 256, (1, 6), bias=False),  # for 80 bands
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(256, 64, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.01),
                nn.Dropout(p=.5),
                nn.Conv2d(64, 1, 1, bias=False),
                nn.MaxPool2d((6, 1)),
            )
        )
    }
}