Commit 17e6ebe9 authored by Franck Dary's avatar Franck Dary
Browse files

Macaon train decoding (devScore) on cpu in parallel

parent c7267759
......@@ -25,7 +25,7 @@ class Decoder
public :
Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
......
......@@ -6,11 +6,12 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
{
constexpr int printInterval = 50;
int nbExamplesProcessed = 0;
std::size_t totalNbExamplesProcessed = 0;
auto pastTime = std::chrono::high_resolution_clock::now();
Beam beam(beamSize, beamThreshold, baseConfig, machine);
......@@ -20,6 +21,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
while (!beam.isEnded())
{
beam.update(machine, debug);
++totalNbExamplesProcessed;
if (printAdvancement)
if (++nbExamplesProcessed >= printInterval)
......@@ -49,6 +51,8 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
// Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
try {baseConfig.addMissingColumns();}
catch (std::exception & e) {util::myThrow(e.what());}
return totalNbExamplesProcessed;
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
......
......@@ -54,6 +54,7 @@ class Classifier
bool isRegression() const;
LossFunction & getLossFunction();
bool exampleIsBanned(const Config & config);
void to(torch::Device device);
};
#endif
......@@ -52,6 +52,7 @@ class ReadingMachine
void loadPretrainedClassifiers();
int getNbParameters() const;
void resetOptimizers();
void to(torch::Device device);
};
#endif
......@@ -300,3 +300,8 @@ bool Classifier::exampleIsBanned(const Config & config)
return false;
}
void Classifier::to(torch::Device device)
{
getNN()->to(device);
}
......@@ -194,3 +194,9 @@ void ReadingMachine::resetOptimizers()
classifier->resetOptimizer();
}
void ReadingMachine::to(torch::Device device)
{
for (auto & classifier : classifiers)
classifier->to(device);
}
......@@ -23,6 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
virtual void setCountOcc(bool countOcc) = 0;
virtual void removeRareDictElements(float rarityThreshold) = 0;
static torch::Device getPreferredDevice();
static float entropy(torch::Tensor probabilities);
};
TORCH_MODULE(NeuralNetwork);
......
#include "NeuralNetwork.hpp"
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
torch::Device NeuralNetworkImpl::device(getPreferredDevice());
float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
{
......@@ -13,3 +13,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
return entropy;
}
torch::Device NeuralNetworkImpl::getPreferredDevice()
{
return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
}
......@@ -47,7 +47,7 @@ po::options_description MacaonTrain::getOptionsDescription()
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("lineByLine", "Treat the TXT input as being one different text per line.")
("lineByLine", "Process the TXT input as being one different text per line.")
("help,h", "Produce this help message")
("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions.");
......@@ -323,11 +323,22 @@ int MacaonTrain::main()
machine.trainMode(false);
machine.setDictsState(Dict::State::Closed);
std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(),
[&decoder, &debug, &printAdvancement](BaseConfig & devConfig)
{
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
});
if (devConfigs.size() > 1)
{
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(),
[&decoder, debug, printAdvancement](BaseConfig & devConfig)
{
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
});
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
}
else
{
decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement);
}
std::vector<const Config *> devConfigsPtrs;
for (auto & devConfig : devConfigs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment