Skip to content
Snippets Groups Projects
NeuralNetwork.cpp 2.07 KiB
Newer Older
#include "NeuralNetwork.hpp"

torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const
  std::stack<long> leftContext;
  for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
    if (!config.isCommentPredicted(index))
      leftContext.push(index);

  std::vector<long> context;

  while (context.size() < leftBorder-leftContext.size())
    context.emplace_back(-1);
  while (!leftContext.empty())
  {
    context.emplace_back(leftContext.top());
    leftContext.pop();
  }

  for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
    if (!config.isCommentPredicted(index))
      context.emplace_back(index);

  while (context.size() < leftBorder+rightBorder+1)
    context.emplace_back(-1);

  for (unsigned int i = 0; i < nbStackElements; i++)
    if (config.hasStack(i))
      context.emplace_back(config.getStack(i));
    else
      context.emplace_back(-1);
  return context;
}

std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{
  std::vector<long> indexes = extractContextIndexes(config);
  std::vector<long> context;
  for (auto & col : columns)
    for (auto index : indexes)
      if (index == -1)
        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
  return context;
}

int NeuralNetworkImpl::getContextSize() const
{
  return columns.size()*(1 + leftBorder + rightBorder + nbStackElements);
Franck Dary's avatar
Franck Dary committed
}

void NeuralNetworkImpl::setRightBorder(int rightBorder)
{
  this->rightBorder = rightBorder;
}

void NeuralNetworkImpl::setLeftBorder(int leftBorder)
{
  this->leftBorder = leftBorder;
}

void NeuralNetworkImpl::setNbStackElements(int nbStackElements)
{
  this->nbStackElements = nbStackElements;
void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns)
{
  this->columns = columns;
}