#include "ContextLSTM.hpp" ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold) { lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)); } std::size_t ContextLSTMImpl::getOutputSize() { return lstm->getOutputSize(bufferContext.size()+stackContext.size()); } std::size_t ContextLSTMImpl::getInputSize() { return columns.size()*(bufferContext.size()+stackContext.size()); } void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { std::vector<long> contextIndexes; for (int index : bufferContext) contextIndexes.emplace_back(config.getRelativeWordIndex(index)); for (int index : stackContext) if (config.hasStack(index)) contextIndexes.emplace_back(config.getStack(index)); else contextIndexes.emplace_back(-1); for (auto index : contextIndexes) for (auto & col : columns) if (index == -1) for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else { int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); for (auto & contextElement : context) contextElement.push_back(dictIndex); if (is_training()) for (auto & targetCol : unknownValueColumns) if (col == targetCol) if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) { context.emplace_back(context.back()); context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); } } } torch::Tensor ContextLSTMImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)}); return lstm(context); }