diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 64dda232cccdadb8031433f5c5562f79e6e8db06..22e715f505d66ed9d41c32894d10d3d8395f405f 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -90,7 +90,7 @@ int MacaonDecode::main() if (modelPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); - fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str()); + fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::getDevice().str()); try { @@ -127,8 +127,8 @@ int MacaonDecode::main() if (configs.size() > 1) { - NeuralNetworkImpl::device = torch::kCPU; - machine.to(NeuralNetworkImpl::device); + NeuralNetworkImpl::setDevice(torch::kCPU); + machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::par, configs.begin(), configs.end(), [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config) { diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index a2361c8ae69f36dfb6a9d60659a59a90c52530a5..a9de8bbaabc1c3f20a2987e65cfb8094aaf470fc 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -83,20 +83,21 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std if (!train) { - torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::device); + fmt::print(stderr, "Before load on {}\n", NeuralNetworkImpl::getDevice() == torch::kCPU ? "cpu" : "gpu"); + torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice()); + fmt::print(stderr, "After load\n"); getNN()->registerEmbeddings(); - getNN()->to(NeuralNetworkImpl::device); + getNN()->to(NeuralNetworkImpl::getDevice()); } else if (std::filesystem::exists(getLastFilename())) { - torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::device); - getNN()->to(NeuralNetworkImpl::device); + torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::getDevice()); resetOptimizer(); loadOptimizer(); } else { - getNN()->to(NeuralNetworkImpl::device); + getNN()->to(NeuralNetworkImpl::getDevice()); } } @@ -183,7 +184,7 @@ void Classifier::loadOptimizer() { auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name)); if (std::filesystem::exists(optimizerPath)) - torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::device); + torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::getDevice()); } void Classifier::saveOptimizer() diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index d96f2647cbb09650f4106148953f294e4092a65b..6e2319b6663e0611ec3dcc8b480b0f79e2fe5c25 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -8,7 +8,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder { - public : + private : static torch::Device device; @@ -24,6 +24,8 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder virtual void removeRareDictElements(float rarityThreshold) = 0; static torch::Device getPreferredDevice(); + static torch::Device getDevice(); + static void setDevice(torch::Device device); static float entropy(torch::Tensor probabilities); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 91695ad8749e4c23afed51445c2ba3c83441a4ed..e8f40befef18e5b44173ad4aec91d3d3565e3171 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset() loadedTensorIndex = 0; nextIndexToGive = 0; nbGiven = 0; - torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); - loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device))); + torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice()); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice()))); } c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize) @@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset if (loadedTensorIndex >= (int)files.size()) return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); nextIndexToGive = 0; - torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); - loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device))); + torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice()); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice()))); } int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive); diff --git a/torch_modules/src/CustomHingeLoss.cpp b/torch_modules/src/CustomHingeLoss.cpp index ec44b6c5bdba5aac28d314a11bfb06175f1b1f7b..e2d0b510cae18d9093181d2551b6d4a78f8b0078 100644 --- a/torch_modules/src/CustomHingeLoss.cpp +++ b/torch_modules/src/CustomHingeLoss.cpp @@ -3,7 +3,7 @@ torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold) { - torch::Tensor loss = torch::zeros(1).to(NeuralNetworkImpl::device); + torch::Tensor loss = torch::zeros(1, NeuralNetworkImpl::getDevice()); for (unsigned int i = 0; i < prediction.size(0); i++) { diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index 2f8f1be9dff754f407c5d3e4bdf5fa14e6cf825a..bb5724932bb3750188e5c92409bb50341cb2b7ee 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -51,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: if (index == 0 or index == 2 or index == 4) { - auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); gold[0] = goldIndexes.at(0); return gold; } if (index == 1 or index == 3) { - auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); + auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); for (auto goldIndex : goldIndexes) gold[goldIndex] = 1; return gold; diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 1c39f186db7bdc62e6e30c3771974dacefa8af63..066b79dbec2c0ba8f09048d5160dea9814ef436e 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -99,7 +99,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string torch::Tensor ModularNetworkImpl::extractContext(Config & config) { - torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); + torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); for (auto & mod : modules) mod->addToContext(context, config); return context; diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index c85c1602dc028670a88ed1dfacbee3c78e0896a0..fe3727dc18827225b17bbf6db9133b7dd67ce54e 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -18,3 +18,13 @@ torch::Device NeuralNetworkImpl::getPreferredDevice() return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; } +torch::Device NeuralNetworkImpl::getDevice() +{ + return device; +} + +void NeuralNetworkImpl::setDevice(torch::Device device) +{ + NeuralNetworkImpl::device = device; +} + diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index b05d1aa00b26677500bc7f8a2acb59ea12a2cbcd..d27ffe93ed41185fb62c922278362b25f9aa84e3 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -10,7 +10,7 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string if (input.dim() == 1) input = input.unsqueeze(0); - return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true)); + return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(NeuralNetworkImpl::getDevice()).requires_grad(true)); } torch::Tensor RandomNetworkImpl::extractContext(Config &) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 79d608a25c5fe7ffa51cbe03a6b8dcbf6787706d..940945180494e89df1a69ed49713eec621ffdf2f 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -161,7 +161,7 @@ int MacaonTrain::main() std::fclose(file); } - fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str()); + fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::getDevice().str()); try { @@ -325,15 +325,15 @@ int MacaonTrain::main() if (devConfigs.size() > 1) { - NeuralNetworkImpl::device = torch::kCPU; - machine.to(NeuralNetworkImpl::device); + NeuralNetworkImpl::setDevice(torch::kCPU); + machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(), [&decoder, debug, printAdvancement](BaseConfig & devConfig) { decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); }); - NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice(); - machine.to(NeuralNetworkImpl::device); + NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); + machine.to(NeuralNetworkImpl::getDevice()); } else { diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 628386c55e4867d57bb7f9787e56f87a5132ca6d..2237a07e099e038023f1ecb66fff45f7a19c83f1 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -50,8 +50,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: std::atomic<int> totalNbExamples = 0; - NeuralNetworkImpl::device = torch::kCPU; - machine.to(NeuralNetworkImpl::device); + NeuralNetworkImpl::setDevice(torch::kCPU); + machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::par, configs.begin(), configs.end(), [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config) { @@ -191,8 +191,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: for (auto & it : examplesPerState) it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle); - NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice(); - machine.to(NeuralNetworkImpl::device); + NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); + machine.to(NeuralNetworkImpl::getDevice()); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f)