Skip to content
Snippets Groups Projects
Commit 17e6ebe9 authored by Franck Dary's avatar Franck Dary
Browse files

Macaon train decoding (devScore) on cpu in parallel

parent c7267759
No related tags found
No related merge requests found
...@@ -25,7 +25,7 @@ class Decoder ...@@ -25,7 +25,7 @@ class Decoder
public : public :
Decoder(ReadingMachine & machine); Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
......
...@@ -6,11 +6,12 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -6,11 +6,12 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{ {
} }
void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
{ {
constexpr int printInterval = 50; constexpr int printInterval = 50;
int nbExamplesProcessed = 0; int nbExamplesProcessed = 0;
std::size_t totalNbExamplesProcessed = 0;
auto pastTime = std::chrono::high_resolution_clock::now(); auto pastTime = std::chrono::high_resolution_clock::now();
Beam beam(beamSize, beamThreshold, baseConfig, machine); Beam beam(beamSize, beamThreshold, baseConfig, machine);
...@@ -20,6 +21,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh ...@@ -20,6 +21,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
while (!beam.isEnded()) while (!beam.isEnded())
{ {
beam.update(machine, debug); beam.update(machine, debug);
++totalNbExamplesProcessed;
if (printAdvancement) if (printAdvancement)
if (++nbExamplesProcessed >= printInterval) if (++nbExamplesProcessed >= printInterval)
...@@ -49,6 +51,8 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh ...@@ -49,6 +51,8 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
// Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
try {baseConfig.addMissingColumns();} try {baseConfig.addMissingColumns();}
catch (std::exception & e) {util::myThrow(e.what());} catch (std::exception & e) {util::myThrow(e.what());}
return totalNbExamplesProcessed;
} }
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
......
...@@ -54,6 +54,7 @@ class Classifier ...@@ -54,6 +54,7 @@ class Classifier
bool isRegression() const; bool isRegression() const;
LossFunction & getLossFunction(); LossFunction & getLossFunction();
bool exampleIsBanned(const Config & config); bool exampleIsBanned(const Config & config);
void to(torch::Device device);
}; };
#endif #endif
...@@ -52,6 +52,7 @@ class ReadingMachine ...@@ -52,6 +52,7 @@ class ReadingMachine
void loadPretrainedClassifiers(); void loadPretrainedClassifiers();
int getNbParameters() const; int getNbParameters() const;
void resetOptimizers(); void resetOptimizers();
void to(torch::Device device);
}; };
#endif #endif
...@@ -300,3 +300,8 @@ bool Classifier::exampleIsBanned(const Config & config) ...@@ -300,3 +300,8 @@ bool Classifier::exampleIsBanned(const Config & config)
return false; return false;
} }
void Classifier::to(torch::Device device)
{
getNN()->to(device);
}
...@@ -194,3 +194,9 @@ void ReadingMachine::resetOptimizers() ...@@ -194,3 +194,9 @@ void ReadingMachine::resetOptimizers()
classifier->resetOptimizer(); classifier->resetOptimizer();
} }
void ReadingMachine::to(torch::Device device)
{
for (auto & classifier : classifiers)
classifier->to(device);
}
...@@ -23,6 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder ...@@ -23,6 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
virtual void setCountOcc(bool countOcc) = 0; virtual void setCountOcc(bool countOcc) = 0;
virtual void removeRareDictElements(float rarityThreshold) = 0; virtual void removeRareDictElements(float rarityThreshold) = 0;
static torch::Device getPreferredDevice();
static float entropy(torch::Tensor probabilities); static float entropy(torch::Tensor probabilities);
}; };
TORCH_MODULE(NeuralNetwork); TORCH_MODULE(NeuralNetwork);
......
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); torch::Device NeuralNetworkImpl::device(getPreferredDevice());
float NeuralNetworkImpl::entropy(torch::Tensor probabilities) float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
{ {
...@@ -13,3 +13,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) ...@@ -13,3 +13,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
return entropy; return entropy;
} }
torch::Device NeuralNetworkImpl::getPreferredDevice()
{
return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
}
...@@ -47,7 +47,7 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -47,7 +47,7 @@ po::options_description MacaonTrain::getOptionsDescription()
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings") "Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("lineByLine", "Treat the TXT input as being one different text per line.") ("lineByLine", "Process the TXT input as being one different text per line.")
("help,h", "Produce this help message") ("help,h", "Produce this help message")
("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions."); ("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions.");
...@@ -323,11 +323,22 @@ int MacaonTrain::main() ...@@ -323,11 +323,22 @@ int MacaonTrain::main()
machine.trainMode(false); machine.trainMode(false);
machine.setDictsState(Dict::State::Closed); machine.setDictsState(Dict::State::Closed);
if (devConfigs.size() > 1)
{
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(), std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(),
[&decoder, &debug, &printAdvancement](BaseConfig & devConfig) [&decoder, debug, printAdvancement](BaseConfig & devConfig)
{ {
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
}); });
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
}
else
{
decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement);
}
std::vector<const Config *> devConfigsPtrs; std::vector<const Config *> devConfigsPtrs;
for (auto & devConfig : devConfigs) for (auto & devConfig : devConfigs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment