diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 9b6b3a67033013989b31b560dab167de3fcc08eb..33c483783f9c02060a40057dba26597735c34dac 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
 {
   torch::AutoGradMode useGrad(false);
   machine.trainMode(false);
+  machine.splitUnknown(false);
   machine.setDictsState(Dict::State::Closed);
   machine.getStrategy().reset();
   config.addPredicted(machine.getPredicted());
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 9eb09d038a853625dcbb0b649f02556a06eea94c..5f3ff1c6449e98f666a007f84f4b2b1b4d673726 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -47,6 +47,7 @@ class ReadingMachine
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
   void trainMode(bool isTrainMode);
+  void splitUnknown(bool splitUnknown);
   void setDictsState(Dict::State state);
   void saveBest() const;
   void saveLast() const;
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 38f79c84c0a1e14e7adffe168486d47d1674a944..138c2494791fda62ddf933589d39c6b297260e6e 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode)
   classifier->getNN()->train(isTrainMode);
 }
 
+void ReadingMachine::splitUnknown(bool splitUnknown)
+{
+  classifier->getNN()->setSplitUnknown(splitUnknown);
+}
+
 void ReadingMachine::setDictsState(Dict::State state)
 {
   for (auto & it : dicts)
diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextLSTM.hpp
index 136029cd33d2c3a1825ce23d8902c860242eecf8..3e3bbacac0e56cfd38e981279a0f6a54c1f41b3d 100644
--- a/torch_modules/include/ContextLSTM.hpp
+++ b/torch_modules/include/ContextLSTM.hpp
@@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(ContextLSTM);
 
diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp
index 436a082a06121a2c62f50da0b5f5ef4b79b99ba8..2a8f7e8ca0ccd8fea1313e4b4437700c9bdd6bef 100644
--- a/torch_modules/include/DepthLayerTreeEmbedding.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp
@@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(DepthLayerTreeEmbedding);
 
diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnLSTM.hpp
index 6ea836a041017fb1fdf6725506b6eb1f561bdb99..fd5d915df6d42d24294e6a75dd42c87d6e81dec1 100644
--- a/torch_modules/include/FocusedColumnLSTM.hpp
+++ b/torch_modules/include/FocusedColumnLSTM.hpp
@@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(FocusedColumnLSTM);
 
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index be25c873978d61edc57ee014695e2110b8cb189b..1237f09e15989dc6534e150e9ca03cfe983f797b 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   static torch::Device device;
 
+  private :
+
+  bool splitUnknown{false};
+
   protected : 
 
   static constexpr int maxNbEmbeddings = 150000;
@@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
+  bool mustSplitUnknown() const;
+  void setSplitUnknown(bool splitUnknown);
 };
 TORCH_MODULE(NeuralNetwork);
 
diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputLSTM.hpp
index db17d6f0014e615474a462635844e3d5251f3fb0..0e08560836b735f181849571ff0beec8f02bc335 100644
--- a/torch_modules/include/RawInputLSTM.hpp
+++ b/torch_modules/include/RawInputLSTM.hpp
@@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(RawInputLSTM);
 
diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransLSTM.hpp
index f90c0edcfedd699d7d9d18626db8d80cd40d385e..85d542ce8510bd0c1d11b2ca6c1f280aeb386d55 100644
--- a/torch_modules/include/SplitTransLSTM.hpp
+++ b/torch_modules/include/SplitTransLSTM.hpp
@@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(SplitTransLSTM);
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 437bbfa4e82ca29fb35b33924ba1fc3c16cbb126..cc381013aea518aeefe8422b36537283d5d0da94 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -15,7 +15,7 @@ class Submodule
   void setFirstInputIndex(std::size_t firstInputIndex);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
-  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
+  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
 };
 
 #endif
diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp
index 95daa696df15c70797332f0d938c5646111e418d..5da02e729e4425d366ff9c1220c43ae477f5c926 100644
--- a/torch_modules/src/ContextLSTM.cpp
+++ b/torch_modules/src/ContextLSTM.cpp
@@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize()
   return columns.size()*(bufferContext.size()+stackContext.size());
 }
 
-void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
 {
   std::vector<long> contextIndexes;
 
@@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
   for (auto index : contextIndexes)
     for (auto & col : columns)
       if (index == -1)
+      {
         for (auto & contextElement : context)
           contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      }
       else
       {
         int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
@@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
         for (auto & contextElement : context)
           contextElement.push_back(dictIndex);
 
-        if (is_training())
+
+        if (splitUnknown)
           for (auto & targetCol : unknownValueColumns)
             if (col == targetCol)
               if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp
index 6e1342a0270a0dea0fcc4d41bef758a050709595..b506f9219fd8284094960907975294fbd3a5b28a 100644
--- a/torch_modules/src/DepthLayerTreeEmbedding.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp
@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnLSTM.cpp
index 4e0da0ebb99e7aab3abe9f18b89301df0a448591..e39af636c817fdc1677cfd9131b85ec7fb1bd3ba 100644
--- a/torch_modules/src/FocusedColumnLSTM.cpp
+++ b/torch_modules/src/FocusedColumnLSTM.cpp
@@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize()
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index cfa004e58e66b3162231fd4046b499ac62dc5636..a4f5863ce39dab0fd51730e9edb52a8a0a43e2a2 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
 
   context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
 
-  contextLSTM->addToContext(context, dict, config);
+  contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
 
   if (!rawInputLSTM.is_empty())
-    rawInputLSTM->addToContext(context, dict, config);
+    rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
 
   if (!treeEmbedding.is_empty())
-    treeEmbedding->addToContext(context, dict, config);
+    treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
 
-  splitTransLSTM->addToContext(context, dict, config);
+  splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
 
   for (auto & lstm : focusedLstms)
-    lstm->addToContext(context, dict, config);
+    lstm->addToContext(context, dict, config, mustSplitUnknown());
 
-  if (!is_training() && context.size() > 1)
-    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
+  if (!mustSplitUnknown() && context.size() > 1)
+    util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
 
   return context;
 }
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 02e8a191bfb4b2bc718b6e815a266bec252fb24b..235c67793305280d0e09a3f1d45593fa727d13a3 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -2,3 +2,13 @@
 
 torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
 
+bool NeuralNetworkImpl::mustSplitUnknown() const
+{
+  return splitUnknown;
+}
+
+void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown)
+{
+  this->splitUnknown = splitUnknown;
+}
+
diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp
index 2aa8cfd9c2f97c9fd99d6ed08d3e2a0aa75c35d8..c6da426a7807b90bfd52eaf06abe7599c4c517c3 100644
--- a/torch_modules/src/RawInputLSTM.cpp
+++ b/torch_modules/src/RawInputLSTM.cpp
@@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize()
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   if (leftWindow < 0 or rightWindow < 0)
     return;
diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp
index a83894abf93efda1ed0124fe58057e3ce042be06..99a1b35650e0b60c8c34c22f0a863d1ab1f8c990 100644
--- a/torch_modules/src/SplitTransLSTM.cpp
+++ b/torch_modules/src/SplitTransLSTM.cpp
@@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize()
   return maxNbTrans;
 }
 
-void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   auto & splitTransitions = config.getAppliableSplitTransitions();
   for (auto & contextElement : context)
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 03b7f880742df3dd737f555fd96ec1490a6c5a7c..95e98eba3e8162e01a71d8a6a5373a0d96293d30 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
+  machine.splitUnknown(true);
   machine.setDictsState(Dict::State::Open);
 
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
@@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
+  machine.splitUnknown(false);
   machine.setDictsState(Dict::State::Closed);
 
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);