Commit 2cf0aae6 authored by Franck Dary's avatar Franck Dary
Browse files

Added Contextual module

parent cf56fefc
#ifndef CONTEXTUALMODULE__H
#define CONTEXTUALMODULE__H
#include <torch/torch.h>
#include <optional>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "GRU.hpp"
#include "LSTM.hpp"
#include "Concat.hpp"
class ContextualModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns;
std::vector<std::function<std::string(const std::string &)>> functions;
std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
std::pair<int,int> window;
int inSize;
int outSize;
std::filesystem::path path;
std::filesystem::path w2vFile;
public :
ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(ContextualModule);
#endif
......@@ -3,6 +3,7 @@
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "ContextualModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "AppliableTransModule.hpp"
......
......@@ -148,7 +148,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
return myModule->forward(context);
return myModule->forward(context).reshape({input.size(0), -1});
}
void ContextModuleImpl::registerEmbeddings()
......
#include "ContextualModule.hpp"
ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Window\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)Targets\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
auto splited = util::split(sm.str(1), ' ');
if (splited.size() != 2)
util::myThrow("bad Window, expected 2 indexes");
window = std::make_pair(std::stoi(splited[0]), std::stoi(splited[1]));
auto funcColumns = util::split(sm.str(2), ' ');
columns.clear();
for (auto & funcCol : funcColumns)
{
functions.emplace_back() = getFunction(funcCol);
columns.emplace_back(util::split(funcCol, ':').back());
}
auto subModuleType = sm.str(3);
auto subModuleArguments = util::split(sm.str(4), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
for (auto & target : util::split(sm.str(8), ' '))
{
auto splited = util::split(target, '.');
if (splited.size() != 2 and splited.size() != 3)
util::myThrow(fmt::format("invalid target '{}' expected 'object.index(.childIndex)'", target));
targets.emplace_back(std::make_tuple(Config::str2object(splited[0]), std::stoi(splited[1]), splited.size() == 3 ? std::optional<int>(std::stoi(splited[2])) : std::optional<int>()));
}
inSize = std::stoi(sm.str(5));
outSize = std::stoi(sm.str(6));
if (outSize % 2)
util::myThrow("odd out size is not supported");
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFile = sm.str(7);
if (!w2vFile.empty())
{
getDict().loadWord2Vec(this->path / w2vFile);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
std::size_t ContextualModuleImpl::getOutputSize()
{
return targets.size()*outSize;
}
std::size_t ContextualModuleImpl::getInputSize()
{
return columns.size()*(2+window.second-window.first)+targets.size();
}
void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
auto & dict = getDict();
std::vector<long> contextIndexes;
std::vector<long> targetIndexes;
std::map<long,long> configIndex2ContextIndex;
contextIndexes.emplace_back(-2);
for (long i = window.first; i <= window.second; i++)
{
if (config.hasRelativeWordIndex(Config::Object::Buffer, i))
{
contextIndexes.emplace_back(config.getRelativeWordIndex(Config::Object::Buffer, i));
configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1;
}
else
contextIndexes.emplace_back(-1);
}
for (auto & target : targets)
if (config.hasRelativeWordIndex(std::get<0>(target), std::get<1>(target)))
{
int baseIndex = config.getRelativeWordIndex(std::get<0>(target), std::get<1>(target));
if (!std::get<2>(target))
targetIndexes.emplace_back(baseIndex);
else
{
int childIndex = *std::get<2>(target);
auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|');
int candidate = -1;
if (childIndex >= 0 and childIndex < (int)childs.size())
candidate = std::stoi(childs[childIndex]);
else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0)
candidate = std::stoi(childs[childs.size()+childIndex]);
targetIndexes.emplace_back(candidate);
}
}
else
targetIndexes.emplace_back(-1);
for (auto index : contextIndexes)
for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
{
auto & col = columns[colIndex];
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr)));
}
else if (index == -2)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, "_NONE_")));
}
else
{
int dictIndex;
if (col == Config::idColName)
{
std::string value;
if (config.isCommentPredicted(index))
value = "ID(comment)";
else if (config.isMultiwordPredicted(index))
value = "ID(multiword)";
else if (config.isTokenPredicted(index))
value = "ID(token)";
dictIndex = dict.getIndexOrInsert(value);
}
else if (col == Config::EOSColName)
{
dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
}
else
dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
}
}
for (auto index : targetIndexes)
{
if (configIndex2ContextIndex.count(index))
for (auto & contextElement : context)
contextElement.push_back(configIndex2ContextIndex.at(index)+1);
else
for (auto & contextElement : context)
contextElement.push_back(0);
}
}
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size());
auto out = myModule->forward(context);
std::vector<torch::Tensor> batchElems;
for (unsigned int i = 0; i < input.size(0); i++)
batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1}));
return torch::stack(batchElems);
}
void ContextualModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
}
......@@ -50,7 +50,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def
torch::Tensor DistanceModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1});
}
std::size_t DistanceModuleImpl::getOutputSize()
......
......@@ -48,7 +48,7 @@ torch::Tensor FocusedColumnModuleImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> outputs;
for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))));
outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))).reshape({input.size(0), -1}));
return torch::cat(outputs, 1);
}
......
......@@ -2,7 +2,7 @@
GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
auto gruOptions = torch::nn::GRUOptions(inputSize, outputSize)
auto gruOptions = torch::nn::GRUOptions(inputSize, std::get<1>(options) ? outputSize/2 : outputSize)
.batch_first(std::get<0>(options))
.bidirectional(std::get<1>(options))
.num_layers(std::get<2>(options))
......@@ -13,15 +13,7 @@ GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputA
torch::Tensor GRUImpl::forward(torch::Tensor input)
{
auto gruOut = std::get<0>(gru(input));
if (outputAll)
return gruOut.reshape({gruOut.size(0), -1});
if (gru->options.bidirectional())
return torch::cat({gruOut.narrow(1,0,1).squeeze(1), gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1)}, 1);
return gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1);
return std::get<0>(gru(input));
}
int GRUImpl::getOutputSize(int sequenceLength)
......
......@@ -38,7 +38,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
torch::Tensor HistoryModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements)));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))).reshape({input.size(0), -1});
}
std::size_t HistoryModuleImpl::getOutputSize()
......
......@@ -2,7 +2,7 @@
LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
auto lstmOptions = torch::nn::LSTMOptions(inputSize, std::get<1>(options) ? outputSize/2 : outputSize)
.batch_first(std::get<0>(options))
.bidirectional(std::get<1>(options))
.num_layers(std::get<2>(options))
......@@ -13,15 +13,7 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu
torch::Tensor LSTMImpl::forward(torch::Tensor input)
{
auto lstmOut = std::get<0>(lstm(input));
if (outputAll)
return lstmOut.reshape({lstmOut.size(0), -1});
if (lstm->options.bidirectional())
return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1);
return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1);
return std::get<0>(lstm(input));
}
int LSTMImpl::getOutputSize(int sequenceLength)
......
......@@ -29,6 +29,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
std::string nameH = fmt::format("{}_{}", getName(), name);
if (splited.first == "Context")
modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path)));
else if (splited.first == "Contextual")
modules.emplace_back(register_module(name, ContextualModule(nameH, splited.second, path)));
else if (splited.first == "StateName")
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "History")
......
......@@ -45,7 +45,7 @@ torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
return myModule->forward(values);
return myModule->forward(values).reshape({input.size(0), -1});
}
std::size_t NumericColumnModuleImpl::getOutputSize()
......
......@@ -39,7 +39,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
torch::Tensor RawInputModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1});
}
std::size_t RawInputModuleImpl::getOutputSize()
......
......@@ -38,7 +38,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
torch::Tensor SplitTransModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1});
}
std::size_t SplitTransModuleImpl::getOutputSize()
......
......@@ -43,7 +43,7 @@ torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
return myModule->forward(values);
return myModule->forward(values).reshape({input.size(0), -1});
}
std::size_t UppercaseRateModuleImpl::getOutputSize()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment