Skip to content
Snippets Groups Projects
Commit 95131d6d authored by Franck Dary's avatar Franck Dary
Browse files

Added DistanceModule

parent c8db3f36
No related branches found
No related tags found
No related merge requests found
...@@ -98,6 +98,7 @@ class Config ...@@ -98,6 +98,7 @@ class Config
std::size_t & getStackRef(int relativeIndex); std::size_t & getStackRef(int relativeIndex);
long getRelativeWordIndex(int relativeIndex) const; long getRelativeWordIndex(int relativeIndex) const;
long getRelativeDistance(int fromIndex, int toIndex) const;
public : public :
......
...@@ -675,6 +675,32 @@ long Config::getRelativeWordIndex(int relativeIndex) const ...@@ -675,6 +675,32 @@ long Config::getRelativeWordIndex(int relativeIndex) const
return -1; return -1;
} }
long Config::getRelativeDistance(int fromIndex, int toIndex) const
{
if (toIndex < fromIndex)
{
for (int index = fromIndex, counter = 0; has(0,index,0); --index)
if (!isCommentPredicted(index))
{
if (index == toIndex)
return counter;
--counter;
}
}
else
{
for (int index = fromIndex, counter = 0; has(0,index,0); ++index)
if (!isCommentPredicted(index))
{
if (index == toIndex)
return counter;
++counter;
}
}
return 0;
}
long Config::getRelativeWordIndex(Object object, int relativeIndex) const long Config::getRelativeWordIndex(Object object, int relativeIndex) const
{ {
if (object == Object::Buffer) if (object == Object::Buffer)
......
#ifndef DISTANCEMODULE__H
#define DISTANCEMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class DistanceModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> fromBuffer, fromStack;
std::vector<int> toBuffer, toStack;
int threshold;
int inSize;
public :
DistanceModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(DistanceModule);
#endif
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "UppercaseRateModule.hpp" #include "UppercaseRateModule.hpp"
#include "NumericColumnModule.hpp" #include "NumericColumnModule.hpp"
#include "HistoryModule.hpp" #include "HistoryModule.hpp"
#include "DistanceModule.hpp"
#include "MLP.hpp" #include "MLP.hpp"
class ModularNetworkImpl : public NeuralNetworkImpl class ModularNetworkImpl : public NeuralNetworkImpl
......
#include "DistanceModule.hpp"
DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & definition)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)FromBuffer\\{(.*)\\}(?:(?:\\s|\\t)*)FromStack\\{(.*)\\}(?:(?:\\s|\\t)*)ToBuffer\\{(.*)\\}(?:(?:\\s|\\t)*)ToStack\\{(.*)\\}(?:(?:\\s|\\t)*)Threshold\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
for (auto & index : util::split(sm.str(1), ' '))
fromBuffer.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(2), ' '))
fromStack.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(3), ' '))
toBuffer.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(4), ' '))
toStack.emplace_back(std::stoi(index));
threshold = std::stoi(sm.str(5));
auto subModuleType = sm.str(6);
auto subModuleArguments = util::split(sm.str(7), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
inSize = std::stoi(sm.str(8));
int outSize = std::stoi(sm.str(9));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor DistanceModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
}
std::size_t DistanceModuleImpl::getOutputSize()
{
return myModule->getOutputSize(getInputSize());
}
std::size_t DistanceModuleImpl::getInputSize()
{
return (fromBuffer.size()+fromStack.size()) * (toBuffer.size()+toStack.size());
}
void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
auto & dict = getDict();
std::vector<long> fromIndexes, toIndexes;
for (int index : fromBuffer)
fromIndexes.emplace_back(config.getRelativeWordIndex(index));
for (int index : fromStack)
if (config.hasStack(index))
fromIndexes.emplace_back(config.getStack(index));
else
fromIndexes.emplace_back(-1);
for (int index : toBuffer)
toIndexes.emplace_back(config.getRelativeWordIndex(index));
for (int index : toStack)
if (config.hasStack(index))
toIndexes.emplace_back(config.getStack(index));
else
toIndexes.emplace_back(-1);
for (auto & contextElement : context)
{
for (auto from : fromIndexes)
for (auto to : toIndexes)
{
if (from == -1 or to == -1)
{
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
long dist = std::abs(config.getRelativeDistance(from, to));
if (dist <= threshold)
contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}", dist)));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr));
}
}
}
void DistanceModuleImpl::registerEmbeddings(std::filesystem::path path)
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, path);
}
...@@ -45,6 +45,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st ...@@ -45,6 +45,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second))); modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
else if (splited.first == "AppliableTrans") else if (splited.first == "AppliableTrans")
modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs))); modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs)));
else if (splited.first == "Distance")
modules.emplace_back(register_module(name, DistanceModule(nameH, splited.second)));
else if (splited.first == "DepthLayerTree") else if (splited.first == "DepthLayerTree")
modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second))); modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
else if (splited.first == "MLP") else if (splited.first == "MLP")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment