import torch
import torch.nn as nn
import torch.nn.functional as F
import Features
import Transition

################################################################################
def readPretrainedSize(w2vFile) :
  for line in open(w2vFile, "r") :
    return int(line.strip().split()[1])
################################################################################

################################################################################
def loadW2v(w2vFile, weights, dicts, colname) :
  size = None
  for line in open(w2vFile, "r") :
    line = line.strip()
    if size is None :
      size = int(line.split()[1])
      continue
    splited = line.split()
    word = " ".join(splited[0:len(splited)-size])
    emb = torch.tensor(list(map(float,splited[len(splited)-size:])))

    weights[dicts.get(colname, word)] = emb
################################################################################

################################################################################
def getNeededDicts(name) :
  names = ["FORM","UPOS"]
  if "Lexicon" in name :
    names.append("LEXICON")
  if "NoLetters" not in name :
    names.append("LETTER")

  return names
################################################################################

################################################################################
def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) :
  featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
  featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2"
  historyNb = 10
  historyPopNb = 5
  suffixSize = 4
  prefixSize = 4
  columns = ["UPOS", "FORM"]

  if name == "base" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack)
  if name == "big" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, 3200, 128, pretrained, hasBack)
  elif name == "baseNoLetters" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, 1600, 64, pretrained, hasBack)
  elif name == "tagger" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, 1600, 64, pretrained, hasBack)
  elif name == "taggerLexicon" :
    return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], 1600, 64, pretrained, hasBack)

  raise Exception("Unknown network name '%s'"%name)
################################################################################

################################################################################
class LockedEmbeddings(nn.Module) :
  def __init__(self, nbElems, embSize, specialIndexes) :
    super().__init__()
    self.embNormal = nn.Embedding(nbElems, embSize)
    self.embNormal.weight.requires_grad = False
    self.embSpecial = nn.Embedding(len(specialIndexes), embSize)
    for index in specialIndexes :
      if index not in range(len(specialIndexes)) :
        raise Exception("Special indexes not contiguous from 0 :", specialIndexes)

  def forward(self, x) :
    mask = x >= self.embSpecial.weight.size(0)
    specialIndexes = torch.ones(x.size(), device=self.embNormal.weight.device, dtype=torch.int)
    specialIndexes[mask] = 0
    normalRes = self.embNormal(x)
    specialRes = self.embSpecial(x*specialIndexes)
    normalIndexes = torch.ones(normalRes.size(), device=self.embNormal.weight.device, dtype=torch.int)
    specialIndexes = torch.ones(specialRes.size(), device=self.embNormal.weight.device, dtype=torch.int)
    specialIndexes[mask] = 0
    normalIndexes[~mask] = 0
    return normalIndexes*normalRes + specialIndexes*specialRes

  def getNormalWeights(self) :
    return self.embNormal.weight
################################################################################

################################################################################
class BaseNet(nn.Module) :
  def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, embSize, pretrained, hasBack) :
    super().__init__()
    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)

    self.incremental = incremental
    self.state = 0
    self.featureFunction = featureFunction
    self.historyNb = historyNb
    self.historyPopNb = historyPopNb
    self.suffixSize = suffixSize
    self.prefixSize = prefixSize
    self.columns = columns
    self.hasBack = hasBack

    self.embSize = embSize
    embSizes = {}
    self.nbTargets = len(self.featureFunction.split())
    self.outputSizes = outputSizes
    for name in dicts.dicts :
      if name in pretrained :
        pretrainedSize = readPretrainedSize(pretrained[name])
        embSizes[name] = pretrainedSize
        emb = LockedEmbeddings(len(dicts.dicts[name]), pretrainedSize, dicts.getSpecialIndexes(name))
        loadW2v(pretrained[name], emb.getNormalWeights(), dicts, name)
        self.add_module("emb_"+name, emb)
      else :
        embSizes[name] = self.embSize
        self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
    self.inputSize = (self.historyNb+self.historyPopNb)*embSizes.get("HISTORY",0)+(self.suffixSize+self.prefixSize)*embSizes.get("LETTER",0) + sum([self.nbTargets*embSizes.get(col,0) for col in self.columns])
    self.fc1 = nn.Linear(self.inputSize, hiddenSize)
    for i in range(len(outputSizes)) :
      self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack > 0 else 0), outputSizes[i]))
    self.dropout = nn.Dropout(0.3)

    self.apply(self.initWeights)

  def setState(self, state) :
    self.state = state

  def forward(self, x) :
    embeddings = []
    if self.hasBack > 0 :
      canBack = x[...,0:1]
      x = x[...,1:]

    for i in range(len(self.columns)) :
      embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
    y = torch.cat(embeddings,-1).view(x.size(0),-1)
    curIndex = len(self.columns)*self.nbTargets
    if self.historyNb > 0 :
      historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
      y = torch.cat([y, historyEmb],-1)
      curIndex = curIndex+self.historyNb
    if self.historyPopNb > 0 :
      historyPopEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1)
      y = torch.cat([y, historyPopEmb],-1)
      curIndex = curIndex+self.historyPopNb
    if self.prefixSize > 0 :
      prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
      y = torch.cat([y, prefixEmb],-1)
      curIndex = curIndex+self.prefixSize
    if self.suffixSize > 0 :
      suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
      y = torch.cat([y, suffixEmb],-1)
      curIndex = curIndex+self.suffixSize
    y = self.dropout(y)
    y = F.relu(self.dropout(self.fc1(y)))
    if self.hasBack > 0 :
      y = torch.cat([y,canBack], 1)
    y = getattr(self, "output_"+str(self.state))(y)
    return y

  def currentDevice(self) :
    return self.dummyParam.device

  def initWeights(self,m) :
    if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)

  def extractFeatures(self, dicts, config) :
    colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
    historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
    historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb)
    prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
    suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
    backAction = None
    if self.hasBack > 0 :
      backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK %d"%self.hasBack).appliable(config) else torch.zeros(1, dtype=torch.int)
    allFeatures = [f for f in [backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues] if f is not None]
    return torch.cat(allFeatures)
################################################################################