Skip to content
Snippets Groups Projects
Commit f93305a5 authored by Franck Dary's avatar Franck Dary
Browse files

Detect device (cpu or cuda) and use cuda if possible

parent f33d5d8c
No related branches found
No related tags found
No related merge requests found
cmake_minimum_required(VERSION 3.0.2)
project(test_torch)
project(macaon)
add_compile_definitions(BOOST_DISABLE_THREADS)
......@@ -9,7 +9,6 @@ find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options)
include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})
add_library(Torch SHARED IMPORTED)
set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES})
add_library(Boost SHARED IMPORTED)
set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE})
......
......@@ -23,7 +23,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1;
......
......@@ -78,6 +78,8 @@ int main(int argc, char * argv[])
if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
try
{
ReadingMachine machine(machinePath, modelPaths, dictPaths);
......
......@@ -67,7 +67,10 @@ void Classifier::initNeuralNetwork(const std::string & topology)
for (auto & initializer : initializers)
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
{
this->nn->to(NeuralNetworkImpl::device);
return;
}
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
for (auto & initializer : initializers)
......
......@@ -7,6 +7,10 @@
class NeuralNetworkImpl : public torch::nn::Module
{
public :
static torch::Device device;
protected :
int leftBorder{5};
......
#include "NeuralNetwork.hpp"
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::stack<int> leftContext;
......
......@@ -26,10 +26,10 @@ void Trainer::createDataset(SubConfig & config, bool debug)
}
auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
gold[0] = goldIndex;
classes.emplace_back(gold);
......
......@@ -3,6 +3,7 @@
#include "util.hpp"
#include "Trainer.hpp"
#include "Decoder.hpp"
#include "NeuralNetwork.hpp"
namespace po = boost::program_options;
......@@ -73,6 +74,8 @@ int main(int argc, char * argv[])
auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
try
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment