Newer
Older
#include "DepthLayerTreeEmbeddingModule.hpp"
DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition)
std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
columns = util::split(sm.str(1), ' ');
for (auto & index : util::split(sm.str(2), ' '))
focusedBuffer.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(3), ' '))
focusedStack.emplace_back(std::stoi(index));
for (auto & elem : util::split(sm.str(4), ' '))
maxElemPerDepth.emplace_back(std::stoi(elem));
auto subModuleType = sm.str(5);
auto subModuleArguments = util::split(sm.str(6), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
Franck Dary
committed
inSize = std::stoi(sm.str(7));
int outSize = std::stoi(sm.str(8));
for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
{
std::string name = fmt::format("{}_{}", i, subModuleType);
if (subModuleType == "LSTM")
depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
else if (subModuleType == "GRU")
depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
else if (subModuleType == "Concat")
depthModules.emplace_back(register_module(name, Concat(inSize)));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
torch::Tensor DepthLayerTreeEmbeddingModuleImpl::forward(torch::Tensor input)
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()));
std::vector<torch::Tensor> outputs;
int offset = 0;
for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
{
outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({context.size(0), maxElemPerDepth[depth], (long)columns.size()*context.size(2)})));
offset += maxElemPerDepth[depth]*columns.size();
}
return torch::cat(outputs, 1);
}
std::size_t DepthLayerTreeEmbeddingModuleImpl::getOutputSize()
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
outputSize += depthModules[depth]->getOutputSize(maxElemPerDepth[depth]);
return outputSize*(focusedBuffer.size()+focusedStack.size());
std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
int inputSize = 0;
for (int maxElem : maxElemPerDepth)
inputSize += (focusedBuffer.size()+focusedStack.size())*maxElem*columns.size();
return inputSize;
void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
std::vector<long> focusedIndexes;
for (int index : focusedBuffer)
focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
for (int index : focusedStack)
if (config.hasStack(index))
focusedIndexes.emplace_back(config.getStack(index));
else
focusedIndexes.emplace_back(-1);
for (auto & contextElement : context)
for (auto index : focusedIndexes)
{
std::vector<std::string> childs{std::to_string(index)};
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
{
std::vector<std::string> newChilds;
for (auto & child : childs)
if (config.has(Config::childsColName, std::stoi(child), 0))
{
auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
newChilds.insert(newChilds.end(), val.begin(), val.end());
}
childs = newChilds;
for (int i = 0; i < maxElemPerDepth[depth]; i++)
for (auto & col : columns)
if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i]))));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
Franck Dary
committed
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
Franck Dary
committed
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));