Skip to content
Snippets Groups Projects
Commit b495167c authored by Franck Dary's avatar Franck Dary
Browse files

Parallel extractExamples

parent 30e51f46
Branches
Tags
No related merge requests found
......@@ -5,6 +5,7 @@
#include <unordered_map>
#include <vector>
#include <filesystem>
#include <mutex>
class Dict
{
......@@ -30,6 +31,7 @@ class Dict
std::unordered_map<std::string, int> elementsToIndexes;
std::unordered_map<int, std::string> indexesToElements;
std::vector<int> nbOccs;
std::mutex elementsMutex;
State state;
bool isCountingOccs{false};
......@@ -43,6 +45,7 @@ class Dict
void readFromFile(const char * filename);
void insert(const std::string & element);
void reset();
int _getIndexOrInsert(const std::string & element, const std::string & prefix);
public :
......
......@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
}
int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (state == State::Open)
elementsMutex.lock();
int index = _getIndexOrInsert(element, prefix);
if (state == State::Open)
elementsMutex.unlock();
return index;
}
int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (element.empty())
return getIndexOrInsert(emptyValueStr, prefix);
return _getIndexOrInsert(emptyValueStr, prefix);
if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
{
return getIndexOrInsert(separatorValueStr, prefix);
return _getIndexOrInsert(separatorValueStr, prefix);
}
if (util::isNumber(element))
return getIndexOrInsert(numberValueStr, prefix);
return _getIndexOrInsert(numberValueStr, prefix);
if (util::isUrl(element))
return getIndexOrInsert(urlValueStr, prefix);
return _getIndexOrInsert(urlValueStr, prefix);
auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
const auto & found = elementsToIndexes.find(prefixed);
......
#include "Trainer.hpp"
#include "SubConfig.hpp"
#include <execution>
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{
......@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState;
std::unordered_map<std::string, Examples> examplesPerState;
std::mutex examplesMutex;
std::filesystem::create_directories(dir);
......@@ -46,9 +48,12 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
int totalNbExamples = 0;
std::atomic<int> totalNbExamples = 0;
for (auto & config : configs)
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
std::for_each(std::execution::par_unseq, configs.begin(), configs.end(),
[this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
......@@ -157,9 +162,11 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
examplesMutex.lock();
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
examplesMutex.unlock();
}
config.setChosenActionScore(bestScore);
......@@ -179,11 +186,14 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (config.needsUpdate())
config.update();
} // End while true
} // End for on configs
}); // End for on configs
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment