Skip to content
Snippets Groups Projects
Commit 2cf0aae6 authored by Franck Dary's avatar Franck Dary
Browse files

Added Contextual module

parent cf56fefc
Branches
No related tags found
No related merge requests found
Showing with 248 additions and 28 deletions
#ifndef CONTEXTUALMODULE__H
#define CONTEXTUALMODULE__H
#include <torch/torch.h>
#include <optional>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "GRU.hpp"
#include "LSTM.hpp"
#include "Concat.hpp"
class ContextualModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns;
std::vector<std::function<std::string(const std::string &)>> functions;
std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
std::pair<int,int> window;
int inSize;
int outSize;
std::filesystem::path path;
std::filesystem::path w2vFile;
public :
ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(ContextualModule);
#endif
......@@ -3,6 +3,7 @@
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "ContextualModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "AppliableTransModule.hpp"
......
......@@ -148,7 +148,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
return myModule->forward(context);
return myModule->forward(context).reshape({input.size(0), -1});
}
void ContextModuleImpl::registerEmbeddings()
......
#include "ContextualModule.hpp"
ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Window\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)Targets\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
auto splited = util::split(sm.str(1), ' ');
if (splited.size() != 2)
util::myThrow("bad Window, expected 2 indexes");
window = std::make_pair(std::stoi(splited[0]), std::stoi(splited[1]));
auto funcColumns = util::split(sm.str(2), ' ');
columns.clear();
for (auto & funcCol : funcColumns)
{
functions.emplace_back() = getFunction(funcCol);
columns.emplace_back(util::split(funcCol, ':').back());
}
auto subModuleType = sm.str(3);
auto subModuleArguments = util::split(sm.str(4), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
for (auto & target : util::split(sm.str(8), ' '))
{
auto splited = util::split(target, '.');
if (splited.size() != 2 and splited.size() != 3)
util::myThrow(fmt::format("invalid target '{}' expected 'object.index(.childIndex)'", target));
targets.emplace_back(std::make_tuple(Config::str2object(splited[0]), std::stoi(splited[1]), splited.size() == 3 ? std::optional<int>(std::stoi(splited[2])) : std::optional<int>()));
}
inSize = std::stoi(sm.str(5));
outSize = std::stoi(sm.str(6));
if (outSize % 2)
util::myThrow("odd out size is not supported");
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFile = sm.str(7);
if (!w2vFile.empty())
{
getDict().loadWord2Vec(this->path / w2vFile);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
std::size_t ContextualModuleImpl::getOutputSize()
{
return targets.size()*outSize;
}
std::size_t ContextualModuleImpl::getInputSize()
{
return columns.size()*(2+window.second-window.first)+targets.size();
}
void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
auto & dict = getDict();
std::vector<long> contextIndexes;
std::vector<long> targetIndexes;
std::map<long,long> configIndex2ContextIndex;
contextIndexes.emplace_back(-2);
for (long i = window.first; i <= window.second; i++)
{
if (config.hasRelativeWordIndex(Config::Object::Buffer, i))
{
contextIndexes.emplace_back(config.getRelativeWordIndex(Config::Object::Buffer, i));
configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1;
}
else
contextIndexes.emplace_back(-1);
}
for (auto & target : targets)
if (config.hasRelativeWordIndex(std::get<0>(target), std::get<1>(target)))
{
int baseIndex = config.getRelativeWordIndex(std::get<0>(target), std::get<1>(target));
if (!std::get<2>(target))
targetIndexes.emplace_back(baseIndex);
else
{
int childIndex = *std::get<2>(target);
auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|');
int candidate = -1;
if (childIndex >= 0 and childIndex < (int)childs.size())
candidate = std::stoi(childs[childIndex]);
else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0)
candidate = std::stoi(childs[childs.size()+childIndex]);
targetIndexes.emplace_back(candidate);
}
}
else
targetIndexes.emplace_back(-1);
for (auto index : contextIndexes)
for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
{
auto & col = columns[colIndex];
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr)));
}
else if (index == -2)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, "_NONE_")));
}
else
{
int dictIndex;
if (col == Config::idColName)
{
std::string value;
if (config.isCommentPredicted(index))
value = "ID(comment)";
else if (config.isMultiwordPredicted(index))
value = "ID(multiword)";
else if (config.isTokenPredicted(index))
value = "ID(token)";
dictIndex = dict.getIndexOrInsert(value);
}
else if (col == Config::EOSColName)
{
dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
}
else
dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
}
}
for (auto index : targetIndexes)
{
if (configIndex2ContextIndex.count(index))
for (auto & contextElement : context)
contextElement.push_back(configIndex2ContextIndex.at(index)+1);
else
for (auto & contextElement : context)
contextElement.push_back(0);
}
}
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size());
auto out = myModule->forward(context);
std::vector<torch::Tensor> batchElems;
for (unsigned int i = 0; i < input.size(0); i++)
batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1}));
return torch::stack(batchElems);
}
void ContextualModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
}
......@@ -50,7 +50,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def
torch::Tensor DistanceModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1});
}
std::size_t DistanceModuleImpl::getOutputSize()
......
......@@ -48,7 +48,7 @@ torch::Tensor FocusedColumnModuleImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> outputs;
for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))));
outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))).reshape({input.size(0), -1}));
return torch::cat(outputs, 1);
}
......
......@@ -2,7 +2,7 @@
GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
auto gruOptions = torch::nn::GRUOptions(inputSize, outputSize)
auto gruOptions = torch::nn::GRUOptions(inputSize, std::get<1>(options) ? outputSize/2 : outputSize)
.batch_first(std::get<0>(options))
.bidirectional(std::get<1>(options))
.num_layers(std::get<2>(options))
......@@ -13,15 +13,7 @@ GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputA
torch::Tensor GRUImpl::forward(torch::Tensor input)
{
auto gruOut = std::get<0>(gru(input));
if (outputAll)
return gruOut.reshape({gruOut.size(0), -1});
if (gru->options.bidirectional())
return torch::cat({gruOut.narrow(1,0,1).squeeze(1), gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1)}, 1);
return gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1);
return std::get<0>(gru(input));
}
int GRUImpl::getOutputSize(int sequenceLength)
......
......@@ -38,7 +38,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
torch::Tensor HistoryModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements)));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))).reshape({input.size(0), -1});
}
std::size_t HistoryModuleImpl::getOutputSize()
......
......@@ -2,7 +2,7 @@
LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
auto lstmOptions = torch::nn::LSTMOptions(inputSize, std::get<1>(options) ? outputSize/2 : outputSize)
.batch_first(std::get<0>(options))
.bidirectional(std::get<1>(options))
.num_layers(std::get<2>(options))
......@@ -13,15 +13,7 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu
torch::Tensor LSTMImpl::forward(torch::Tensor input)
{
auto lstmOut = std::get<0>(lstm(input));
if (outputAll)
return lstmOut.reshape({lstmOut.size(0), -1});
if (lstm->options.bidirectional())
return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1);
return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1);
return std::get<0>(lstm(input));
}
int LSTMImpl::getOutputSize(int sequenceLength)
......
......@@ -29,6 +29,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
std::string nameH = fmt::format("{}_{}", getName(), name);
if (splited.first == "Context")
modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path)));
else if (splited.first == "Contextual")
modules.emplace_back(register_module(name, ContextualModule(nameH, splited.second, path)));
else if (splited.first == "StateName")
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "History")
......
......@@ -45,7 +45,7 @@ torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
return myModule->forward(values);
return myModule->forward(values).reshape({input.size(0), -1});
}
std::size_t NumericColumnModuleImpl::getOutputSize()
......
......@@ -39,7 +39,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
torch::Tensor RawInputModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1});
}
std::size_t RawInputModuleImpl::getOutputSize()
......
......@@ -38,7 +38,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
torch::Tensor SplitTransModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1});
}
std::size_t SplitTransModuleImpl::getOutputSize()
......
......@@ -43,7 +43,7 @@ torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
return myModule->forward(values);
return myModule->forward(values).reshape({input.size(0), -1});
}
std::size_t UppercaseRateModuleImpl::getOutputSize()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment