diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index af823b77e5fbaeed858785d087457cb98ac16a58..a51ad089e1f569e594dedf435f6f990af9ba0b93 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -53,6 +53,8 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) initLSTM(definition, curIndex); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); + + this->nn->to(NeuralNetworkImpl::device); } void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)