Skip to content
Snippets Groups Projects
Commit d2655230 authored by Franck Dary's avatar Franck Dary
Browse files

Added NumericColumnModule

parent 014779b1
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "StateNameModule.hpp"
#include "UppercaseRateModule.hpp"
#include "NumericColumnModule.hpp"
#include "MLP.hpp"
class ModularNetworkImpl : public NeuralNetworkImpl
......
#ifndef NUMERICCOLUMNMODULE__H
#define NUMERICCOLUMNMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
class NumericColumnModuleImpl : public Submodule
{
private :
int outSize;
std::vector<int> focusedBuffer, focusedStack;
std::shared_ptr<MyModule> myModule{nullptr};
std::string column;
public :
NumericColumnModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(NumericColumnModule);
#endif
......@@ -27,6 +27,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
else if (splited.first == "StateName")
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "NumericColumn")
modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
else if (splited.first == "UppercaseRate")
modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
else if (splited.first == "Focused")
......
#include "NumericColumnModule.hpp"
#include "NeuralNetwork.hpp"
NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::string & definition)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
column = sm.str(1);
for (auto & index : util::split(sm.str(2), ' '))
focusedBuffer.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(3), ' '))
focusedStack.emplace_back(std::stoi(index));
auto subModuleType = sm.str(4);
auto subModuleArguments = util::split(sm.str(5), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
int outSize = std::stoi(sm.str(6));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(1, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
void * dataPtr = context.flatten().data_ptr();
auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::kDouble).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1});
return myModule->forward(values);
}
std::size_t NumericColumnModuleImpl::getOutputSize()
{
return myModule->getOutputSize(getInputSize());
}
std::size_t NumericColumnModuleImpl::getInputSize()
{
return focusedBuffer.size() + focusedStack.size();
}
void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
std::vector<long> focusedIndexes;
for (int index : focusedBuffer)
focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
for (int index : focusedStack)
if (config.hasStack(index))
focusedIndexes.emplace_back(config.getStack(index));
else
focusedIndexes.emplace_back(-1);
for (auto & contextElement : context)
for (auto index : focusedIndexes)
{
double res = 0.0;
if (index >= 0)
res = std::stof(config.getAsFeature(column, index).get());
contextElement.emplace_back(0);
std::memcpy(&contextElement.back(), &res, sizeof res);
}
}
void NumericColumnModuleImpl::registerEmbeddings()
{
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment