Skip to content
Snippets Groups Projects
NeuralNetwork.cpp 2.06 KiB
Newer Older
#include "NeuralNetwork.hpp"

torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const
{
  std::vector<long> context;

  for (int index : bufferContext)
    context.emplace_back(config.getRelativeWordIndex(index));

  for (int index : stackContext)
    if (config.hasStack(index))
      context.emplace_back(config.getStack(index));
    else
      context.emplace_back(-1);

  return context;
}
std::vector<long> NeuralNetworkImpl::extractFocusedIndexes(const Config & config) const
{
  std::vector<long> context;
  for (int index : bufferFocused)
    context.emplace_back(config.getRelativeWordIndex(index));
  for (int index : stackFocused)
    if (config.hasStack(index))
      context.emplace_back(config.getStack(index));
    else
      context.emplace_back(-1);
std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{
  std::vector<long> indexes = extractContextIndexes(config);
  std::vector<long> context;
  for (auto & col : columns)
    for (auto index : indexes)
      if (index == -1)
        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
}

int NeuralNetworkImpl::getContextSize() const
{
  return columns.size()*(bufferContext.size()+stackContext.size());
}

void NeuralNetworkImpl::setBufferContext(const std::vector<int> & bufferContext)
{
  this->bufferContext = bufferContext;
void NeuralNetworkImpl::setStackContext(const std::vector<int> & stackContext)
Franck Dary's avatar
Franck Dary committed
{
  this->stackContext = stackContext;
void NeuralNetworkImpl::setBufferFocused(const std::vector<int> & bufferFocused)
Franck Dary's avatar
Franck Dary committed
{
  this->bufferFocused = bufferFocused;
void NeuralNetworkImpl::setStackFocused(const std::vector<int> & stackFocused)
Franck Dary's avatar
Franck Dary committed
{
  this->stackFocused = stackFocused;
void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns)
{
  this->columns = columns;
}