#include "ContextModule.hpp" ContextModuleImpl::ContextModuleImpl(const std::string & definition) { std::regex regex("(?:(?:\\s|\\t)*)Unk\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try { unknownValueThreshold = std::stoi(sm.str(1)); for (auto & index : util::split(sm.str(2), ' ')) bufferContext.emplace_back(std::stoi(index)); for (auto & index : util::split(sm.str(3), ' ')) stackContext.emplace_back(std::stoi(index)); columns = util::split(sm.str(4), ' '); auto subModuleType = sm.str(5); auto subModuleArguments = util::split(sm.str(6), ' '); auto options = MyModule::ModuleOptions(true) .bidirectional(std::stoi(subModuleArguments[0])) .num_layers(std::stoi(subModuleArguments[1])) .dropout(std::stof(subModuleArguments[2])) .complete(std::stoi(subModuleArguments[3])); int inSize = std::stoi(sm.str(7)); int outSize = std::stoi(sm.str(8)); wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} })) util::myThrow(fmt::format("invalid definition '{}'", definition)); } std::size_t ContextModuleImpl::getOutputSize() { return myModule->getOutputSize(bufferContext.size()+stackContext.size()); } std::size_t ContextModuleImpl::getInputSize() { return columns.size()*(bufferContext.size()+stackContext.size()); } void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const { std::vector<long> contextIndexes; for (int index : bufferContext) contextIndexes.emplace_back(config.getRelativeWordIndex(index)); for (int index : stackContext) if (config.hasStack(index)) contextIndexes.emplace_back(config.getStack(index)); else contextIndexes.emplace_back(-1); for (auto index : contextIndexes) for (auto & col : columns) if (index == -1) { for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); } else { int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); for (auto & contextElement : context) contextElement.push_back(dictIndex); for (auto & targetCol : unknownValueColumns) if (col == targetCol) if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); } } torch::Tensor ContextModuleImpl::forward(torch::Tensor input) { auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)}); return myModule->forward(context); }