diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 01bc7cc5b70e7caaef1d6d556c9ace4d4c1bc320..fe8c870c21dcc8bdf2f677d94d8e6dd6c2d98aef 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -25,7 +25,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); + std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 38957afa3eb2189f860d1d748e1099f7e9673e69..70eba0ef0219ad1a1ae3e9d4c820600b7fcc3f5a 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -6,11 +6,12 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) +std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; + std::size_t totalNbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); Beam beam(beamSize, beamThreshold, baseConfig, machine); @@ -20,6 +21,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh while (!beam.isEnded()) { beam.update(machine, debug); + ++totalNbExamplesProcessed; if (printAdvancement) if (++nbExamplesProcessed >= printInterval) @@ -49,6 +51,8 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script try {baseConfig.addMissingColumns();} catch (std::exception & e) {util::myThrow(e.what());} + + return totalNbExamplesProcessed; } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index e4b22080c791147f700e88a4dd2a50cbea3cd207..41285a311d8a6440a294b4eaf0fa5a384a2fff9d 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -54,6 +54,7 @@ class Classifier bool isRegression() const; LossFunction & getLossFunction(); bool exampleIsBanned(const Config & config); + void to(torch::Device device); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 3135635f69c7d29956bd2b28c860af497e58281f..503507076b3ea0aa806bc756a905af14c1b95015 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -52,6 +52,7 @@ class ReadingMachine void loadPretrainedClassifiers(); int getNbParameters() const; void resetOptimizers(); + void to(torch::Device device); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 68e89b7ec1d94ddc22ee4671aceabc003321ce53..a2361c8ae69f36dfb6a9d60659a59a90c52530a5 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -300,3 +300,8 @@ bool Classifier::exampleIsBanned(const Config & config) return false; } +void Classifier::to(torch::Device device) +{ + getNN()->to(device); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 33e08cd945cad3e52886d958bd304e1b62b46895..7c06ebd0e8218a0d30f6898be35a120daa7c168b 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -194,3 +194,9 @@ void ReadingMachine::resetOptimizers() classifier->resetOptimizer(); } +void ReadingMachine::to(torch::Device device) +{ + for (auto & classifier : classifiers) + classifier->to(device); +} + diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index ffbcdea03d406e38aaef06830b54e2809f27b0a4..d96f2647cbb09650f4106148953f294e4092a65b 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -23,6 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder virtual void setCountOcc(bool countOcc) = 0; virtual void removeRareDictElements(float rarityThreshold) = 0; + static torch::Device getPreferredDevice(); static float entropy(torch::Tensor probabilities); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 785c8d9d1c7ecca0342405377a73c055827562ad..c85c1602dc028670a88ed1dfacbee3c78e0896a0 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -1,6 +1,6 @@ #include "NeuralNetwork.hpp" -torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); +torch::Device NeuralNetworkImpl::device(getPreferredDevice()); float NeuralNetworkImpl::entropy(torch::Tensor probabilities) { @@ -13,3 +13,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) return entropy; } +torch::Device NeuralNetworkImpl::getPreferredDevice() +{ + return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 900ab29218bf839115679e60daec6a3358c736c3..efc63417de5f2b3e0308864478716cd89571e150 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -47,7 +47,7 @@ po::options_description MacaonTrain::getOptionsDescription() ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), "Max norm for the embeddings") ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") - ("lineByLine", "Treat the TXT input as being one different text per line.") + ("lineByLine", "Process the TXT input as being one different text per line.") ("help,h", "Produce this help message") ("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions."); @@ -323,11 +323,22 @@ int MacaonTrain::main() machine.trainMode(false); machine.setDictsState(Dict::State::Closed); - std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(), - [&decoder, &debug, &printAdvancement](BaseConfig & devConfig) - { - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); - }); + if (devConfigs.size() > 1) + { + NeuralNetworkImpl::device = torch::kCPU; + machine.to(NeuralNetworkImpl::device); + std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(), + [&decoder, debug, printAdvancement](BaseConfig & devConfig) + { + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); + }); + NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice(); + machine.to(NeuralNetworkImpl::device); + } + else + { + decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement); + } std::vector<const Config *> devConfigsPtrs; for (auto & devConfig : devConfigs)