diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 2996b796806af3a633a72bc182e2da673419578e..aa0b0c2b4acc52510e6f13af019de059774b8373 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -7,6 +7,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 
 void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
 {
+  machine.getClassifier()->getNN()->train(false);
   config.addPredicted(machine.getPredicted());
 
   try
@@ -63,6 +64,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
     if (debug)
       fmt::print(stderr, "Forcing EOS transition\n");
   }
+
+  machine.getClassifier()->getNN()->train(true);
 }
 
 float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const