#include "LSTMNetwork.hpp"

LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
{
  LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
  auto lstmOptionsAll = lstmOptions;
  std::get<4>(lstmOptionsAll) = true;

  int currentOutputSize = embeddingsSize;
  int currentInputSize = 1;

  contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold));
  contextLSTM->setFirstInputIndex(currentInputSize);
  currentOutputSize += contextLSTM->getOutputSize();
  currentInputSize += contextLSTM->getInputSize();

  if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
  {
    hasRawInputLSTM = true;
    rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
    rawInputLSTM->setFirstInputIndex(currentInputSize);
    currentOutputSize += rawInputLSTM->getOutputSize();
    currentInputSize += rawInputLSTM->getInputSize();
  }

  splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
  splitTransLSTM->setFirstInputIndex(currentInputSize);
  currentOutputSize += splitTransLSTM->getOutputSize();
  currentInputSize += splitTransLSTM->getInputSize();

  for (unsigned int i = 0; i < focusedColumns.size(); i++)
  {
    focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions)));
    focusedLstms.back()->setFirstInputIndex(currentInputSize);
    currentOutputSize += focusedLstms.back()->getOutputSize();
    currentInputSize += focusedLstms.back()->getInputSize();
  }

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));

  linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize));
  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}

torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);

  auto embeddings = embeddingsDropout(wordEmbeddings(input));

  std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};

  outputs.emplace_back(contextLSTM(embeddings));

  if (hasRawInputLSTM)
    outputs.emplace_back(rawInputLSTM(embeddings));

  outputs.emplace_back(splitTransLSTM(embeddings));

  for (auto & lstm : focusedLstms)
    outputs.emplace_back(lstm(embeddings));

  auto totalInput = torch::cat(outputs, 1);

  return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}

std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
  if (dict.size() >= maxNbEmbeddings)
    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));

  std::vector<std::vector<long>> context;
  context.emplace_back();

  context.back().emplace_back(dict.getIndexOrInsert(config.getState()));

  contextLSTM->addToContext(context, dict, config);
  if (hasRawInputLSTM)
    rawInputLSTM->addToContext(context, dict, config);
  splitTransLSTM->addToContext(context, dict, config);
  for (auto & lstm : focusedLstms)
    lstm->addToContext(context, dict, config);

  if (!is_training() && context.size() > 1)
    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));

  return context;
}