From 11e87ce1fd20636526e6259ffbad18883b42fce5 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 3 Mar 2021 16:26:42 +0100
Subject: [PATCH] extractContext now directly gives a torch::Tensor

---
 decoder/src/Beam.cpp                          |   3 +-
 .../include/AppliableTransModule.hpp          |   2 +-
 torch_modules/include/ContextModule.hpp       |   2 +-
 torch_modules/include/ContextualModule.hpp    |   2 +-
 .../include/DepthLayerTreeEmbeddingModule.hpp |   2 +-
 torch_modules/include/DistanceModule.hpp      |   2 +-
 torch_modules/include/FocusedColumnModule.hpp |   2 +-
 torch_modules/include/HistoryModule.hpp       |   2 +-
 torch_modules/include/ModularNetwork.hpp      |   3 +-
 torch_modules/include/NeuralNetwork.hpp       |   2 +-
 torch_modules/include/NumericColumnModule.hpp |   2 +-
 torch_modules/include/RandomNetwork.hpp       |   2 +-
 torch_modules/include/RawInputModule.hpp      |   2 +-
 torch_modules/include/SplitTransModule.hpp    |   2 +-
 torch_modules/include/StateNameModule.hpp     |   2 +-
 torch_modules/include/Submodule.hpp           |   2 +-
 torch_modules/include/UppercaseRateModule.hpp |   2 +-
 torch_modules/src/AppliableTransModule.cpp    |  11 +-
 torch_modules/src/ContextModule.cpp           |  16 ++-
 torch_modules/src/ContextualModule.cpp        |  44 ++++----
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  48 ++++----
 torch_modules/src/DistanceModule.cpp          |  29 ++---
 torch_modules/src/FocusedColumnModule.cpp     | 104 +++++++++---------
 torch_modules/src/HistoryModule.cpp           |  13 +--
 torch_modules/src/ModularNetwork.cpp          |   6 +-
 torch_modules/src/NumericColumnModule.cpp     |  28 ++---
 torch_modules/src/RandomNetwork.cpp           |   5 +-
 torch_modules/src/RawInputModule.cpp          |  40 ++++---
 torch_modules/src/SplitTransModule.cpp        |  13 +--
 torch_modules/src/StateNameModule.cpp         |   5 +-
 torch_modules/src/UppercaseRateModule.cpp     |  33 +++---
 trainer/include/Trainer.hpp                   |   2 +-
 trainer/src/Trainer.cpp                       |  14 +--
 33 files changed, 226 insertions(+), 221 deletions(-)

diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 9aad4f0..c932733 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -45,8 +45,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
     auto appliableTransitions = machine.getTransitionSet(elements[index].config.getState()).getAppliableTransitions(elements[index].config);
     elements[index].config.setAppliableTransitions(appliableTransitions);
 
-    auto context = classifier.getNN()->extractContext(elements[index].config).back();
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
+    auto neuralInput = classifier.getNN()->extractContext(elements[index].config);
 
     auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0);
     float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp
index 5e6f9e4..98f5fe1 100644
--- a/torch_modules/include/AppliableTransModule.hpp
+++ b/torch_modules/include/AppliableTransModule.hpp
@@ -19,7 +19,7 @@ class AppliableTransModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(AppliableTransModule);
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index 5851887..c2e0668 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -30,7 +30,7 @@ class ContextModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(ContextModule);
diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp
index e7fb2a9..8483b1a 100644
--- a/torch_modules/include/ContextualModule.hpp
+++ b/torch_modules/include/ContextualModule.hpp
@@ -31,7 +31,7 @@ class ContextualModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(ContextualModule);
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 6da8943..3621e6e 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(DepthLayerTreeEmbeddingModule);
diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp
index 3702ad5..bafa0b8 100644
--- a/torch_modules/include/DistanceModule.hpp
+++ b/torch_modules/include/DistanceModule.hpp
@@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(DistanceModule);
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index af370c6..a7df331 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -29,7 +29,7 @@ class FocusedColumnModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(FocusedColumnModule);
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index 54418a6..b4a725b 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -25,7 +25,7 @@ class HistoryModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(HistoryModule);
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index ed73c30..31685e2 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -25,12 +25,13 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   MLP mlp{nullptr};
   std::vector<std::shared_ptr<Submodule>> modules;
   std::map<std::string,torch::nn::Linear> outputLayersPerState;
+  std::size_t totalInputSize{0};
 
   public :
 
   ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
   torch::Tensor forward(torch::Tensor input, const std::string & state) override;
-  std::vector<std::vector<long>> extractContext(Config & config) override;
+  torch::Tensor extractContext(Config & config) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 8215ad2..ffbcdea 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -15,7 +15,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   public :
 
   virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
-  virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
+  virtual torch::Tensor extractContext(Config & config) = 0;
   virtual void registerEmbeddings() = 0;
   virtual void saveDicts(std::filesystem::path path) = 0;
   virtual void loadDicts(std::filesystem::path path) = 0;
diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
index e3d46ae..3ee9cb2 100644
--- a/torch_modules/include/NumericColumnModule.hpp
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(NumericColumnModule);
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index 3c559e9..33d99a1 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -13,7 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
 
   RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input, const std::string & state) override;
-  std::vector<std::vector<long>> extractContext(Config &) override;
+  torch::Tensor extractContext(Config &) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index 26237e2..0ca658b 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -25,7 +25,7 @@ class RawInputModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(RawInputModule);
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index 1ef1796..b88491e 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(SplitTransModule);
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
index 3abfe82..ace1cbc 100644
--- a/torch_modules/include/StateNameModule.hpp
+++ b/torch_modules/include/StateNameModule.hpp
@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(StateNameModule);
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 1dbbdc7..f4722bf 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -24,7 +24,7 @@ class Submodule : public torch::nn::Module, public DictHolder
   void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
-  virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
+  virtual void addToContext(torch::Tensor & context, const Config & config) = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual void registerEmbeddings() = 0;
   std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp
index dcfb89c..9495661 100644
--- a/torch_modules/include/UppercaseRateModule.hpp
+++ b/torch_modules/include/UppercaseRateModule.hpp
@@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(UppercaseRateModule);
diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp
index c50586f..7a5c830 100644
--- a/torch_modules/src/AppliableTransModule.cpp
+++ b/torch_modules/src/AppliableTransModule.cpp
@@ -20,15 +20,12 @@ std::size_t AppliableTransModuleImpl::getInputSize()
   return nbTrans;
 }
 
-void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void AppliableTransModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & appliableTrans = config.getAppliableTransitions();
-  for (auto & contextElement : context)
-    for (int i = 0; i < nbTrans; i++)
-      if (i < (int)appliableTrans.size())
-        contextElement.emplace_back(appliableTrans[i]);
-      else
-        contextElement.emplace_back(0);
+  for (int i = 0; i < nbTrans; i++)
+    if (i < (int)appliableTrans.size())
+      context[firstInputIndex+i] = appliableTrans[i];
 }
 
 void AppliableTransModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index f457f31..19c7972 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -83,7 +83,7 @@ std::size_t ContextModuleImpl::getInputSize()
   return columns.size()*(targets.size());
 }
 
-void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> contextIndexes;
@@ -125,24 +125,22 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
       contextIndexes.emplace_back(-3);
     }
 
+  int insertIndex = 0;
   for (auto index : contextIndexes)
     for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
     {
       auto & col = columns[colIndex];
       if (index == -1)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
       }
       else if (index == -2)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col);
       }
       else if (index == -3)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col);
       }
       else
       {
@@ -162,9 +160,9 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
           dictIndex = dict.getIndexOrInsert(featureValue, col);
         }
 
-        for (auto & contextElement : context)
-          contextElement.push_back(dictIndex);
+        context[firstInputIndex+insertIndex] = dictIndex;
       }
+      insertIndex++;
     }
 }
 
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index f5f8562..1c56983 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -87,7 +87,7 @@ std::size_t ContextualModuleImpl::getInputSize()
   return columns.size()*(4+window.second-window.first)+targets.size();
 }
 
-void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> contextIndexes;
@@ -132,24 +132,23 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
     else
       targetIndexes.emplace_back(-1);
 
+  int insertIndex = 0;
+
   for (auto index : contextIndexes)
     for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
     {
       auto & col = columns[colIndex];
       if (index == -1)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
       }
       else if (index == -2)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col);
       }
       else if (index == -3)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col);
       }
       else
       {
@@ -169,33 +168,32 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
           dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col);
         }
 
-        for (auto & contextElement : context)
-          contextElement.push_back(dictIndex);
+        context[firstInputIndex+insertIndex] = dictIndex;
       }
+  
+      insertIndex++;
     }
 
   for (auto index : targetIndexes)
   {
     if (configIndex2ContextIndex.count(index))
     {
-      for (auto & contextElement : context)
-        contextElement.push_back(configIndex2ContextIndex.at(index));
+      context[firstInputIndex+insertIndex] = configIndex2ContextIndex.at(index);
     }
     else
     {
-      for (auto & contextElement : context)
-      {
-        // -1 == doesn't exist (s.0 when no stack)
-        if (index == -1)
-          contextElement.push_back(0);
-        // -2 == nochild
-        else if (index == -2)
-          contextElement.push_back(1);
-        // other == out of context bounds
-        else
-          contextElement.push_back(2);
-      }
+      // -1 == doesn't exist (s.0 when no stack)
+      if (index == -1)
+        context[firstInputIndex+insertIndex] = 0;
+      // -2 == nochild
+      else if (index == -2)
+        context[firstInputIndex+insertIndex] = 1;
+      // other == out of context bounds
+      else
+        context[firstInputIndex+insertIndex] = 2;
     }
+
+    insertIndex++;
   }
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 06c0b5f..ac90690 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -84,7 +84,7 @@ std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> focusedIndexes;
@@ -98,30 +98,34 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
-    for (auto index : focusedIndexes)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
+  {
+    std::vector<std::string> childs{std::to_string(index)};
+
+    for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
     {
-      std::vector<std::string> childs{std::to_string(index)};
+      std::vector<std::string> newChilds;
+      for (auto & child : childs)
+        if (config.has(Config::childsColName, std::stoi(child), 0))
+        {
+          auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
+          newChilds.insert(newChilds.end(), val.begin(), val.end());
+        }
+      childs = newChilds;
 
-      for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
-      {
-        std::vector<std::string> newChilds;
-        for (auto & child : childs)
-          if (config.has(Config::childsColName, std::stoi(child), 0))
-          {
-            auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
-            newChilds.insert(newChilds.end(), val.begin(), val.end());
-          }
-        childs = newChilds;
-
-        for (int i = 0; i < maxElemPerDepth[depth]; i++)
-          for (auto & col : columns)
-            if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
-              contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col));
-            else
-              contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
-      }
+      for (int i = 0; i < maxElemPerDepth[depth]; i++)
+        for (auto & col : columns)
+        {
+          if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
+            context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col);
+          else
+            context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
+
+          insertIndex++;
+        }
     }
+  }
 }
 
 void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index 2e71e25..fddf2e0 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -63,7 +63,7 @@ std::size_t DistanceModuleImpl::getInputSize()
   return (fromBuffer.size()+fromStack.size()) * (toBuffer.size()+toStack.size());
 }
 
-void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> fromIndexes, toIndexes;
@@ -88,25 +88,26 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
   std::string prefix = "DISTANCE";
 
-  for (auto & contextElement : context)
-  {
-    for (auto from : fromIndexes)
-      for (auto to : toIndexes)
+  int insertIndex = 0;
+  for (auto from : fromIndexes)
+    for (auto to : toIndexes)
+    {
+      if (from == -1 or to == -1)
+      {
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
+      }
+      else
       {
-        if (from == -1 or to == -1)
-        {
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
-          continue;
-        }
-
         long dist = std::abs(config.getRelativeDistance(from, to));
 
         if (dist <= threshold)
-          contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), ""));
+          context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), "");
         else
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr, prefix));
+          context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::unknownValueStr, prefix);
       }
-  }
+    
+      insertIndex++;
+    }
 }
 
 void DistanceModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 115f918..3fc25f0 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -82,7 +82,7 @@ std::size_t FocusedColumnModuleImpl::getInputSize()
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> focusedIndexes;
@@ -96,63 +96,67 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
   {
-    for (auto index : focusedIndexes)
+    if (index == -1)
     {
-      if (index == -1)
+      for (int i = 0; i < maxNbElements; i++)
       {
-        for (int i = 0; i < maxNbElements; i++)
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, column));
-        continue;
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, column);
+        insertIndex++;
       }
+      continue;
+    }
 
-      std::vector<std::string> elements;
-      if (column == "FORM")
-      {
-        auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get()));
-
-        //TODO don't use nullValueStr here
-        for (int i = 0; i < maxNbElements; i++)
-          if (i < (int)asUtf8.size())
-            elements.emplace_back(fmt::format("{}", asUtf8[i]));
-          else
-            elements.emplace_back(Dict::nullValueStr);
-      }
-      else if (column == "FEATS")
-      {
-        auto splited = util::split(func(config.getAsFeature(column, index).get()), '|');
+    std::vector<std::string> elements;
+    if (column == "FORM")
+    {
+      auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get()));
+
+      //TODO don't use nullValueStr here
+      for (int i = 0; i < maxNbElements; i++)
+        if (i < (int)asUtf8.size())
+          elements.emplace_back(fmt::format("{}", asUtf8[i]));
+        else
+          elements.emplace_back(Dict::nullValueStr);
+    }
+    else if (column == "FEATS")
+    {
+      auto splited = util::split(func(config.getAsFeature(column, index).get()), '|');
 
-        for (int i = 0; i < maxNbElements; i++)
-          if (i < (int)splited.size())
-            elements.emplace_back(splited[i]);
-          else
-            elements.emplace_back(Dict::nullValueStr);
-      }
-      else if (column == "ID")
-      {
-        if (config.isTokenPredicted(index))
-          elements.emplace_back("TOKEN");
-        else if (config.isMultiwordPredicted(index))
-          elements.emplace_back("MULTIWORD");
-        else if (config.isEmptyNodePredicted(index))
-          elements.emplace_back("EMPTYNODE");
-      }
-      else if (column == "EOS")
-      {
-        bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1;
-        elements.emplace_back(fmt::format("{}", isEOS));
-      }
-      else
-      {
-        elements.emplace_back(func(config.getAsFeature(column, index)));
-      }
+      for (int i = 0; i < maxNbElements; i++)
+        if (i < (int)splited.size())
+          elements.emplace_back(splited[i]);
+        else
+          elements.emplace_back(Dict::nullValueStr);
+    }
+    else if (column == "ID")
+    {
+      if (config.isTokenPredicted(index))
+        elements.emplace_back("TOKEN");
+      else if (config.isMultiwordPredicted(index))
+        elements.emplace_back("MULTIWORD");
+      else if (config.isEmptyNodePredicted(index))
+        elements.emplace_back("EMPTYNODE");
+    }
+    else if (column == "EOS")
+    {
+      bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1;
+      elements.emplace_back(fmt::format("{}", isEOS));
+    }
+    else
+    {
+      elements.emplace_back(func(config.getAsFeature(column, index)));
+    }
 
-      if ((int)elements.size() != maxNbElements)
-        util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
+    if ((int)elements.size() != maxNbElements)
+      util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
 
-      for (auto & element : elements)
-        contextElement.emplace_back(dict.getIndexOrInsert(element, column));
+    for (auto & element : elements)
+    {
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(element, column);
+      insertIndex++;
     }
   }
 }
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index 7d0912c..4a9033f 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -53,18 +53,17 @@ std::size_t HistoryModuleImpl::getInputSize()
   return maxNbElements;
 }
 
-void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
 
   std::string prefix = "HISTORY";
 
-  for (auto & contextElement : context)
-    for (int i = 0; i < maxNbElements; i++)
-      if (config.hasHistory(i))
-        contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i), prefix));
-      else
-        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
+  for (int i = 0; i < maxNbElements; i++)
+    if (config.hasHistory(i))
+      context[firstInputIndex+i] = dict.getIndexOrInsert(config.getHistory(i), prefix);
+    else
+      context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
 }
 
 void HistoryModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index c936f85..1c39f18 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -69,6 +69,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
     currentOutputSize += modules.back()->getOutputSize();
   }
 
+  totalInputSize = currentInputSize;
+
   if (mlpDef.empty())
     util::myThrow("no MLP definition found");
   if (inputDropout.is_empty())
@@ -95,9 +97,9 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string
   return outputLayersPerState.at(state)(mlp(totalInput));
 }
 
-std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
+torch::Tensor ModularNetworkImpl::extractContext(Config & config)
 {
-  std::vector<std::vector<long>> context(1);
+  torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
   for (auto & mod : modules)
     mod->addToContext(context, config);
   return context;
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index 15d2b19..49d3016 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -60,7 +60,7 @@ std::size_t NumericColumnModuleImpl::getInputSize()
   return focusedBuffer.size() + focusedStack.size();
 }
 
-void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   std::vector<long> focusedIndexes;
 
@@ -73,21 +73,21 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
-    for (auto index : focusedIndexes)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
+  {
+    double res = 0.0;
+    if (index >= 0)
     {
-      double res = 0.0;
-      if (index >= 0)
-      {
-        auto value = config.getAsFeature(column, index).get();
-        try {res = (value == "_" or value == "NA") ? defaultValue : std::stof(value);}
-        catch (std::exception & e)
-          {util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
-      }
-
-      contextElement.emplace_back(0);
-      std::memcpy(&contextElement.back(), &res, sizeof res);
+      auto value = config.getAsFeature(column, index).get();
+      try {res = (value == "_" or value == "NA") ? defaultValue : std::stof(value);}
+      catch (std::exception & e)
+        {util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
     }
+
+    //TODO : Check if this works
+    context[firstInputIndex+insertIndex] = res;
+  }
 }
 
 void NumericColumnModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 87a6046..b05d1aa 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -13,9 +13,10 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string
   return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
 }
 
-std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
+torch::Tensor RandomNetworkImpl::extractContext(Config &)
 {
-  return std::vector<std::vector<long>>{{0}};
+  torch::Tensor context;
+  return context;
 }
 
 void RandomNetworkImpl::registerEmbeddings()
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index d948386..2d6bd62 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -11,6 +11,9 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
             leftWindow = std::stoi(sm.str(1));
             rightWindow = std::stoi(sm.str(2));
 
+            if (leftWindow < 0 or rightWindow < 0)
+              util::myThrow(fmt::format("Invalid negative values for leftWindow({}) or rightWindow({})", leftWindow, rightWindow));
+
             auto subModuleType = sm.str(3);
             auto subModuleArguments = util::split(sm.str(4), ' ');
 
@@ -54,27 +57,30 @@ std::size_t RawInputModuleImpl::getInputSize()
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
-  if (leftWindow < 0 or rightWindow < 0)
-    return;
-
   std::string prefix = "LETTER";
-
   auto & dict = getDict();
-  for (auto & contextElement : context)
+
+  int insertIndex = 0;
+  for (int i = 0; i < leftWindow; i++)
   {
-    for (int i = 0; i < leftWindow; i++)
-      if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix));
-      else
-        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
-
-    for (int i = 0; i <= rightWindow; i++)
-      if (config.hasCharacter(config.getCharacterIndex()+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix));
-      else
-        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
+    if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix);
+    else
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
+  
+    insertIndex++;
+  }
+
+  for (int i = 0; i <= rightWindow; i++)
+  {
+    if (config.hasCharacter(config.getCharacterIndex()+i))
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix);
+    else
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
+
+    insertIndex++;
   }
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 0c1de2e..dcb78e1 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -51,16 +51,15 @@ std::size_t SplitTransModuleImpl::getInputSize()
   return maxNbTrans;
 }
 
-void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   auto & splitTransitions = config.getAppliableSplitTransitions();
-  for (auto & contextElement : context)
-    for (int i = 0; i < maxNbTrans; i++)
-      if (i < (int)splitTransitions.size())
-        contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName(), ""));
-      else
-        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, ""));
+  for (int i = 0; i < maxNbTrans; i++)
+    if (i < (int)splitTransitions.size())
+      context[firstInputIndex+i] = dict.getIndexOrInsert(splitTransitions[i]->getName(), "");
+    else
+      context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, "");
 }
 
 void SplitTransModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index 0c642b9..f3ac977 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -29,11 +29,10 @@ std::size_t StateNameModuleImpl::getInputSize()
   return 1;
 }
 
-void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
-  for (auto & contextElement : context)
-    contextElement.emplace_back(dict.getIndexOrInsert(config.getState(), ""));
+  context[firstInputIndex] = dict.getIndexOrInsert(config.getState(), "");
 }
 
 void StateNameModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index 0452eb8..7c84615 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -56,7 +56,7 @@ std::size_t UppercaseRateModuleImpl::getInputSize()
   return focusedBuffer.size() + focusedStack.size();
 }
 
-void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   std::vector<long> focusedIndexes;
 
@@ -69,25 +69,24 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
   {
-    for (auto index : focusedIndexes)
+    double res = -1.0;
+    if (index >= 0)
     {
-      double res = -1.0;
-      if (index >= 0)
-      {
-        auto word = util::splitAsUtf8(config.getAsFeature("FORM", index).get());
-        int nbUpper = 0;
-        for (auto & letter : word)
-          if (util::isUppercase(letter))
-            nbUpper++;
-        if (word.size() > 0)
-          res = 1.0*nbUpper/word.size();
-      }
-
-      contextElement.emplace_back(0);
-      std::memcpy(&contextElement.back(), &res, sizeof res);
+      auto word = util::splitAsUtf8(config.getAsFeature("FORM", index).get());
+      int nbUpper = 0;
+      for (auto & letter : word)
+        if (util::isUppercase(letter))
+          nbUpper++;
+      if (word.size() > 0)
+        res = 1.0*nbUpper/word.size();
     }
+
+    //TODO : Check if this works
+    context[firstInputIndex+insertIndex] = res;
+    insertIndex++;
   }
 
 }
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index dfa465e..a5088d0 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -34,7 +34,7 @@ class Trainer
     int lastSavedIndex{0};
 
     void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
-    void addContext(std::vector<std::vector<long>> & context);
+    void addContext(torch::Tensor & context);
     void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes);
   };
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 56ecc44..4143b2c 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -65,7 +65,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
       config.setAppliableTransitions(appliableTransitions);
 
-      std::vector<std::vector<long>> context;
+      torch::Tensor context;
 
       try
       {
@@ -92,8 +92,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
       {
         auto & classifier = *machine.getClassifier(config.getState());
-        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
-        auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0);
+        auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
         entropy  = NeuralNetworkImpl::entropy(prediction);
     
         std::vector<int> candidates;
@@ -154,7 +153,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
 
       if (!exampleIsBanned)
       {
-        totalNbExamples += context.size();
+        totalNbExamples += 1;
         if (totalNbExamples >= (int)safetyNbExamplesMax)
           util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
 
@@ -295,12 +294,11 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
   classes.clear();
 }
 
-void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
+void Trainer::Examples::addContext(torch::Tensor & context)
 {
-  for (auto & element : context)
-    contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
+  contexts.emplace_back(context);
 
-  currentExampleIndex += context.size();
+  currentExampleIndex += 1;
 }
 
 void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes)
-- 
GitLab