diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 9aad4f048b3e065e554732bdf1b56326bbff78c3..c9327339675b254fa33cfe9d6773598fdcfdda79 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -45,8 +45,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
     auto appliableTransitions = machine.getTransitionSet(elements[index].config.getState()).getAppliableTransitions(elements[index].config);
     elements[index].config.setAppliableTransitions(appliableTransitions);
 
-    auto context = classifier.getNN()->extractContext(elements[index].config).back();
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
+    auto neuralInput = classifier.getNN()->extractContext(elements[index].config);
 
     auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0);
     float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp
index 5e6f9e461109eac691920e9763106681f1461f38..98f5fe13a0b5e5ca9c46caa2318d4ab054509b69 100644
--- a/torch_modules/include/AppliableTransModule.hpp
+++ b/torch_modules/include/AppliableTransModule.hpp
@@ -19,7 +19,7 @@ class AppliableTransModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(AppliableTransModule);
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index 585188793b32779c8da3f92bfe6603b71ccbb6ae..c2e0668135942d7dbdfaf0fff9e6f5d72fc66880 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -30,7 +30,7 @@ class ContextModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(ContextModule);
diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp
index e7fb2a90e62fe2840de0e56a3050960c137d67d3..8483b1a1c6199660fe88c6054888144a6bf1f6e0 100644
--- a/torch_modules/include/ContextualModule.hpp
+++ b/torch_modules/include/ContextualModule.hpp
@@ -31,7 +31,7 @@ class ContextualModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(ContextualModule);
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 6da8943b132306bc57ef27780269be9263b1037f..3621e6e5df8165963788872100df71c4adaa7aad 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(DepthLayerTreeEmbeddingModule);
diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp
index 3702ad58e6f1cf084c8bd0bd823e62bb4112c774..bafa0b8e3d852ab22e6f787ec005b14c0ff81f5f 100644
--- a/torch_modules/include/DistanceModule.hpp
+++ b/torch_modules/include/DistanceModule.hpp
@@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(DistanceModule);
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index af370c64509546a8adc3c3558f3775f14dbbfd03..a7df33187f2486e9aac28d60ca3dbd5a77492ecc 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -29,7 +29,7 @@ class FocusedColumnModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(FocusedColumnModule);
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index 54418a6398ba15f24821def4059b719e24737746..b4a725b8a4cf2ec79b1891cb39cfa21530f2278f 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -25,7 +25,7 @@ class HistoryModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(HistoryModule);
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index ed73c301bd90134f50b98a4778d5a6539b54f9aa..31685e29ea45326c94acbfa62efe9dcc280a747b 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -25,12 +25,13 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   MLP mlp{nullptr};
   std::vector<std::shared_ptr<Submodule>> modules;
   std::map<std::string,torch::nn::Linear> outputLayersPerState;
+  std::size_t totalInputSize{0};
 
   public :
 
   ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
   torch::Tensor forward(torch::Tensor input, const std::string & state) override;
-  std::vector<std::vector<long>> extractContext(Config & config) override;
+  torch::Tensor extractContext(Config & config) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 8215ad2fff9438ab0b6e133f1591b9cddc7c369e..ffbcdea03d406e38aaef06830b54e2809f27b0a4 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -15,7 +15,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   public :
 
   virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
-  virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
+  virtual torch::Tensor extractContext(Config & config) = 0;
   virtual void registerEmbeddings() = 0;
   virtual void saveDicts(std::filesystem::path path) = 0;
   virtual void loadDicts(std::filesystem::path path) = 0;
diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
index e3d46aece017226572fbc469e53f85dfbaddc518..3ee9cb217f31b1eeaf9e0d44554e19ce078d1923 100644
--- a/torch_modules/include/NumericColumnModule.hpp
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(NumericColumnModule);
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index 3c559e9c393146a86bb9ffc1c6f3dd42f0a89ac2..33d99a1455b521f88e972d002103baa4053ee5da 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -13,7 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
 
   RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input, const std::string & state) override;
-  std::vector<std::vector<long>> extractContext(Config &) override;
+  torch::Tensor extractContext(Config &) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
   void loadDicts(std::filesystem::path path) override;
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index 26237e2fefc63093b693ebf1860f99492e7beeed..0ca658b978bc2d0bacb6a6514e64751f5305d5a5 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -25,7 +25,7 @@ class RawInputModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(RawInputModule);
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index 1ef1796c490f2192b9d12fb1f2a77bdfe8ec786c..b88491e6d6f0f019b7c3c5e8c0f8482308fa5b72 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(SplitTransModule);
diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp
index 3abfe826021001fe3bc1d04016bea3ed353069b2..ace1cbc63d66e164a155a9b9b611a7d74c82233e 100644
--- a/torch_modules/include/StateNameModule.hpp
+++ b/torch_modules/include/StateNameModule.hpp
@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(StateNameModule);
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 1dbbdc7e46844a910a5d0884c46e8f6e62f192ae..f4722bf195537ea4ae0d1b93f9bae38cab268c6e 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -24,7 +24,7 @@ class Submodule : public torch::nn::Module, public DictHolder
   void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
-  virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
+  virtual void addToContext(torch::Tensor & context, const Config & config) = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual void registerEmbeddings() = 0;
   std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp
index dcfb89c3d08f42dfd4f595d79f48dcd639594559..94956613a7974bfbcb91e7ea2768f6ba195a3ecd 100644
--- a/torch_modules/include/UppercaseRateModule.hpp
+++ b/torch_modules/include/UppercaseRateModule.hpp
@@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void addToContext(torch::Tensor & context, const Config & config) override;
   void registerEmbeddings() override;
 };
 TORCH_MODULE(UppercaseRateModule);
diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp
index c50586f1e49ed7001d0ecc6643b2f6af36047226..7a5c830cedff73c185e930141611b396cfdb8b44 100644
--- a/torch_modules/src/AppliableTransModule.cpp
+++ b/torch_modules/src/AppliableTransModule.cpp
@@ -20,15 +20,12 @@ std::size_t AppliableTransModuleImpl::getInputSize()
   return nbTrans;
 }
 
-void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void AppliableTransModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & appliableTrans = config.getAppliableTransitions();
-  for (auto & contextElement : context)
-    for (int i = 0; i < nbTrans; i++)
-      if (i < (int)appliableTrans.size())
-        contextElement.emplace_back(appliableTrans[i]);
-      else
-        contextElement.emplace_back(0);
+  for (int i = 0; i < nbTrans; i++)
+    if (i < (int)appliableTrans.size())
+      context[firstInputIndex+i] = appliableTrans[i];
 }
 
 void AppliableTransModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index f457f314b3595bf0c51d438893c3743609605a98..19c7972123aab7fa47822aaf5e290cb69fbdb54a 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -83,7 +83,7 @@ std::size_t ContextModuleImpl::getInputSize()
   return columns.size()*(targets.size());
 }
 
-void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> contextIndexes;
@@ -125,24 +125,22 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
       contextIndexes.emplace_back(-3);
     }
 
+  int insertIndex = 0;
   for (auto index : contextIndexes)
     for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
     {
       auto & col = columns[colIndex];
       if (index == -1)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
       }
       else if (index == -2)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col);
       }
       else if (index == -3)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col);
       }
       else
       {
@@ -162,9 +160,9 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
           dictIndex = dict.getIndexOrInsert(featureValue, col);
         }
 
-        for (auto & contextElement : context)
-          contextElement.push_back(dictIndex);
+        context[firstInputIndex+insertIndex] = dictIndex;
       }
+      insertIndex++;
     }
 }
 
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index f5f8562e26b811af74966ec8343eeaacb6774e2f..1c569831a4b15f71a8314326e5c7eccd96c337d6 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -87,7 +87,7 @@ std::size_t ContextualModuleImpl::getInputSize()
   return columns.size()*(4+window.second-window.first)+targets.size();
 }
 
-void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> contextIndexes;
@@ -132,24 +132,23 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
     else
       targetIndexes.emplace_back(-1);
 
+  int insertIndex = 0;
+
   for (auto index : contextIndexes)
     for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
     {
       auto & col = columns[colIndex];
       if (index == -1)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
       }
       else if (index == -2)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col);
       }
       else if (index == -3)
       {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col);
       }
       else
       {
@@ -169,33 +168,32 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
           dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col);
         }
 
-        for (auto & contextElement : context)
-          contextElement.push_back(dictIndex);
+        context[firstInputIndex+insertIndex] = dictIndex;
       }
+  
+      insertIndex++;
     }
 
   for (auto index : targetIndexes)
   {
     if (configIndex2ContextIndex.count(index))
     {
-      for (auto & contextElement : context)
-        contextElement.push_back(configIndex2ContextIndex.at(index));
+      context[firstInputIndex+insertIndex] = configIndex2ContextIndex.at(index);
     }
     else
     {
-      for (auto & contextElement : context)
-      {
-        // -1 == doesn't exist (s.0 when no stack)
-        if (index == -1)
-          contextElement.push_back(0);
-        // -2 == nochild
-        else if (index == -2)
-          contextElement.push_back(1);
-        // other == out of context bounds
-        else
-          contextElement.push_back(2);
-      }
+      // -1 == doesn't exist (s.0 when no stack)
+      if (index == -1)
+        context[firstInputIndex+insertIndex] = 0;
+      // -2 == nochild
+      else if (index == -2)
+        context[firstInputIndex+insertIndex] = 1;
+      // other == out of context bounds
+      else
+        context[firstInputIndex+insertIndex] = 2;
     }
+
+    insertIndex++;
   }
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 06c0b5fe222ebbddfa644feed25ff1a72249ce45..ac906908408baf2c8375618294a9892929bc3062 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -84,7 +84,7 @@ std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> focusedIndexes;
@@ -98,30 +98,34 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
-    for (auto index : focusedIndexes)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
+  {
+    std::vector<std::string> childs{std::to_string(index)};
+
+    for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
     {
-      std::vector<std::string> childs{std::to_string(index)};
+      std::vector<std::string> newChilds;
+      for (auto & child : childs)
+        if (config.has(Config::childsColName, std::stoi(child), 0))
+        {
+          auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
+          newChilds.insert(newChilds.end(), val.begin(), val.end());
+        }
+      childs = newChilds;
 
-      for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
-      {
-        std::vector<std::string> newChilds;
-        for (auto & child : childs)
-          if (config.has(Config::childsColName, std::stoi(child), 0))
-          {
-            auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
-            newChilds.insert(newChilds.end(), val.begin(), val.end());
-          }
-        childs = newChilds;
-
-        for (int i = 0; i < maxElemPerDepth[depth]; i++)
-          for (auto & col : columns)
-            if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
-              contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col));
-            else
-              contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
-      }
+      for (int i = 0; i < maxElemPerDepth[depth]; i++)
+        for (auto & col : columns)
+        {
+          if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
+            context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col);
+          else
+            context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
+
+          insertIndex++;
+        }
     }
+  }
 }
 
 void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index 2e71e25cdfd6b174af1115ef636e28cc581365e3..fddf2e00867c73b9a8d7559c03bafcc6a059177f 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -63,7 +63,7 @@ std::size_t DistanceModuleImpl::getInputSize()
   return (fromBuffer.size()+fromStack.size()) * (toBuffer.size()+toStack.size());
 }
 
-void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> fromIndexes, toIndexes;
@@ -88,25 +88,26 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 
   std::string prefix = "DISTANCE";
 
-  for (auto & contextElement : context)
-  {
-    for (auto from : fromIndexes)
-      for (auto to : toIndexes)
+  int insertIndex = 0;
+  for (auto from : fromIndexes)
+    for (auto to : toIndexes)
+    {
+      if (from == -1 or to == -1)
+      {
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
+      }
+      else
       {
-        if (from == -1 or to == -1)
-        {
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
-          continue;
-        }
-
         long dist = std::abs(config.getRelativeDistance(from, to));
 
         if (dist <= threshold)
-          contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), ""));
+          context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), "");
         else
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr, prefix));
+          context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::unknownValueStr, prefix);
       }
-  }
+    
+      insertIndex++;
+    }
 }
 
 void DistanceModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 115f918ad3c845a52f1366b277d42b6b35e4b616..3fc25f0ad53ef521f05da2bdbc5aabf0216ef096 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -82,7 +82,7 @@ std::size_t FocusedColumnModuleImpl::getInputSize()
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   std::vector<long> focusedIndexes;
@@ -96,63 +96,67 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
   {
-    for (auto index : focusedIndexes)
+    if (index == -1)
     {
-      if (index == -1)
+      for (int i = 0; i < maxNbElements; i++)
       {
-        for (int i = 0; i < maxNbElements; i++)
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, column));
-        continue;
+        context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, column);
+        insertIndex++;
       }
+      continue;
+    }
 
-      std::vector<std::string> elements;
-      if (column == "FORM")
-      {
-        auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get()));
-
-        //TODO don't use nullValueStr here
-        for (int i = 0; i < maxNbElements; i++)
-          if (i < (int)asUtf8.size())
-            elements.emplace_back(fmt::format("{}", asUtf8[i]));
-          else
-            elements.emplace_back(Dict::nullValueStr);
-      }
-      else if (column == "FEATS")
-      {
-        auto splited = util::split(func(config.getAsFeature(column, index).get()), '|');
+    std::vector<std::string> elements;
+    if (column == "FORM")
+    {
+      auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get()));
+
+      //TODO don't use nullValueStr here
+      for (int i = 0; i < maxNbElements; i++)
+        if (i < (int)asUtf8.size())
+          elements.emplace_back(fmt::format("{}", asUtf8[i]));
+        else
+          elements.emplace_back(Dict::nullValueStr);
+    }
+    else if (column == "FEATS")
+    {
+      auto splited = util::split(func(config.getAsFeature(column, index).get()), '|');
 
-        for (int i = 0; i < maxNbElements; i++)
-          if (i < (int)splited.size())
-            elements.emplace_back(splited[i]);
-          else
-            elements.emplace_back(Dict::nullValueStr);
-      }
-      else if (column == "ID")
-      {
-        if (config.isTokenPredicted(index))
-          elements.emplace_back("TOKEN");
-        else if (config.isMultiwordPredicted(index))
-          elements.emplace_back("MULTIWORD");
-        else if (config.isEmptyNodePredicted(index))
-          elements.emplace_back("EMPTYNODE");
-      }
-      else if (column == "EOS")
-      {
-        bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1;
-        elements.emplace_back(fmt::format("{}", isEOS));
-      }
-      else
-      {
-        elements.emplace_back(func(config.getAsFeature(column, index)));
-      }
+      for (int i = 0; i < maxNbElements; i++)
+        if (i < (int)splited.size())
+          elements.emplace_back(splited[i]);
+        else
+          elements.emplace_back(Dict::nullValueStr);
+    }
+    else if (column == "ID")
+    {
+      if (config.isTokenPredicted(index))
+        elements.emplace_back("TOKEN");
+      else if (config.isMultiwordPredicted(index))
+        elements.emplace_back("MULTIWORD");
+      else if (config.isEmptyNodePredicted(index))
+        elements.emplace_back("EMPTYNODE");
+    }
+    else if (column == "EOS")
+    {
+      bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1;
+      elements.emplace_back(fmt::format("{}", isEOS));
+    }
+    else
+    {
+      elements.emplace_back(func(config.getAsFeature(column, index)));
+    }
 
-      if ((int)elements.size() != maxNbElements)
-        util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
+    if ((int)elements.size() != maxNbElements)
+      util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
 
-      for (auto & element : elements)
-        contextElement.emplace_back(dict.getIndexOrInsert(element, column));
+    for (auto & element : elements)
+    {
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(element, column);
+      insertIndex++;
     }
   }
 }
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index 7d0912ce154af1025d409df3f7d3de4f40eae683..4a9033fe01e989301b12553245033c1a876b5ace 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -53,18 +53,17 @@ std::size_t HistoryModuleImpl::getInputSize()
   return maxNbElements;
 }
 
-void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
 
   std::string prefix = "HISTORY";
 
-  for (auto & contextElement : context)
-    for (int i = 0; i < maxNbElements; i++)
-      if (config.hasHistory(i))
-        contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i), prefix));
-      else
-        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
+  for (int i = 0; i < maxNbElements; i++)
+    if (config.hasHistory(i))
+      context[firstInputIndex+i] = dict.getIndexOrInsert(config.getHistory(i), prefix);
+    else
+      context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
 }
 
 void HistoryModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index c936f85b75ebfc4d6ed9686b091a183d53bc5adc..1c39f186db7bdc62e6e30c3771974dacefa8af63 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -69,6 +69,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
     currentOutputSize += modules.back()->getOutputSize();
   }
 
+  totalInputSize = currentInputSize;
+
   if (mlpDef.empty())
     util::myThrow("no MLP definition found");
   if (inputDropout.is_empty())
@@ -95,9 +97,9 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string
   return outputLayersPerState.at(state)(mlp(totalInput));
 }
 
-std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
+torch::Tensor ModularNetworkImpl::extractContext(Config & config)
 {
-  std::vector<std::vector<long>> context(1);
+  torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
   for (auto & mod : modules)
     mod->addToContext(context, config);
   return context;
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index 15d2b19ef681539418dbd1c64dcffb27f7eabe54..49d3016bb00452efee0ffa7fc0d45d5a80a58bae 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -60,7 +60,7 @@ std::size_t NumericColumnModuleImpl::getInputSize()
   return focusedBuffer.size() + focusedStack.size();
 }
 
-void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   std::vector<long> focusedIndexes;
 
@@ -73,21 +73,21 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
-    for (auto index : focusedIndexes)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
+  {
+    double res = 0.0;
+    if (index >= 0)
     {
-      double res = 0.0;
-      if (index >= 0)
-      {
-        auto value = config.getAsFeature(column, index).get();
-        try {res = (value == "_" or value == "NA") ? defaultValue : std::stof(value);}
-        catch (std::exception & e)
-          {util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
-      }
-
-      contextElement.emplace_back(0);
-      std::memcpy(&contextElement.back(), &res, sizeof res);
+      auto value = config.getAsFeature(column, index).get();
+      try {res = (value == "_" or value == "NA") ? defaultValue : std::stof(value);}
+      catch (std::exception & e)
+        {util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
     }
+
+    //TODO : Check if this works
+    context[firstInputIndex+insertIndex] = res;
+  }
 }
 
 void NumericColumnModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 87a604636595062008ecbd5d442111b4101c8b39..b05d1aa00b26677500bc7f8a2acb59ea12a2cbcd 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -13,9 +13,10 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string
   return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
 }
 
-std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
+torch::Tensor RandomNetworkImpl::extractContext(Config &)
 {
-  return std::vector<std::vector<long>>{{0}};
+  torch::Tensor context;
+  return context;
 }
 
 void RandomNetworkImpl::registerEmbeddings()
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index d948386e6f2d7ffd376c8170b6a670f895a48948..2d6bd62164e13adb8b21547cd895490a0f173e59 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -11,6 +11,9 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
             leftWindow = std::stoi(sm.str(1));
             rightWindow = std::stoi(sm.str(2));
 
+            if (leftWindow < 0 or rightWindow < 0)
+              util::myThrow(fmt::format("Invalid negative values for leftWindow({}) or rightWindow({})", leftWindow, rightWindow));
+
             auto subModuleType = sm.str(3);
             auto subModuleArguments = util::split(sm.str(4), ' ');
 
@@ -54,27 +57,30 @@ std::size_t RawInputModuleImpl::getInputSize()
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
-  if (leftWindow < 0 or rightWindow < 0)
-    return;
-
   std::string prefix = "LETTER";
-
   auto & dict = getDict();
-  for (auto & contextElement : context)
+
+  int insertIndex = 0;
+  for (int i = 0; i < leftWindow; i++)
   {
-    for (int i = 0; i < leftWindow; i++)
-      if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix));
-      else
-        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
-
-    for (int i = 0; i <= rightWindow; i++)
-      if (config.hasCharacter(config.getCharacterIndex()+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix));
-      else
-        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
+    if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix);
+    else
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
+  
+    insertIndex++;
+  }
+
+  for (int i = 0; i <= rightWindow; i++)
+  {
+    if (config.hasCharacter(config.getCharacterIndex()+i))
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix);
+    else
+      context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
+
+    insertIndex++;
   }
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 0c1de2e7f5f9a1dbe7003ac24cd02d474d399048..dcb78e1f42464871301c2556da000f61efb63b28 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -51,16 +51,15 @@ std::size_t SplitTransModuleImpl::getInputSize()
   return maxNbTrans;
 }
 
-void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
   auto & splitTransitions = config.getAppliableSplitTransitions();
-  for (auto & contextElement : context)
-    for (int i = 0; i < maxNbTrans; i++)
-      if (i < (int)splitTransitions.size())
-        contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName(), ""));
-      else
-        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, ""));
+  for (int i = 0; i < maxNbTrans; i++)
+    if (i < (int)splitTransitions.size())
+      context[firstInputIndex+i] = dict.getIndexOrInsert(splitTransitions[i]->getName(), "");
+    else
+      context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, "");
 }
 
 void SplitTransModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index 0c642b947b78a69a64490ab4e2dc7f070b3277af..f3ac97753a08a2529a0708295dfd1cc21231f974 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -29,11 +29,10 @@ std::size_t StateNameModuleImpl::getInputSize()
   return 1;
 }
 
-void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   auto & dict = getDict();
-  for (auto & contextElement : context)
-    contextElement.emplace_back(dict.getIndexOrInsert(config.getState(), ""));
+  context[firstInputIndex] = dict.getIndexOrInsert(config.getState(), "");
 }
 
 void StateNameModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index 0452eb8db781b8e83a1e62069b88c790b1214678..7c846150c1e96af8f7886520592b8d4363e9c78d 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -56,7 +56,7 @@ std::size_t UppercaseRateModuleImpl::getInputSize()
   return focusedBuffer.size() + focusedStack.size();
 }
 
-void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config & config)
 {
   std::vector<long> focusedIndexes;
 
@@ -69,25 +69,24 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     else
       focusedIndexes.emplace_back(-1);
 
-  for (auto & contextElement : context)
+  int insertIndex = 0;
+  for (auto index : focusedIndexes)
   {
-    for (auto index : focusedIndexes)
+    double res = -1.0;
+    if (index >= 0)
     {
-      double res = -1.0;
-      if (index >= 0)
-      {
-        auto word = util::splitAsUtf8(config.getAsFeature("FORM", index).get());
-        int nbUpper = 0;
-        for (auto & letter : word)
-          if (util::isUppercase(letter))
-            nbUpper++;
-        if (word.size() > 0)
-          res = 1.0*nbUpper/word.size();
-      }
-
-      contextElement.emplace_back(0);
-      std::memcpy(&contextElement.back(), &res, sizeof res);
+      auto word = util::splitAsUtf8(config.getAsFeature("FORM", index).get());
+      int nbUpper = 0;
+      for (auto & letter : word)
+        if (util::isUppercase(letter))
+          nbUpper++;
+      if (word.size() > 0)
+        res = 1.0*nbUpper/word.size();
     }
+
+    //TODO : Check if this works
+    context[firstInputIndex+insertIndex] = res;
+    insertIndex++;
   }
 
 }
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index dfa465e8a524134b55c7887892940fe0bc1a01cd..a5088d0eebd1dd5e60449632718aed378d7c893c 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -34,7 +34,7 @@ class Trainer
     int lastSavedIndex{0};
 
     void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
-    void addContext(std::vector<std::vector<long>> & context);
+    void addContext(torch::Tensor & context);
     void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes);
   };
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 56ecc44cbfdd3fdf069856b4da3963665ae9c068..4143b2cb5a961f9db485dbd9891a2d35e569fc0c 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -65,7 +65,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
       config.setAppliableTransitions(appliableTransitions);
 
-      std::vector<std::vector<long>> context;
+      torch::Tensor context;
 
       try
       {
@@ -92,8 +92,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
       {
         auto & classifier = *machine.getClassifier(config.getState());
-        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
-        auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0);
+        auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
         entropy  = NeuralNetworkImpl::entropy(prediction);
     
         std::vector<int> candidates;
@@ -154,7 +153,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
 
       if (!exampleIsBanned)
       {
-        totalNbExamples += context.size();
+        totalNbExamples += 1;
         if (totalNbExamples >= (int)safetyNbExamplesMax)
           util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
 
@@ -295,12 +294,11 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
   classes.clear();
 }
 
-void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
+void Trainer::Examples::addContext(torch::Tensor & context)
 {
-  for (auto & element : context)
-    contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
+  contexts.emplace_back(context);
 
-  currentExampleIndex += context.size();
+  currentExampleIndex += 1;
 }
 
 void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes)