diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 2996b796806af3a633a72bc182e2da673419578e..aa0b0c2b4acc52510e6f13af019de059774b8373 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,6 +7,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) { + machine.getClassifier()->getNN()->train(false); config.addPredicted(machine.getPredicted()); try @@ -63,6 +64,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) if (debug) fmt::print(stderr, "Forcing EOS transition\n"); } + + machine.getClassifier()->getNN()->train(true); } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const