#include "ContextLSTM.hpp"

ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold)
{
  lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options));
}

std::size_t ContextLSTMImpl::getOutputSize()
{
  return lstm->getOutputSize(bufferContext.size()+stackContext.size());
}

std::size_t ContextLSTMImpl::getInputSize()
{
  return columns.size()*(bufferContext.size()+stackContext.size());
}

void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
  std::vector<long> contextIndexes;

  for (int index : bufferContext)
    contextIndexes.emplace_back(config.getRelativeWordIndex(index));

  for (int index : stackContext)
    if (config.hasStack(index))
      contextIndexes.emplace_back(config.getStack(index));
    else
      contextIndexes.emplace_back(-1);

  for (auto index : contextIndexes)
    for (auto & col : columns)
      if (index == -1)
        for (auto & contextElement : context)
          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
      else
      {
        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));

        for (auto & contextElement : context)
          contextElement.push_back(dictIndex);

        if (is_training())
          for (auto & targetCol : unknownValueColumns)
            if (col == targetCol)
              if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
              {
                context.emplace_back(context.back());
                context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
              }
      }
}

torch::Tensor ContextLSTMImpl::forward(torch::Tensor input)
{
  auto context = input.narrow(1, firstInputIndex, getInputSize());

  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});

  return lstm(context);
}