Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include "ContextualModule.hpp"
ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Window\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)Targets\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
auto splited = util::split(sm.str(1), ' ');
if (splited.size() != 2)
util::myThrow("bad Window, expected 2 indexes");
window = std::make_pair(std::stoi(splited[0]), std::stoi(splited[1]));
auto funcColumns = util::split(sm.str(2), ' ');
columns.clear();
for (auto & funcCol : funcColumns)
{
functions.emplace_back() = getFunction(funcCol);
columns.emplace_back(util::split(funcCol, ':').back());
}
auto subModuleType = sm.str(3);
auto subModuleArguments = util::split(sm.str(4), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
for (auto & target : util::split(sm.str(8), ' '))
{
auto splited = util::split(target, '.');
if (splited.size() != 2 and splited.size() != 3)
util::myThrow(fmt::format("invalid target '{}' expected 'object.index(.childIndex)'", target));
targets.emplace_back(std::make_tuple(Config::str2object(splited[0]), std::stoi(splited[1]), splited.size() == 3 ? std::optional<int>(std::stoi(splited[2])) : std::optional<int>()));
}
inSize = std::stoi(sm.str(5));
outSize = std::stoi(sm.str(6));
if (outSize % 2)
util::myThrow("odd out size is not supported");
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
if (splited.size() != 2)
util::myThrow("expected 'prefix,file.w2v'");
auto pretrained = getDict().loadWord2Vec(this->path / splited[1], splited[0]);
if (pretrained)
{
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
std::size_t ContextualModuleImpl::getOutputSize()
{
return targets.size()*outSize;
}
std::size_t ContextualModuleImpl::getInputSize()
{
return columns.size()*(2+window.second-window.first)+targets.size();
}
void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
auto & dict = getDict();
std::vector<long> contextIndexes;
std::vector<long> targetIndexes;
std::map<long,long> configIndex2ContextIndex;
contextIndexes.emplace_back(-2);
for (long i = window.first; i <= window.second; i++)
{
if (config.hasRelativeWordIndex(Config::Object::Buffer, i))
{
contextIndexes.emplace_back(config.getRelativeWordIndex(Config::Object::Buffer, i));
configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1;
}
else
contextIndexes.emplace_back(-1);
}
for (auto & target : targets)
if (config.hasRelativeWordIndex(std::get<0>(target), std::get<1>(target)))
{
int baseIndex = config.getRelativeWordIndex(std::get<0>(target), std::get<1>(target));
if (!std::get<2>(target))
targetIndexes.emplace_back(baseIndex);
else
{
int childIndex = *std::get<2>(target);
auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|');
int candidate = -1;
if (childIndex >= 0 and childIndex < (int)childs.size())
candidate = std::stoi(childs[childIndex]);
else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0)
candidate = std::stoi(childs[childs.size()+childIndex]);
targetIndexes.emplace_back(candidate);
}
}
else
targetIndexes.emplace_back(-1);
for (auto index : contextIndexes)
for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
{
auto & col = columns[colIndex];
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
//TODO maybe change this to a unique value like Dict::noneValueStr
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
}
else
{
int dictIndex;
if (col == Config::idColName)
{
std::string value;
if (config.isMultiwordPredicted(index))
value = "token";
dictIndex = dict.getIndexOrInsert(value, col);
{
std::string featureValue = config.getAsFeature(col, index);
dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col);
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
}
}
for (auto index : targetIndexes)
{
if (configIndex2ContextIndex.count(index))
contextElement.push_back(configIndex2ContextIndex.at(index));
}
for (auto & contextElement : context)
contextElement.push_back(0);
torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index)
{
for (int i = 1; i < input.dim(); i++)
if (i != dim)
index = index.unsqueeze(i);
std::vector<long> expanse(input.dim());
for (unsigned int i = 1; i < expanse.size(); i++)
expanse[i] = input.size(i);
expanse[0] = -1;
expanse[dim] = -1;
index = index.expand(expanse);
return torch::gather(input, dim, index);
}
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size());
auto out = myModule->forward(context);
return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
}
void ContextualModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);