ContextualModule.cpp 7.85 KB
Newer Older
Franck Dary's avatar
Franck Dary committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include "ContextualModule.hpp"

ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{
  setName(name);

  std::regex regex("(?:(?:\\s|\\t)*)Window\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)Targets\\{(.*)\\}(?:(?:\\s|\\t)*)");
  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
        {
          try
          {
            auto splited = util::split(sm.str(1), ' ');
            if (splited.size() != 2)
              util::myThrow("bad Window, expected 2 indexes");
            window = std::make_pair(std::stoi(splited[0]), std::stoi(splited[1]));

            auto funcColumns = util::split(sm.str(2), ' ');
            columns.clear();
            for (auto & funcCol : funcColumns)
            {
              functions.emplace_back() = getFunction(funcCol);
              columns.emplace_back(util::split(funcCol, ':').back());
            }

            auto subModuleType = sm.str(3);
            auto subModuleArguments = util::split(sm.str(4), ' ');

            auto options = MyModule::ModuleOptions(true)
              .bidirectional(std::stoi(subModuleArguments[0]))
              .num_layers(std::stoi(subModuleArguments[1]))
              .dropout(std::stof(subModuleArguments[2]))
              .complete(std::stoi(subModuleArguments[3]));

            for (auto & target : util::split(sm.str(8), ' '))
            {
              auto splited = util::split(target, '.');
              if (splited.size() != 2 and splited.size() != 3)
                util::myThrow(fmt::format("invalid target '{}' expected 'object.index(.childIndex)'", target));
              targets.emplace_back(std::make_tuple(Config::str2object(splited[0]), std::stoi(splited[1]), splited.size() == 3 ? std::optional<int>(std::stoi(splited[2])) : std::optional<int>()));
            }

            inSize = std::stoi(sm.str(5));
            outSize = std::stoi(sm.str(6));
            if (outSize % 2)
              util::myThrow("odd out size is not supported");

            if (subModuleType == "LSTM")
              myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
            else if (subModuleType == "GRU")
              myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
            else if (subModuleType == "Concat")
              myModule = register_module("myModule", Concat(inSize));
            else
              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));

56
            w2vFiles = sm.str(7);
Franck Dary's avatar
Franck Dary committed
57

58
            if (!w2vFiles.empty())
Franck Dary's avatar
Franck Dary committed
59
            {
60
61
62
63
64
65
66
67
68
69
              auto pathes = util::split(w2vFiles.string(), ' ');
              for (auto & p : pathes)
              {
                auto splited = util::split(p, ',');
                if (splited.size() != 2)
                  util::myThrow("expected 'prefix,file.w2v'");
                getDict().loadWord2Vec(this->path / splited[1], splited[0]);
                getDict().setState(Dict::State::Closed);
                dictSetPretrained(true);
              }
Franck Dary's avatar
Franck Dary committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            }

          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
        }))
    util::myThrow(fmt::format("invalid definition '{}'", definition));
}

std::size_t ContextualModuleImpl::getOutputSize()
{
  return targets.size()*outSize;
}

std::size_t ContextualModuleImpl::getInputSize()
{
  return columns.size()*(2+window.second-window.first)+targets.size();
}

void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
  auto & dict = getDict();
  std::vector<long> contextIndexes;
  std::vector<long> targetIndexes;
  std::map<long,long> configIndex2ContextIndex;

  contextIndexes.emplace_back(-2);

  for (long i = window.first; i <= window.second; i++)
  {
    if (config.hasRelativeWordIndex(Config::Object::Buffer, i))
    {
      contextIndexes.emplace_back(config.getRelativeWordIndex(Config::Object::Buffer, i));
      configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1;
    }
    else
      contextIndexes.emplace_back(-1);
  }

  for (auto & target : targets)
    if (config.hasRelativeWordIndex(std::get<0>(target), std::get<1>(target)))
    {
      int baseIndex = config.getRelativeWordIndex(std::get<0>(target), std::get<1>(target));
      if (!std::get<2>(target))
        targetIndexes.emplace_back(baseIndex);
      else
      {
        int childIndex = *std::get<2>(target);
        auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|');
        int candidate = -1;

        if (childIndex >= 0 and childIndex < (int)childs.size())
          candidate = std::stoi(childs[childIndex]);
        else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0)
          candidate = std::stoi(childs[childs.size()+childIndex]);

        targetIndexes.emplace_back(candidate);
      }
    }
    else
      targetIndexes.emplace_back(-1);

  for (auto index : contextIndexes)
    for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
    {
      auto & col = columns[colIndex];
      if (index == -1)
      {
        for (auto & contextElement : context)
137
          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
Franck Dary's avatar
Franck Dary committed
138
139
140
      }
      else if (index == -2)
      {
141
        //TODO maybe change this to a unique value like Dict::noneValueStr
Franck Dary's avatar
Franck Dary committed
142
        for (auto & contextElement : context)
143
          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
Franck Dary's avatar
Franck Dary committed
144
145
146
147
148
149
150
151
      }
      else
      {
        int dictIndex;
        if (col == Config::idColName)
        {
          std::string value;
          if (config.isCommentPredicted(index))
152
            value = "comment";
Franck Dary's avatar
Franck Dary committed
153
          else if (config.isMultiwordPredicted(index))
154
            value = "multiword";
Franck Dary's avatar
Franck Dary committed
155
          else if (config.isTokenPredicted(index))
156
157
            value = "token";
          dictIndex = dict.getIndexOrInsert(value, col);
Franck Dary's avatar
Franck Dary committed
158
159
        }
        else
160
161
        {
          std::string featureValue = config.getAsFeature(col, index);
162
          dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col);
163
        }
Franck Dary's avatar
Franck Dary committed
164
165
166
167
168
169
170
171
172

        for (auto & contextElement : context)
          contextElement.push_back(dictIndex);
      }
    }

  for (auto index : targetIndexes)
  {
    if (configIndex2ContextIndex.count(index))
173
    {
Franck Dary's avatar
Franck Dary committed
174
      for (auto & contextElement : context)
175
176
        contextElement.push_back(configIndex2ContextIndex.at(index));
    }
Franck Dary's avatar
Franck Dary committed
177
    else
178
    {
Franck Dary's avatar
Franck Dary committed
179
180
      for (auto & contextElement : context)
        contextElement.push_back(0);
181
    }
Franck Dary's avatar
Franck Dary committed
182
183
184
  }
}

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index)
{
  for (int i = 1; i < input.dim(); i++)
    if (i != dim)
      index = index.unsqueeze(i);

  std::vector<long> expanse(input.dim());
  for (unsigned int i = 1; i < expanse.size(); i++)
    expanse[i] = input.size(i);
  expanse[0] = -1;
  expanse[dim] = -1;
  index = index.expand(expanse);

  return torch::gather(input, dim, index);
}

Franck Dary's avatar
Franck Dary committed
201
202
203
204
205
206
207
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{
  auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
  auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size());

  auto out = myModule->forward(context);

208
  return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
Franck Dary's avatar
Franck Dary committed
209
210
211
212
}

void ContextualModuleImpl::registerEmbeddings()
{
213
  wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
214
215
216
217
218

  auto pathes = util::split(w2vFiles.string(), ' ');
  for (auto & p : pathes)
  {
    auto splited = util::split(p, ',');
219
    loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
220
  }
Franck Dary's avatar
Franck Dary committed
221
222
}