From 17e6ebe9d4ea7af82e9980b9a2a41125121ad3e5 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 6 Mar 2021 18:45:41 +0100 Subject: [PATCH] Macaon train decoding (devScore) on cpu in parallel --- decoder/include/Decoder.hpp | 2 +- decoder/src/Decoder.cpp | 6 +++++- reading_machine/include/Classifier.hpp | 1 + reading_machine/include/ReadingMachine.hpp | 1 + reading_machine/src/Classifier.cpp | 5 +++++ reading_machine/src/ReadingMachine.cpp | 6 ++++++ torch_modules/include/NeuralNetwork.hpp | 1 + torch_modules/src/NeuralNetwork.cpp | 7 ++++++- trainer/src/MacaonTrain.cpp | 23 ++++++++++++++++------ 9 files changed, 43 insertions(+), 9 deletions(-) diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 01bc7cc..fe8c870 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -25,7 +25,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); + std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 38957af..70eba0e 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -6,11 +6,12 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) +std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; + std::size_t totalNbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); Beam beam(beamSize, beamThreshold, baseConfig, machine); @@ -20,6 +21,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh while (!beam.isEnded()) { beam.update(machine, debug); + ++totalNbExamplesProcessed; if (printAdvancement) if (++nbExamplesProcessed >= printInterval) @@ -49,6 +51,8 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script try {baseConfig.addMissingColumns();} catch (std::exception & e) {util::myThrow(e.what());} + + return totalNbExamplesProcessed; } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index e4b2208..41285a3 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -54,6 +54,7 @@ class Classifier bool isRegression() const; LossFunction & getLossFunction(); bool exampleIsBanned(const Config & config); + void to(torch::Device device); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 3135635..5035070 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -52,6 +52,7 @@ class ReadingMachine void loadPretrainedClassifiers(); int getNbParameters() const; void resetOptimizers(); + void to(torch::Device device); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 68e89b7..a2361c8 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -300,3 +300,8 @@ bool Classifier::exampleIsBanned(const Config & config) return false; } +void Classifier::to(torch::Device device) +{ + getNN()->to(device); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 33e08cd..7c06ebd 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -194,3 +194,9 @@ void ReadingMachine::resetOptimizers() classifier->resetOptimizer(); } +void ReadingMachine::to(torch::Device device) +{ + for (auto & classifier : classifiers) + classifier->to(device); +} + diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index ffbcdea..d96f264 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -23,6 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder virtual void setCountOcc(bool countOcc) = 0; virtual void removeRareDictElements(float rarityThreshold) = 0; + static torch::Device getPreferredDevice(); static float entropy(torch::Tensor probabilities); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 785c8d9..c85c160 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -1,6 +1,6 @@ #include "NeuralNetwork.hpp" -torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); +torch::Device NeuralNetworkImpl::device(getPreferredDevice()); float NeuralNetworkImpl::entropy(torch::Tensor probabilities) { @@ -13,3 +13,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) return entropy; } +torch::Device NeuralNetworkImpl::getPreferredDevice() +{ + return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 900ab29..efc6341 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -47,7 +47,7 @@ po::options_description MacaonTrain::getOptionsDescription() ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), "Max norm for the embeddings") ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") - ("lineByLine", "Treat the TXT input as being one different text per line.") + ("lineByLine", "Process the TXT input as being one different text per line.") ("help,h", "Produce this help message") ("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions."); @@ -323,11 +323,22 @@ int MacaonTrain::main() machine.trainMode(false); machine.setDictsState(Dict::State::Closed); - std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(), - [&decoder, &debug, &printAdvancement](BaseConfig & devConfig) - { - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); - }); + if (devConfigs.size() > 1) + { + NeuralNetworkImpl::device = torch::kCPU; + machine.to(NeuralNetworkImpl::device); + std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(), + [&decoder, debug, printAdvancement](BaseConfig & devConfig) + { + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); + }); + NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice(); + machine.to(NeuralNetworkImpl::device); + } + else + { + decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement); + } std::vector<const Config *> devConfigsPtrs; for (auto & devConfig : devConfigs) -- GitLab