Commit 8c6cc68d authored by Franck Dary's avatar Franck Dary
Browse files

Added functions getDevice and setDevice

parent 0ab37e18
......@@ -90,7 +90,7 @@ int MacaonDecode::main()
if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::getDevice().str());
try
{
......@@ -127,8 +127,8 @@ int MacaonDecode::main()
if (configs.size() > 1)
{
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::par, configs.begin(), configs.end(),
[&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config)
{
......
......@@ -83,20 +83,21 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
if (!train)
{
torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::device);
fmt::print(stderr, "Before load on {}\n", NeuralNetworkImpl::getDevice() == torch::kCPU ? "cpu" : "gpu");
torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice());
fmt::print(stderr, "After load\n");
getNN()->registerEmbeddings();
getNN()->to(NeuralNetworkImpl::device);
getNN()->to(NeuralNetworkImpl::getDevice());
}
else if (std::filesystem::exists(getLastFilename()))
{
torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::device);
getNN()->to(NeuralNetworkImpl::device);
torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::getDevice());
resetOptimizer();
loadOptimizer();
}
else
{
getNN()->to(NeuralNetworkImpl::device);
getNN()->to(NeuralNetworkImpl::getDevice());
}
}
......@@ -183,7 +184,7 @@ void Classifier::loadOptimizer()
{
auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
if (std::filesystem::exists(optimizerPath))
torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::device);
torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::getDevice());
}
void Classifier::saveOptimizer()
......
......@@ -8,7 +8,7 @@
class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
{
public :
private :
static torch::Device device;
......@@ -24,6 +24,8 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
virtual void removeRareDictElements(float rarityThreshold) = 0;
static torch::Device getPreferredDevice();
static torch::Device getDevice();
static void setDevice(torch::Device device);
static float entropy(torch::Tensor probabilities);
};
TORCH_MODULE(NeuralNetwork);
......
......@@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset()
loadedTensorIndex = 0;
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
......@@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
......
......@@ -3,7 +3,7 @@
torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
{
torch::Tensor loss = torch::zeros(1).to(NeuralNetworkImpl::device);
torch::Tensor loss = torch::zeros(1, NeuralNetworkImpl::getDevice());
for (unsigned int i = 0; i < prediction.size(0); i++)
{
......
......@@ -51,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
if (index == 0 or index == 2 or index == 4)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
gold[0] = goldIndexes.at(0);
return gold;
}
if (index == 1 or index == 3)
{
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
for (auto goldIndex : goldIndexes)
gold[goldIndex] = 1;
return gold;
......
......@@ -99,7 +99,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string
torch::Tensor ModularNetworkImpl::extractContext(Config & config)
{
torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
for (auto & mod : modules)
mod->addToContext(context, config);
return context;
......
......@@ -18,3 +18,13 @@ torch::Device NeuralNetworkImpl::getPreferredDevice()
return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
}
torch::Device NeuralNetworkImpl::getDevice()
{
return device;
}
void NeuralNetworkImpl::setDevice(torch::Device device)
{
NeuralNetworkImpl::device = device;
}
......@@ -10,7 +10,7 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string
if (input.dim() == 1)
input = input.unsqueeze(0);
return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(NeuralNetworkImpl::getDevice()).requires_grad(true));
}
torch::Tensor RandomNetworkImpl::extractContext(Config &)
......
......@@ -161,7 +161,7 @@ int MacaonTrain::main()
std::fclose(file);
}
fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str());
fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::getDevice().str());
try
{
......@@ -325,15 +325,15 @@ int MacaonTrain::main()
if (devConfigs.size() > 1)
{
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(),
[&decoder, debug, printAdvancement](BaseConfig & devConfig)
{
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
});
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
machine.to(NeuralNetworkImpl::getDevice());
}
else
{
......
......@@ -50,8 +50,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
std::atomic<int> totalNbExamples = 0;
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::par, configs.begin(), configs.end(),
[this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
{
......@@ -191,8 +191,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
machine.to(NeuralNetworkImpl::getDevice());
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment