diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 1d81309a5ffc34c8664784d68775b68e78c832ac..44a8edc0f8cbc8d69ced8c273f63e9824af35ecf 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -25,10 +25,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
       config.printForDebug(stderr);
 
     auto dictState = machine.getDict(config.getState()).getState();
-    auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
+    auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
     machine.getDict(config.getState()).setState(dictState);
 
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device);
+    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
     auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
 
     int chosenTransition = -1;
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index 2edac4993051841372c293c07d55a6aeee56088c..0cd54b8f087034863e1fd8dbb2e07089be221a56 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -33,7 +33,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
 
   CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<long> extractContext(Config & config, Dict & dict) const override;
+  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
 
 #endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 1ca0919cc3118a2ef5b01c0a466c97ed3c3bd6a5..34bf14b632cd912e1d0743fc667a14ed49e667c2 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -27,7 +27,7 @@ class NeuralNetworkImpl : public torch::nn::Module
   public :
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
+  virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
   std::vector<long> extractContextIndexes(const Config & config) const;
   int getContextSize() const;
   void setColumns(const std::vector<std::string> & columns);
diff --git a/torch_modules/include/RLTNetwork.hpp b/torch_modules/include/RLTNetwork.hpp
index 7d350b38fb36a0b31b55ab89b335eb6de62c4124..b996def57a5e540d738bd9db0874a5d8511d2983 100644
--- a/torch_modules/include/RLTNetwork.hpp
+++ b/torch_modules/include/RLTNetwork.hpp
@@ -23,7 +23,7 @@ class RLTNetworkImpl : public NeuralNetworkImpl
 
   RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<long> extractContext(Config & config, Dict & dict) const override;
+  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
 
 #endif
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 5e9696eba7062b67c1b36dccb4dc29dd1fb8f7c5..9f9d6a1598a33f0be2f321dc00542b19562f2c58 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -74,118 +74,129 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
   return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
 }
 
-std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   if (dict.size() >= maxNbEmbeddings)
     util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
 
   std::vector<long> contextIndexes = extractContextIndexes(config);
-  std::vector<long> context;
+  std::vector<std::vector<long>> context;
+  context.emplace_back();
 
   if (rawInputSize > 0)
   {
     for (int i = 0; i < leftWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
-        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
       else
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
     for (int i = 0; i <= rightWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()+i))
-
-        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
       else
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
   }
 
   for (auto index : contextIndexes)
     for (auto & col : columns)
       if (index == -1)
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        for (auto & contextElement : context)
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
       else
       {
         int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
-        if (col == "FORM" || col == "LEMMA")
-          if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
-            dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr);
 
-        context.push_back(dictIndex);
-      }
+        for (auto & contextElement : context)
+          contextElement.push_back(dictIndex);
 
-  for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
-  {
-    auto & col = focusedColumns[colIndex];
+        if (is_training())
+          if (col == "FORM" || col == "LEMMA")
+            if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
+            {
+              context.emplace_back(context.back());
+              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
+            }
+      }
 
-    std::vector<int> focusedIndexes;
-    for (auto relIndex : focusedBufferIndexes)
+  for (auto & contextElement : context)
+    for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
     {
-      int index = relIndex + leftBorder;
-      if (index < 0 || index >= (int)contextIndexes.size())
-        focusedIndexes.push_back(-1);
-      else
-        focusedIndexes.push_back(contextIndexes[index]);
-    }
-    for (auto index : focusedStackIndexes)
-    {
-      if (!config.hasStack(index))
-        focusedIndexes.push_back(-1);
-      else if (!config.has(col, config.getStack(index), 0))
-        focusedIndexes.push_back(-1);
-      else
-        focusedIndexes.push_back(config.getStack(index));
-    }
+      auto & col = focusedColumns[colIndex];
 
-    for (auto index : focusedIndexes)
-    {
-      if (index == -1)
+      std::vector<int> focusedIndexes;
+      for (auto relIndex : focusedBufferIndexes)
       {
-        for (int i = 0; i < maxNbElements[colIndex]; i++)
-          context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-        continue;
+        int index = relIndex + leftBorder;
+        if (index < 0 || index >= (int)contextIndexes.size())
+          focusedIndexes.push_back(-1);
+        else
+          focusedIndexes.push_back(contextIndexes[index]);
       }
-
-      std::vector<std::string> elements;
-      if (col == "FORM")
+      for (auto index : focusedStackIndexes)
       {
-        auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
-
-        for (int i = 0; i < maxNbElements[colIndex]; i++)
-          if (i < (int)asUtf8.size())
-            elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
-          else
-            elements.emplace_back(Dict::nullValueStr);
+        if (!config.hasStack(index))
+          focusedIndexes.push_back(-1);
+        else if (!config.has(col, config.getStack(index), 0))
+          focusedIndexes.push_back(-1);
+        else
+          focusedIndexes.push_back(config.getStack(index));
       }
-      else if (col == "FEATS")
-      {
-        auto splited = util::split(config.getAsFeature(col, index).get(), '|');
 
-        for (int i = 0; i < maxNbElements[colIndex]; i++)
-          if (i < (int)splited.size())
-            elements.emplace_back(fmt::format("FEATS({})", splited[i]));
-          else
-            elements.emplace_back(Dict::nullValueStr);
-      }
-      else if (col == "ID")
+      for (auto index : focusedIndexes)
       {
-        if (config.isTokenPredicted(index))
-          elements.emplace_back("ID(TOKEN)");
-        else if (config.isMultiwordPredicted(index))
-          elements.emplace_back("ID(MULTIWORD)");
-        else if (config.isEmptyNodePredicted(index))
-          elements.emplace_back("ID(EMPTYNODE)");
+        if (index == -1)
+        {
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+          continue;
+        }
+
+        std::vector<std::string> elements;
+        if (col == "FORM")
+        {
+          auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
+
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            if (i < (int)asUtf8.size())
+              elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
+            else
+              elements.emplace_back(Dict::nullValueStr);
+        }
+        else if (col == "FEATS")
+        {
+          auto splited = util::split(config.getAsFeature(col, index).get(), '|');
+
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            if (i < (int)splited.size())
+              elements.emplace_back(fmt::format("FEATS({})", splited[i]));
+            else
+              elements.emplace_back(Dict::nullValueStr);
+        }
+        else if (col == "ID")
+        {
+          if (config.isTokenPredicted(index))
+            elements.emplace_back("ID(TOKEN)");
+          else if (config.isMultiwordPredicted(index))
+            elements.emplace_back("ID(MULTIWORD)");
+          else if (config.isEmptyNodePredicted(index))
+            elements.emplace_back("ID(EMPTYNODE)");
+        }
+        else
+        {
+          elements.emplace_back(config.getAsFeature(col, index));
+        }
+
+        if ((int)elements.size() != maxNbElements[colIndex])
+          util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
+
+        for (auto & element : elements)
+          contextElement.emplace_back(dict.getIndexOrInsert(element));
       }
-      else
-      {
-        elements.emplace_back(config.getAsFeature(col, index));
-      }
-
-      if ((int)elements.size() != maxNbElements[colIndex])
-        util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
-
-      for (auto & element : elements)
-        context.emplace_back(dict.getIndexOrInsert(element));
     }
-  }
+
+  if (!is_training() && context.size() > 1)
+    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
 
   return context;
 }
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 37c206cf1baf43d2d3d230529e88f1e7271a48ca..fef5519623d91e48bbac5dead3dcd575216c4121 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -35,7 +35,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
   return context;
 }
 
-std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   std::vector<long> indexes = extractContextIndexes(config);
   std::vector<long> context;
@@ -47,7 +47,7 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
       else
         context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
 
-  return context;
+  return {context};
 }
 
 int NeuralNetworkImpl::getContextSize() const
diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp
index 85223e776bc2595a594c69fb2fa7abe9c1320b92..38fe64203cbdb0f9546e96cd1b6ac758265af364 100644
--- a/torch_modules/src/RLTNetwork.cpp
+++ b/torch_modules/src/RLTNetwork.cpp
@@ -79,7 +79,7 @@ torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
   return linear2(torch::relu(linear1(representation)));
 }
 
-std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   std::vector<long> contextIndexes;
   std::stack<int> leftContext;
@@ -183,6 +183,6 @@ std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) c
       else
         context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
 
-  return context;
+  return {context};
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 29637014808715bfbbd8f0e546f1998ce61d53e0..501af8e8d87456e2fb8699210972ebf0a130c460 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -48,20 +48,24 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
       util::myThrow("No transition appliable !");
     }
 
+    std::vector<std::vector<long>> context;
+
     try
     {
-      auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
-      contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
+      context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
+      for (auto & element : context)
+        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device));
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
     }
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-    auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
     gold[0] = goldIndex;
 
-    classes.emplace_back(gold);
+    for (auto & element : context)
+      classes.emplace_back(gold);
 
     transition->apply(config);
     config.addToHistory(transition->getName());