From 15d915ca7b8798c14da8177c5ee72e9bdb423472 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 8 Apr 2020 22:45:21 +0200 Subject: [PATCH] Fixed some problems in dependency parsing --- decoder/src/Decoder.cpp | 6 ++- reading_machine/include/Config.hpp | 2 + reading_machine/src/Action.cpp | 10 +++- reading_machine/src/BaseConfig.cpp | 22 +++++++++ reading_machine/src/Config.cpp | 12 +++++ .../include/DepthLayerTreeEmbedding.hpp | 14 ++++-- torch_modules/src/DepthLayerTreeEmbedding.cpp | 47 +++++++++++++++++-- 7 files changed, 102 insertions(+), 11 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 78e6204..f08bd03 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -89,11 +89,13 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool // Force EOS when needed if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) { - Action shift = Action::pushWordIndexOnStack(); - shift.apply(config, shift); + machine.getTransitionSet().getTransition("SHIFT")->apply(config); machine.getTransitionSet().getTransition("EOS")->apply(config); if (debug) + { fmt::print(stderr, "Forcing EOS transition\n"); + config.printForDebug(stderr); + } } // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index c9b7d4d..3ae84fc 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -21,6 +21,8 @@ class Config static constexpr const char * headColName = "HEAD"; static constexpr const char * deprelColName = "DEPREL"; static constexpr const char * idColName = "ID"; + static constexpr const char * isMultiColName = "MULTI"; + static constexpr const char * childsColName = "CHILDS"; static constexpr int nbHypothesesMax = 1; static constexpr int maxNbAppliableSplitTransitions = 8; diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 889aaa7..b2e7adb 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -270,6 +270,9 @@ Action Action::pushWordIndexOnStack() if (config.hasStack(0) and config.getStack(0) == config.getWordIndex()) return false; + if (config.hasStack(0) and !config.isTokenPredicted(config.getStack(0))) + return false; + return (int)config.getWordIndex() != config.getLastPoppedStack(); }; @@ -292,7 +295,7 @@ Action Action::popStack() auto appliable = [](const Config & config, const Action &) { - return config.hasStack(0); + return config.hasStack(0) and config.getStack(0) != config.getWordIndex(); }; return {Type::Pop, apply, undo, appliable}; @@ -499,7 +502,7 @@ Action Action::setRoot() auto appliable = [](const Config & config, const Action &) { - return config.hasStack(0); + return config.hasStack(0) and config.isTokenPredicted(config.getStack(0)) and config.getLastNotEmptyConst(Config::isMultiColName, config.getStack(0)) != Config::EOSSymbol1; }; return {Type::Write, apply, undo, appliable}; @@ -605,6 +608,9 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent depLineIndex = config.getStack(dependentIndex); } + if (!config.isTokenPredicted(govLineIndex) or !config.isTokenPredicted(depLineIndex)) + return false; + // Check for cycles while (govLineIndex != depLineIndex) { diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 1eb719e..4997f6e 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -28,6 +28,16 @@ void BaseConfig::readMCD(std::string_view mcdFilename) std::fclose(file); + if (colName2Index.count(isMultiColName)) + util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, isMultiColName)); + colIndex2Name.emplace_back(isMultiColName); + colName2Index.emplace(isMultiColName, colIndex2Name.size()-1); + + if (colName2Index.count(childsColName)) + util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, childsColName)); + colIndex2Name.emplace_back(childsColName); + colName2Index.emplace(childsColName, colIndex2Name.size()-1); + if (colName2Index.count(EOSColName)) util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, EOSColName)); colIndex2Name.emplace_back(EOSColName); @@ -64,6 +74,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) int inputLineIndex = 0; bool inputHasBeenRead = false; int usualNbCol = -1; + int nbMultiwords = 0; while (!std::feof(file)) { @@ -116,6 +127,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) { addLines(1); get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; + get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; get(0, getNbLines()-1, 0) = std::string(line); continue; } @@ -134,6 +146,13 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) addLines(1); get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; + if (nbMultiwords > 0) + { + get(isMultiColName, getNbLines()-1, 0) = EOSSymbol1; + nbMultiwords--; + } + else + get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; for (unsigned int i = 0; i < splited.size(); i++) if (i < colIndex2Name.size()) @@ -141,6 +160,9 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) std::string value = std::string(splited[i]); get(i, getNbLines()-1, 0) = value; } + + if (isMultiword(getNbLines()-1)) + nbMultiwords = getMultiwordSize(getNbLines()-1)+1; } std::fclose(file); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 8681cce..4e5b42e 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -98,6 +98,12 @@ void Config::print(FILE * dest) const } for (unsigned int i = 0; i < getNbColumns()-1; i++) { + if (getColName(i) == isMultiColName or getColName(i) == childsColName) + { + if (i == getNbColumns()-2) + currentSequence.back().back() = '\n'; + continue; + } auto & colContent = getAsFeature(i, getFirstLineIndex()+line); std::string valueToPrint = colContent; try @@ -139,7 +145,11 @@ void Config::printForDebug(FILE * dest) const toPrint.emplace_back(); toPrint.back().emplace_back(""); for (unsigned int i = 0; i < getNbColumns(); i++) + { + if (getColName(i) == isMultiColName or getColName(i) == childsColName) + continue; toPrint.back().emplace_back(getColName(i)); + } for (int line = firstLineToPrint; line <= lastLineToPrint; line++) { @@ -149,6 +159,8 @@ void Config::printForDebug(FILE * dest) const toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : ""); for (unsigned int i = 0; i < getNbColumns(); i++) { + if (getColName(i) == isMultiColName or getColName(i) == childsColName) + continue; std::string colContent = has(i,line,0) ? getAsFeature(i, line).get() : "?"; std::string toPrintCol = colContent; try diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp index d471e6b..6eb069b 100644 --- a/torch_modules/include/DepthLayerTreeEmbedding.hpp +++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp @@ -2,22 +2,28 @@ #define DEPTHLAYERTREEEMBEDDING__H #include <torch/torch.h> -#include "fmt/core.h" +#include "Submodule.hpp" #include "LSTM.hpp" -class DepthLayerTreeEmbeddingImpl : public torch::nn::Module +class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule { private : + std::vector<std::string> columns{"DEPREL"}; + std::vector<int> focusedBuffer{0}; + std::vector<int> focusedStack{0}; + std::string firstElem{"__special_DepthLayerTreeEmbeddingImpl__"}; std::vector<LSTM> depthLstm; int maxDepth; int maxElemPerDepth; public : - DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth); + DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); torch::Tensor forward(torch::Tensor input); - int getOutputSize(); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; }; TORCH_MODULE(DepthLayerTreeEmbedding); diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp index d53a04a..3f1926d 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -1,17 +1,58 @@ #include "DepthLayerTreeEmbedding.hpp" -DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth) +DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth) { - + for (int i = 0; i < maxDepth; i++) + depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(embeddingsSize, outEmbeddingsSize, options))); } torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) { + auto context = input.narrow(1, firstInputIndex, getInputSize()); + + std::vector<torch::Tensor> outputs; + + for (unsigned int i = 0; i < depthLstm.size(); i++) + for (unsigned int j = 0; j < focusedBuffer.size()+focusedStack.size(); j++) + outputs.emplace_back(depthLstm[i](input.narrow(1,i*(focusedBuffer.size()+focusedStack.size())*columns.size()*maxElemPerDepth + j*maxElemPerDepth, maxElemPerDepth))); + + return torch::cat(outputs, 1); +} + +std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize() +{ + std::size_t outputSize = 0; + for (auto & lstm : depthLstm) + outputSize += lstm->getOutputSize(maxElemPerDepth); + + return outputSize; +} + +std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() +{ + return (focusedBuffer.size()+focusedStack.size())*columns.size()*maxDepth*maxElemPerDepth; } -int DepthLayerTreeEmbeddingImpl::getOutputSize() +void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { + std::vector<long> focusedIndexes; + + for (int index : focusedBuffer) + focusedIndexes.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : focusedStack) + if (config.hasStack(index)) + focusedIndexes.emplace_back(config.getStack(index)); + else + focusedIndexes.emplace_back(-1); + + for (auto & contextElement : context) + { + for (auto index : focusedIndexes) + { + } + } } -- GitLab