Skip to content
Snippets Groups Projects
Commit f93305a5 authored by Franck Dary's avatar Franck Dary
Browse files

Detect device (cpu or cuda) and use cuda if possible

parent f33d5d8c
No related branches found
No related tags found
No related merge requests found
cmake_minimum_required(VERSION 3.0.2) cmake_minimum_required(VERSION 3.0.2)
project(test_torch) project(macaon)
add_compile_definitions(BOOST_DISABLE_THREADS) add_compile_definitions(BOOST_DISABLE_THREADS)
...@@ -9,7 +9,6 @@ find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options) ...@@ -9,7 +9,6 @@ find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options)
include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})
add_library(Torch SHARED IMPORTED) add_library(Torch SHARED IMPORTED)
set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES})
add_library(Boost SHARED IMPORTED) add_library(Boost SHARED IMPORTED)
set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE}) set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE})
......
...@@ -23,7 +23,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) ...@@ -23,7 +23,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState); machine.getDict(config.getState()).setState(dictState);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1; int chosenTransition = -1;
......
...@@ -78,6 +78,8 @@ int main(int argc, char * argv[]) ...@@ -78,6 +78,8 @@ int main(int argc, char * argv[])
if (modelPaths.empty()) if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
try try
{ {
ReadingMachine machine(machinePath, modelPaths, dictPaths); ReadingMachine machine(machinePath, modelPaths, dictPaths);
......
...@@ -67,7 +67,10 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -67,7 +67,10 @@ void Classifier::initNeuralNetwork(const std::string & topology)
for (auto & initializer : initializers) for (auto & initializer : initializers)
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
{
this->nn->to(NeuralNetworkImpl::device);
return; return;
}
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
for (auto & initializer : initializers) for (auto & initializer : initializers)
......
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
class NeuralNetworkImpl : public torch::nn::Module class NeuralNetworkImpl : public torch::nn::Module
{ {
public :
static torch::Device device;
protected : protected :
int leftBorder{5}; int leftBorder{5};
......
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{ {
std::stack<int> leftContext; std::stack<int> leftContext;
......
...@@ -26,10 +26,10 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -26,10 +26,10 @@ void Trainer::createDataset(SubConfig & config, bool debug)
} }
auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong); auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
gold[0] = goldIndex; gold[0] = goldIndex;
classes.emplace_back(gold); classes.emplace_back(gold);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "util.hpp" #include "util.hpp"
#include "Trainer.hpp" #include "Trainer.hpp"
#include "Decoder.hpp" #include "Decoder.hpp"
#include "NeuralNetwork.hpp"
namespace po = boost::program_options; namespace po = boost::program_options;
...@@ -73,6 +74,8 @@ int main(int argc, char * argv[]) ...@@ -73,6 +74,8 @@ int main(int argc, char * argv[])
auto nbEpoch = variables["nbEpochs"].as<int>(); auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
try try
{ {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment