From 15d915ca7b8798c14da8177c5ee72e9bdb423472 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 8 Apr 2020 22:45:21 +0200
Subject: [PATCH] Fixed some problems in dependency parsing

---
 decoder/src/Decoder.cpp                       |  6 ++-
 reading_machine/include/Config.hpp            |  2 +
 reading_machine/src/Action.cpp                | 10 +++-
 reading_machine/src/BaseConfig.cpp            | 22 +++++++++
 reading_machine/src/Config.cpp                | 12 +++++
 .../include/DepthLayerTreeEmbedding.hpp       | 14 ++++--
 torch_modules/src/DepthLayerTreeEmbedding.cpp | 47 +++++++++++++++++--
 7 files changed, 102 insertions(+), 11 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 78e6204..f08bd03 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -89,11 +89,13 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
   // Force EOS when needed
   if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
   {
-    Action shift = Action::pushWordIndexOnStack();
-    shift.apply(config, shift);
+    machine.getTransitionSet().getTransition("SHIFT")->apply(config);
     machine.getTransitionSet().getTransition("EOS")->apply(config);
     if (debug)
+    {
       fmt::print(stderr, "Forcing EOS transition\n");
+      config.printForDebug(stderr);
+    }
   }
 
   // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index c9b7d4d..3ae84fc 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -21,6 +21,8 @@ class Config
   static constexpr const char * headColName = "HEAD";
   static constexpr const char * deprelColName = "DEPREL";
   static constexpr const char * idColName = "ID";
+  static constexpr const char * isMultiColName = "MULTI";
+  static constexpr const char * childsColName = "CHILDS";
   static constexpr int nbHypothesesMax = 1;
   static constexpr int maxNbAppliableSplitTransitions = 8;
 
diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp
index 889aaa7..b2e7adb 100644
--- a/reading_machine/src/Action.cpp
+++ b/reading_machine/src/Action.cpp
@@ -270,6 +270,9 @@ Action Action::pushWordIndexOnStack()
     if (config.hasStack(0) and config.getStack(0) == config.getWordIndex())
       return false;
 
+    if (config.hasStack(0) and !config.isTokenPredicted(config.getStack(0)))
+      return false;
+
     return (int)config.getWordIndex() != config.getLastPoppedStack();
   };
 
@@ -292,7 +295,7 @@ Action Action::popStack()
 
   auto appliable = [](const Config & config, const Action &)
   {
-    return config.hasStack(0);
+    return config.hasStack(0) and config.getStack(0) != config.getWordIndex();
   };
 
   return {Type::Pop, apply, undo, appliable}; 
@@ -499,7 +502,7 @@ Action Action::setRoot()
 
   auto appliable = [](const Config & config, const Action &)
   {
-    return config.hasStack(0);
+    return config.hasStack(0) and config.isTokenPredicted(config.getStack(0)) and config.getLastNotEmptyConst(Config::isMultiColName, config.getStack(0)) != Config::EOSSymbol1;
   };
 
   return {Type::Write, apply, undo, appliable}; 
@@ -605,6 +608,9 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent
       depLineIndex = config.getStack(dependentIndex);
     }
 
+    if (!config.isTokenPredicted(govLineIndex) or !config.isTokenPredicted(depLineIndex))
+      return false;
+
     // Check for cycles
     while (govLineIndex != depLineIndex)
     {
diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp
index 1eb719e..4997f6e 100644
--- a/reading_machine/src/BaseConfig.cpp
+++ b/reading_machine/src/BaseConfig.cpp
@@ -28,6 +28,16 @@ void BaseConfig::readMCD(std::string_view mcdFilename)
 
   std::fclose(file);
 
+  if (colName2Index.count(isMultiColName))
+    util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, isMultiColName));
+  colIndex2Name.emplace_back(isMultiColName);
+  colName2Index.emplace(isMultiColName, colIndex2Name.size()-1);
+
+  if (colName2Index.count(childsColName))
+    util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, childsColName));
+  colIndex2Name.emplace_back(childsColName);
+  colName2Index.emplace(childsColName, colIndex2Name.size()-1);
+
   if (colName2Index.count(EOSColName))
     util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, EOSColName));
   colIndex2Name.emplace_back(EOSColName);
@@ -64,6 +74,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
   int inputLineIndex = 0;
   bool inputHasBeenRead = false;
   int usualNbCol = -1;
+  int nbMultiwords = 0;
 
   while (!std::feof(file))
   {
@@ -116,6 +127,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
     {
       addLines(1);
       get(EOSColName, getNbLines()-1, 0) = EOSSymbol0;
+      get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0;
       get(0, getNbLines()-1, 0) = std::string(line);
       continue;
     }
@@ -134,6 +146,13 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
 
     addLines(1);
     get(EOSColName, getNbLines()-1, 0) = EOSSymbol0;
+    if (nbMultiwords > 0)
+    {
+      get(isMultiColName, getNbLines()-1, 0) = EOSSymbol1;
+      nbMultiwords--;
+    }
+    else
+      get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0;
 
     for (unsigned int i = 0; i < splited.size(); i++)
       if (i < colIndex2Name.size())
@@ -141,6 +160,9 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
         std::string value = std::string(splited[i]);
         get(i, getNbLines()-1, 0) = value;
       }
+
+    if (isMultiword(getNbLines()-1))
+      nbMultiwords = getMultiwordSize(getNbLines()-1)+1;
   }
 
   std::fclose(file);
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 8681cce..4e5b42e 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -98,6 +98,12 @@ void Config::print(FILE * dest) const
     }
     for (unsigned int i = 0; i < getNbColumns()-1; i++)
     {
+      if (getColName(i) == isMultiColName or getColName(i) == childsColName)
+      {
+        if (i == getNbColumns()-2)
+          currentSequence.back().back() = '\n';
+        continue;
+      }
       auto & colContent = getAsFeature(i, getFirstLineIndex()+line);
       std::string valueToPrint = colContent;
       try
@@ -139,7 +145,11 @@ void Config::printForDebug(FILE * dest) const
   toPrint.emplace_back();
   toPrint.back().emplace_back("");
   for (unsigned int i = 0; i < getNbColumns(); i++)
+  {
+    if (getColName(i) == isMultiColName or getColName(i) == childsColName)
+      continue;
     toPrint.back().emplace_back(getColName(i));
+  }
 
   for (int line = firstLineToPrint; line <= lastLineToPrint; line++)
   {
@@ -149,6 +159,8 @@ void Config::printForDebug(FILE * dest) const
     toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : "");
     for (unsigned int i = 0; i < getNbColumns(); i++)
     {
+      if (getColName(i) == isMultiColName or getColName(i) == childsColName)
+        continue;
       std::string colContent = has(i,line,0) ? getAsFeature(i, line).get() : "?";
       std::string toPrintCol = colContent;
       try
diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp
index d471e6b..6eb069b 100644
--- a/torch_modules/include/DepthLayerTreeEmbedding.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp
@@ -2,22 +2,28 @@
 #define DEPTHLAYERTREEEMBEDDING__H
 
 #include <torch/torch.h>
-#include "fmt/core.h"
+#include "Submodule.hpp"
 #include "LSTM.hpp"
 
-class DepthLayerTreeEmbeddingImpl : public torch::nn::Module
+class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
 {
   private :
 
+  std::vector<std::string> columns{"DEPREL"};
+  std::vector<int> focusedBuffer{0};
+  std::vector<int> focusedStack{0};
+  std::string firstElem{"__special_DepthLayerTreeEmbeddingImpl__"};
   std::vector<LSTM> depthLstm;
   int maxDepth;
   int maxElemPerDepth;
 
   public :
 
-  DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth);
+  DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
   torch::Tensor forward(torch::Tensor input);
-  int getOutputSize();
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
 };
 TORCH_MODULE(DepthLayerTreeEmbedding);
 
diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp
index d53a04a..3f1926d 100644
--- a/torch_modules/src/DepthLayerTreeEmbedding.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp
@@ -1,17 +1,58 @@
 #include "DepthLayerTreeEmbedding.hpp"
 
-DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
+DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
 {
-
+  for (int i = 0; i < maxDepth; i++)
+    depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(embeddingsSize, outEmbeddingsSize, options)));
 }
 
 torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
 {
+  auto context = input.narrow(1, firstInputIndex, getInputSize());
+
+  std::vector<torch::Tensor> outputs;
+
+  for (unsigned int i = 0; i < depthLstm.size(); i++)
+    for (unsigned int j = 0; j < focusedBuffer.size()+focusedStack.size(); j++)
+      outputs.emplace_back(depthLstm[i](input.narrow(1,i*(focusedBuffer.size()+focusedStack.size())*columns.size()*maxElemPerDepth + j*maxElemPerDepth, maxElemPerDepth)));
+
+  return torch::cat(outputs, 1);
+}
+
+std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize()
+{
+  std::size_t outputSize = 0;
 
+  for (auto & lstm : depthLstm)
+    outputSize += lstm->getOutputSize(maxElemPerDepth);
+
+  return outputSize;
+}
+
+std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
+{
+  return (focusedBuffer.size()+focusedStack.size())*columns.size()*maxDepth*maxElemPerDepth;
 }
 
-int DepthLayerTreeEmbeddingImpl::getOutputSize()
+void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
 {
+  std::vector<long> focusedIndexes;
+
+  for (int index : focusedBuffer)
+    focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
+
+  for (int index : focusedStack)
+    if (config.hasStack(index))
+      focusedIndexes.emplace_back(config.getStack(index));
+    else
+      focusedIndexes.emplace_back(-1);
+
+  for (auto & contextElement : context)
+  {
+    for (auto index : focusedIndexes)
+    {
 
+    }
+  }
 }
 
-- 
GitLab