diff --git a/CMakeLists.txt b/CMakeLists.txt index 78867d241d7897bf67236112b7d184caf96d909c..7c6507c71ec6594bb9077e62dbdf57fd93d087a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.0.2) -project(test_torch) +project(macaon) add_compile_definitions(BOOST_DISABLE_THREADS) @@ -9,7 +9,6 @@ find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options) include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) add_library(Torch SHARED IMPORTED) -set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES}) add_library(Boost SHARED IMPORTED) set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE}) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index e7b2285363208b1607f5143aa9ba7d092e4e20c7..fc5f386ae9d085c2a0489357085447122a27f476 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -23,7 +23,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())); machine.getDict(config.getState()).setState(dictState); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp index f673c2ca8c2fc23d0e131cb77efd2cc817515c5d..5b7272b0e7d1fb4aa150eca63eb1287902a633a8 100644 --- a/decoder/src/macaon_decode.cpp +++ b/decoder/src/macaon_decode.cpp @@ -78,6 +78,8 @@ int main(int argc, char * argv[]) if (modelPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); + fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str()); + try { ReadingMachine machine(machinePath, modelPaths, dictPaths); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index f234c73e4b1101cd68927ee26897bd6fcf082abe..ad03c314757c302ccd73ae5768fc4c9b42c631f2 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -67,7 +67,10 @@ void Classifier::initNeuralNetwork(const std::string & topology) for (auto & initializer : initializers) if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) + { + this->nn->to(NeuralNetworkImpl::device); return; + } std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); for (auto & initializer : initializers) diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index cdc55145bd7f20cb593e9e9fc82afd02a0564b9c..47092164e0abd730c081e7306182757b56eb14be 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -7,6 +7,10 @@ class NeuralNetworkImpl : public torch::nn::Module { + public : + + static torch::Device device; + protected : int leftBorder{5}; diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index cdb8ad3426050ddab9581713da34d42fe4394ae6..0719d84430f7bf194a2b887a022c1c473667572a 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -1,5 +1,7 @@ #include "NeuralNetwork.hpp" +torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); + std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const { std::stack<int> leftContext; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index a74078cd8f6e7a567687e248b9f924b3e4b6535a..191ac47278e1c6d17efd8fa9663271a6620c654f 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -26,10 +26,10 @@ void Trainer::createDataset(SubConfig & config, bool debug) } auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); - contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); + contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); - auto gold = torch::zeros(1, at::kLong); + auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)); gold[0] = goldIndex; classes.emplace_back(gold); diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 34a009a43a0317fab79d9992faee67f9fc0ae011..70aee20d3e45bd0b6423b023e2a83a44607cd002 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -3,6 +3,7 @@ #include "util.hpp" #include "Trainer.hpp" #include "Decoder.hpp" +#include "NeuralNetwork.hpp" namespace po = boost::program_options; @@ -73,6 +74,8 @@ int main(int argc, char * argv[]) auto nbEpoch = variables["nbEpochs"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; + fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str()); + try {