From f93305a5459af6b1aa689f1701391aa402c756e0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 3 Mar 2020 17:53:38 +0100 Subject: [PATCH] Detect device (cpu or cuda) and use cuda if possible --- CMakeLists.txt | 3 +-- decoder/src/Decoder.cpp | 2 +- decoder/src/macaon_decode.cpp | 2 ++ reading_machine/src/Classifier.cpp | 3 +++ torch_modules/include/NeuralNetwork.hpp | 4 ++++ torch_modules/src/NeuralNetwork.cpp | 2 ++ trainer/src/Trainer.cpp | 4 ++-- trainer/src/macaon_train.cpp | 3 +++ 8 files changed, 18 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 78867d2..7c6507c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.0.2) -project(test_torch) +project(macaon) add_compile_definitions(BOOST_DISABLE_THREADS) @@ -9,7 +9,6 @@ find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options) include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) add_library(Torch SHARED IMPORTED) -set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES}) add_library(Boost SHARED IMPORTED) set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE}) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index e7b2285..fc5f386 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -23,7 +23,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())); machine.getDict(config.getState()).setState(dictState); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp index f673c2c..5b7272b 100644 --- a/decoder/src/macaon_decode.cpp +++ b/decoder/src/macaon_decode.cpp @@ -78,6 +78,8 @@ int main(int argc, char * argv[]) if (modelPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); + fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str()); + try { ReadingMachine machine(machinePath, modelPaths, dictPaths); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index f234c73..ad03c31 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -67,7 +67,10 @@ void Classifier::initNeuralNetwork(const std::string & topology) for (auto & initializer : initializers) if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) + { + this->nn->to(NeuralNetworkImpl::device); return; + } std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); for (auto & initializer : initializers) diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index cdc5514..4709216 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -7,6 +7,10 @@ class NeuralNetworkImpl : public torch::nn::Module { + public : + + static torch::Device device; + protected : int leftBorder{5}; diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index cdb8ad3..0719d84 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -1,5 +1,7 @@ #include "NeuralNetwork.hpp" +torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); + std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const { std::stack<int> leftContext; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index a74078c..191ac47 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -26,10 +26,10 @@ void Trainer::createDataset(SubConfig & config, bool debug) } auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); - contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); + contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); - auto gold = torch::zeros(1, at::kLong); + auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)); gold[0] = goldIndex; classes.emplace_back(gold); diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 34a009a..70aee20 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -3,6 +3,7 @@ #include "util.hpp" #include "Trainer.hpp" #include "Decoder.hpp" +#include "NeuralNetwork.hpp" namespace po = boost::program_options; @@ -73,6 +74,8 @@ int main(int argc, char * argv[]) auto nbEpoch = variables["nbEpochs"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; + fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str()); + try { -- GitLab