NumericColumnModule.cpp 3.22 KB
Newer Older
Franck Dary's avatar
Franck Dary committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include "NumericColumnModule.hpp"
#include "NeuralNetwork.hpp"

NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::string & definition)
{
  setName(name);
  std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
        {
          try
          {
            column = sm.str(1);

            for (auto & index : util::split(sm.str(2), ' '))
              focusedBuffer.emplace_back(std::stoi(index));

            for (auto & index : util::split(sm.str(3), ' '))
              focusedStack.emplace_back(std::stoi(index));

            auto subModuleType = sm.str(4);
            auto subModuleArguments = util::split(sm.str(5), ' ');

            auto options = MyModule::ModuleOptions(true)
              .bidirectional(std::stoi(subModuleArguments[0]))
              .num_layers(std::stoi(subModuleArguments[1]))
              .dropout(std::stof(subModuleArguments[2]))
              .complete(std::stoi(subModuleArguments[3]));

            int outSize = std::stoi(sm.str(6));

            if (subModuleType == "LSTM")
              myModule = register_module("myModule", LSTM(1, outSize, options));
            else if (subModuleType == "GRU")
              myModule = register_module("myModule", GRU(1, outSize, options));
Franck Dary's avatar
Franck Dary committed
35
36
            else if (subModuleType == "Concat")
              myModule = register_module("myModule", Concat(1));
Franck Dary's avatar
Franck Dary committed
37
38
39
40
41
42
43
44
45
46
            else
              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
        }))
    util::myThrow(fmt::format("invalid definition '{}'", definition));
}

torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
  auto context = input.narrow(1, firstInputIndex, getInputSize());
47
  auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
Franck Dary's avatar
Franck Dary committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  return myModule->forward(values);
}

std::size_t NumericColumnModuleImpl::getOutputSize()
{
  return myModule->getOutputSize(getInputSize());
}

std::size_t NumericColumnModuleImpl::getInputSize()
{
  return focusedBuffer.size() + focusedStack.size();
}

void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
  std::vector<long> focusedIndexes;

  for (int index : focusedBuffer)
    focusedIndexes.emplace_back(config.getRelativeWordIndex(index));

  for (int index : focusedStack)
    if (config.hasStack(index))
      focusedIndexes.emplace_back(config.getStack(index));
    else
      focusedIndexes.emplace_back(-1);

  for (auto & contextElement : context)
    for (auto index : focusedIndexes)
    {
      double res = 0.0;
      if (index >= 0)
        res = std::stof(config.getAsFeature(column, index).get());

      contextElement.emplace_back(0);
      std::memcpy(&contextElement.back(), &res, sizeof res);
    }
}

86
void NumericColumnModuleImpl::registerEmbeddings(std::filesystem::path)
Franck Dary's avatar
Franck Dary committed
87
88
89
{
}