Skip to content
Snippets Groups Projects
Commit b495167c authored by Franck Dary's avatar Franck Dary
Browse files

Parallel extractExamples

parent 30e51f46
Branches
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@
#include <unordered_map>
#include <vector>
#include <filesystem>
#include <mutex>
class Dict
{
......@@ -30,6 +31,7 @@ class Dict
std::unordered_map<std::string, int> elementsToIndexes;
std::unordered_map<int, std::string> indexesToElements;
std::vector<int> nbOccs;
std::mutex elementsMutex;
State state;
bool isCountingOccs{false};
......@@ -43,6 +45,7 @@ class Dict
void readFromFile(const char * filename);
void insert(const std::string & element);
void reset();
int _getIndexOrInsert(const std::string & element, const std::string & prefix);
public :
......
......@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
}
int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (state == State::Open)
elementsMutex.lock();
int index = _getIndexOrInsert(element, prefix);
if (state == State::Open)
elementsMutex.unlock();
return index;
}
int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (element.empty())
return getIndexOrInsert(emptyValueStr, prefix);
return _getIndexOrInsert(emptyValueStr, prefix);
if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
{
return getIndexOrInsert(separatorValueStr, prefix);
return _getIndexOrInsert(separatorValueStr, prefix);
}
if (util::isNumber(element))
return getIndexOrInsert(numberValueStr, prefix);
return _getIndexOrInsert(numberValueStr, prefix);
if (util::isUrl(element))
return getIndexOrInsert(urlValueStr, prefix);
return _getIndexOrInsert(urlValueStr, prefix);
auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
const auto & found = elementsToIndexes.find(prefixed);
......
#include "Trainer.hpp"
#include "SubConfig.hpp"
#include <execution>
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{
......@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState;
std::unordered_map<std::string, Examples> examplesPerState;
std::mutex examplesMutex;
std::filesystem::create_directories(dir);
......@@ -46,9 +48,12 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
int totalNbExamples = 0;
std::atomic<int> totalNbExamples = 0;
for (auto & config : configs)
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
std::for_each(std::execution::par_unseq, configs.begin(), configs.end(),
[this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
......@@ -157,9 +162,11 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
examplesMutex.lock();
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
examplesMutex.unlock();
}
config.setChosenActionScore(bestScore);
......@@ -179,11 +186,14 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (config.needsUpdate())
config.update();
} // End while true
} // End for on configs
}); // End for on configs
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment