From 4bfe61473b0cba6e89743f284f003d13da8a569f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 10 Apr 2020 12:26:07 +0200
Subject: [PATCH] Updated LSTMNetwork and attach action to work with tree
 embedding networks

---
 reading_machine/src/Action.cpp        |  8 ++++++++
 reading_machine/src/Classifier.cpp    | 12 +++++++++---
 torch_modules/include/LSTMNetwork.hpp |  5 ++++-
 torch_modules/src/LSTMNetwork.cpp     | 18 +++++++++++++++++-
 4 files changed, 38 insertions(+), 5 deletions(-)

diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp
index b2e7adb..4e070db 100644
--- a/reading_machine/src/Action.cpp
+++ b/reading_machine/src/Action.cpp
@@ -570,12 +570,20 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent
       lineIndex = config.getWordIndex() + governorIndex;
     else
       lineIndex = config.getStack(governorIndex);
+    int depIndex = 0;
+    if (dependentObject == Object::Buffer)
+      depIndex = config.getWordIndex() + dependentIndex;
+    else
+      depIndex = config.getStack(dependentIndex);
+
     addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a);
+    addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a);
   };
 
   auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a)
   {
     addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, "").undo(config, a);
+    addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, "").apply(config, a);
   };
 
   auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action)
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 27a4ef1..aeb6ac8 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -71,9 +71,8 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
 {
   int unknownValueThreshold;
   std::vector<int> bufferContext, stackContext;
-  std::vector<std::string> columns;
+  std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
   std::vector<int> focusedBuffer, focusedStack;
-  std::vector<std::string> focusedColumns;
   std::vector<int> maxNbElements;
   std::vector<std::pair<int, float>> mlp;
   int rawInputLeftWindow, rawInputRightWindow;
@@ -223,6 +222,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
 
-  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
+        {
+          treeEmbeddingColumns = util::split(sm.str(1), ' ');
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
+
+  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns));
 }
 
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index 07d689e..9cb051d 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -7,6 +7,7 @@
 #include "SplitTransLSTM.hpp"
 #include "FocusedColumnLSTM.hpp"
 #include "MLP.hpp"
+#include "DepthLayerTreeEmbedding.hpp"
 
 class LSTMNetworkImpl : public NeuralNetworkImpl
 {
@@ -20,13 +21,15 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
   ContextLSTM contextLSTM{nullptr};
   RawInputLSTM rawInputLSTM{nullptr};
   SplitTransLSTM splitTransLSTM{nullptr};
+  DepthLayerTreeEmbedding treeEmbedding{nullptr};
   std::vector<FocusedColumnLSTM> focusedLstms;
 
   bool hasRawInputLSTM{false};
+  bool hasTreeEmbedding{false};
 
   public :
 
-  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout);
+  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index ab06806..2a9b1b4 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -1,6 +1,6 @@
 #include "LSTMNetwork.hpp"
 
-LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
+LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns)
 {
   LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
   auto lstmOptionsAll = lstmOptions;
@@ -23,6 +23,15 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
     currentInputSize += rawInputLSTM->getInputSize();
   }
 
+  if (!treeEmbeddingColumns.empty())
+  {
+    hasTreeEmbedding = true;
+    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(1,3,embeddingsSize,128,treeEmbeddingColumns,focusedBufferIndexes,focusedStackIndexes,lstmOptionsAll));
+    treeEmbedding->setFirstInputIndex(currentInputSize);
+    currentOutputSize += treeEmbedding->getOutputSize();
+    currentInputSize += treeEmbedding->getInputSize();
+  }
+
   splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
   splitTransLSTM->setFirstInputIndex(currentInputSize);
   currentOutputSize += splitTransLSTM->getOutputSize();
@@ -56,6 +65,9 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
   if (hasRawInputLSTM)
     outputs.emplace_back(rawInputLSTM(embeddings));
 
+  if (hasTreeEmbedding)
+    outputs.emplace_back(treeEmbedding(embeddings));
+
   outputs.emplace_back(splitTransLSTM(embeddings));
 
   for (auto & lstm : focusedLstms)
@@ -79,6 +91,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
   contextLSTM->addToContext(context, dict, config);
   if (hasRawInputLSTM)
     rawInputLSTM->addToContext(context, dict, config);
+  fmt::print(stderr, "before {}\n", context.back().size());
+  if (hasTreeEmbedding)
+    treeEmbedding->addToContext(context, dict, config);
+  fmt::print(stderr, "after {}\n", context.back().size());
   splitTransLSTM->addToContext(context, dict, config);
   for (auto & lstm : focusedLstms)
     lstm->addToContext(context, dict, config);
-- 
GitLab