Skip to content
Snippets Groups Projects
Commit b495167c authored by Franck Dary's avatar Franck Dary
Browse files

Parallel extractExamples

parent 30e51f46
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <filesystem> #include <filesystem>
#include <mutex>
class Dict class Dict
{ {
...@@ -30,6 +31,7 @@ class Dict ...@@ -30,6 +31,7 @@ class Dict
std::unordered_map<std::string, int> elementsToIndexes; std::unordered_map<std::string, int> elementsToIndexes;
std::unordered_map<int, std::string> indexesToElements; std::unordered_map<int, std::string> indexesToElements;
std::vector<int> nbOccs; std::vector<int> nbOccs;
std::mutex elementsMutex;
State state; State state;
bool isCountingOccs{false}; bool isCountingOccs{false};
...@@ -43,6 +45,7 @@ class Dict ...@@ -43,6 +45,7 @@ class Dict
void readFromFile(const char * filename); void readFromFile(const char * filename);
void insert(const std::string & element); void insert(const std::string & element);
void reset(); void reset();
int _getIndexOrInsert(const std::string & element, const std::string & prefix);
public : public :
......
...@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element) ...@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
} }
int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix) int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (state == State::Open)
elementsMutex.lock();
int index = _getIndexOrInsert(element, prefix);
if (state == State::Open)
elementsMutex.unlock();
return index;
}
int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
{ {
if (element.empty()) if (element.empty())
return getIndexOrInsert(emptyValueStr, prefix); return _getIndexOrInsert(emptyValueStr, prefix);
if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element))) if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
{ {
return getIndexOrInsert(separatorValueStr, prefix); return _getIndexOrInsert(separatorValueStr, prefix);
} }
if (util::isNumber(element)) if (util::isNumber(element))
return getIndexOrInsert(numberValueStr, prefix); return _getIndexOrInsert(numberValueStr, prefix);
if (util::isUrl(element)) if (util::isUrl(element))
return getIndexOrInsert(urlValueStr, prefix); return _getIndexOrInsert(urlValueStr, prefix);
auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element); auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
const auto & found = elementsToIndexes.find(prefixed); const auto & found = elementsToIndexes.find(prefixed);
......
#include "Trainer.hpp" #include "Trainer.hpp"
#include "SubConfig.hpp" #include "SubConfig.hpp"
#include <execution>
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{ {
...@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: ...@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000; int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState; std::unordered_map<std::string, Examples> examplesPerState;
std::mutex examplesMutex;
std::filesystem::create_directories(dir); std::filesystem::create_directories(dir);
...@@ -46,9 +48,12 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: ...@@ -46,9 +48,12 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
int totalNbExamples = 0; std::atomic<int> totalNbExamples = 0;
for (auto & config : configs) NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
std::for_each(std::execution::par_unseq, configs.begin(), configs.end(),
[this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
{ {
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition()); config.setStrategy(machine.getStrategyDefinition());
...@@ -157,9 +162,11 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: ...@@ -157,9 +162,11 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (totalNbExamples >= (int)safetyNbExamplesMax) if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
examplesMutex.lock();
examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
examplesMutex.unlock();
} }
config.setChosenActionScore(bestScore); config.setChosenActionScore(bestScore);
...@@ -179,11 +186,14 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: ...@@ -179,11 +186,14 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (config.needsUpdate()) if (config.needsUpdate())
config.update(); config.update();
} // End while true } // End while true
} // End for on configs }); // End for on configs
for (auto & it : examplesPerState) for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle); it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f) if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment