import torch
import torch.nn as nn
import torch.nn.functional as F
import Features

################################################################################
def createNetwork(name, dicts, outputSizes, incremental) :
  featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
  featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2"
  historyNb = 10
  suffixSize = 4
  prefixSize = 4
  hiddenSize = 1600
  columns = ["UPOS", "FORM"]

  if name == "base" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize)
  elif name == "semi" :
    return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize)
  elif name == "big" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2)
  elif name == "lstm" :
    return LSTMNet(dicts, outputSizes, incremental)
  elif name == "separated" :
    return SeparatedNet(dicts, outputSizes, incremental)
  elif name == "tagger" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, suffixSize, prefixSize, columns, hiddenSize)

  raise Exception("Unknown network name '%s'"%name)
################################################################################

################################################################################
class BaseNet(nn.Module):
  def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) :
    super().__init__()
    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)

    self.incremental = incremental
    self.state = 0
    self.featureFunction = featureFunction
    self.historyNb = historyNb
    self.suffixSize = suffixSize
    self.prefixSize = prefixSize
    self.columns = columns

    self.embSize = 64
    self.nbTargets = len(self.featureFunction.split())
    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
    self.outputSizes = outputSizes
    for name in dicts.dicts :
      self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
    self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize)
    for i in range(len(outputSizes)) :
      self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i]))
    self.dropout = nn.Dropout(0.3)

    self.apply(self.initWeights)

  def setState(self, state) :
    self.state = state

  def forward(self, x) :
    embeddings = []
    for i in range(len(self.columns)) :
      embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
    y = torch.cat(embeddings,-1).view(x.size(0),-1)
    curIndex = len(self.columns)*self.nbTargets
    if self.historyNb > 0 :
      historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
      y = torch.cat([y, historyEmb],-1)
      curIndex = curIndex+self.historyNb
    if self.prefixSize > 0 :
      prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
      y = torch.cat([y, prefixEmb],-1)
      curIndex = curIndex+self.prefixSize
    if self.suffixSize > 0 :
      suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
      y = torch.cat([y, suffixEmb],-1)
      curIndex = curIndex+self.suffixSize
    y = self.dropout(y)
    y = F.relu(self.dropout(self.fc1(y)))
    y = getattr(self, "output_"+str(self.state))(y)
    return y

  def currentDevice(self) :
    return self.dummyParam.device

  def initWeights(self,m) :
    if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)

  def extractFeatures(self, dicts, config) :
    colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
    historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
    prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
    suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
    return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################

################################################################################
class SemiNet(nn.Module):
  def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns, hiddenSize) :
    super().__init__()
    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)

    self.incremental = incremental
    self.state = 0
    self.featureFunction = featureFunction
    self.historyNb = historyNb
    self.suffixSize = suffixSize
    self.prefixSize = prefixSize
    self.columns = columns

    self.embSize = 64
    self.nbTargets = len(self.featureFunction.split())
    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
    self.outputSizes = outputSizes
    for name in dicts.dicts :
      self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
    self.fc1 = nn.Linear(self.inputSize * self.embSize, hiddenSize)
    for i in range(len(outputSizes)) :
      self.add_module("output_hidden_"+str(i), nn.Linear(hiddenSize, hiddenSize))
      self.add_module("output_"+str(i), nn.Linear(hiddenSize, outputSizes[i]))
    self.dropout = nn.Dropout(0.3)

    self.apply(self.initWeights)

  def setState(self, state) :
    self.state = state

  def forward(self, x) :
    embeddings = []
    for i in range(len(self.columns)) :
      embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
    y = torch.cat(embeddings,-1).view(x.size(0),-1)
    curIndex = len(self.columns)*self.nbTargets
    if self.historyNb > 0 :
      historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
      y = torch.cat([y, historyEmb],-1)
      curIndex = curIndex+self.historyNb
    if self.prefixSize > 0 :
      prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
      y = torch.cat([y, prefixEmb],-1)
      curIndex = curIndex+self.prefixSize
    if self.suffixSize > 0 :
      suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
      y = torch.cat([y, suffixEmb],-1)
      curIndex = curIndex+self.suffixSize
    y = self.dropout(y)
    y = F.relu(self.dropout(self.fc1(y)))
    y = self.dropout(getattr(self, "output_hidden_"+str(self.state))(y))
    y = getattr(self, "output_"+str(self.state))(y)
    return y

  def currentDevice(self) :
    return self.dummyParam.device

  def initWeights(self,m) :
    if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)

  def extractFeatures(self, dicts, config) :
    colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
    historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
    prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
    suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
    return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################

################################################################################
class SeparatedNet(nn.Module):
  def __init__(self, dicts, outputSizes, incremental) :
    super().__init__()
    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)

    self.incremental = incremental
    self.state = 0
    self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
    self.historyNb = 5
    self.suffixSize = 4
    self.prefixSize = 4
    self.columns = ["UPOS", "FORM"]

    self.embSize = 64
    self.nbTargets = len(self.featureFunction.split())
    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
    self.outputSizes = outputSizes

    for i in range(len(outputSizes)) :
      for name in dicts.dicts :
        self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize))
      self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, 1600))
      self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
    self.dropout = nn.Dropout(0.3)

    self.apply(self.initWeights)

  def setState(self, state) :
    self.state = state

  def forward(self, x) :
    embeddings = []
    for i in range(len(self.columns)) :
      embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
    y = torch.cat(embeddings,-1).view(x.size(0),-1)
    curIndex = len(self.columns)*self.nbTargets
    if self.historyNb > 0 :
      historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
      y = torch.cat([y, historyEmb],-1)
      curIndex = curIndex+self.historyNb
    if self.prefixSize > 0 :
      prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
      y = torch.cat([y, prefixEmb],-1)
      curIndex = curIndex+self.prefixSize
    if self.suffixSize > 0 :
      suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
      y = torch.cat([y, suffixEmb],-1)
      curIndex = curIndex+self.suffixSize
    y = self.dropout(y)
    y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y)))
    y = getattr(self, "output_"+str(self.state))(y)
    return y

  def currentDevice(self) :
    return self.dummyParam.device

  def initWeights(self,m) :
    if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)

  def extractFeatures(self, dicts, config) :
    colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
    historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
    prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
    suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
    return torch.cat([colsValues, historyValues, prefixValues, suffixValues])

################################################################################

################################################################################
class LSTMNet(nn.Module):
  def __init__(self, dicts, outputSizes, incremental) :
    super().__init__()
    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)

    self.incremental = incremental
    self.state = 0
    self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2"
    self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
    self.historyNb = 5
    self.columns = ["UPOS", "FORM"]

    self.embSize = 64
    self.nbInputLSTM = len(self.featureFunctionLSTM.split())
    self.nbInputBase = len(self.featureFunction.split())
    self.nbTargets = self.nbInputBase + self.nbInputLSTM
    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb
    self.outputSizes = outputSizes
    for name in dicts.dicts :
      self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
    self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True)
    self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True)
    self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
    for i in range(len(outputSizes)) :
      self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
    self.dropout = nn.Dropout(0.3)

    self.apply(self.initWeights)

  def setState(self, state) :
    self.state = state

  def forward(self, x) :
    embeddings = []
    embeddingsLSTM = []
    for i in range(len(self.columns)) :
      embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbInputBase:(i+1)*self.nbInputBase]))
    for i in range(len(self.columns)) :
      embeddingsLSTM.append(getattr(self, "emb_"+self.columns[i])(x[...,len(self.columns)*self.nbInputBase+i*self.nbInputLSTM:len(self.columns)*self.nbInputBase+(i+1)*self.nbInputLSTM]))

    z = torch.cat(embeddingsLSTM,-1)
    z = self.lstmFeat(z)[0]
    z = z.reshape(x.size(0), -1)
    y = torch.cat(embeddings,-1).reshape(x.size(0),-1)
    y = torch.cat([y,z], -1)
    if self.historyNb > 0 :
      historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb])
      historyEmb = self.lstmHist(historyEmb)[0]
      historyEmb = historyEmb.reshape(x.size(0), -1)
      y = torch.cat([y, historyEmb],-1)
    y = self.dropout(y)
    y = F.relu(self.dropout(self.fc1(y)))
    y = getattr(self, "output_"+str(self.state))(y)
    return y

  def currentDevice(self) :
    return self.dummyParam.device

  def initWeights(self,m) :
    if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)

  def extractFeatures(self, dicts, config) :
    colsValuesBase = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
    colsValuesLSTM = Features.extractColsFeatures(dicts, config, self.featureFunctionLSTM, self.columns, self.incremental)
    historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
    return torch.cat([colsValuesBase, colsValuesLSTM, historyValues])
################################################################################