Commit b495167c authored by Franck Dary's avatar Franck Dary
Browse files

Parallel extractExamples

parent 30e51f46
......@@ -5,6 +5,7 @@
#include <unordered_map>
#include <vector>
#include <filesystem>
#include <mutex>
class Dict
{
......@@ -30,6 +31,7 @@ class Dict
std::unordered_map<std::string, int> elementsToIndexes;
std::unordered_map<int, std::string> indexesToElements;
std::vector<int> nbOccs;
std::mutex elementsMutex;
State state;
bool isCountingOccs{false};
......@@ -43,6 +45,7 @@ class Dict
void readFromFile(const char * filename);
void insert(const std::string & element);
void reset();
int _getIndexOrInsert(const std::string & element, const std::string & prefix);
public :
......
......@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
}
int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (state == State::Open)
elementsMutex.lock();
int index = _getIndexOrInsert(element, prefix);
if (state == State::Open)
elementsMutex.unlock();
return index;
}
int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (element.empty())
return getIndexOrInsert(emptyValueStr, prefix);
return _getIndexOrInsert(emptyValueStr, prefix);
if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
{
return getIndexOrInsert(separatorValueStr, prefix);
return _getIndexOrInsert(separatorValueStr, prefix);
}
if (util::isNumber(element))
return getIndexOrInsert(numberValueStr, prefix);
return _getIndexOrInsert(numberValueStr, prefix);
if (util::isUrl(element))
return getIndexOrInsert(urlValueStr, prefix);
return _getIndexOrInsert(urlValueStr, prefix);
auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
const auto & found = elementsToIndexes.find(prefixed);
......
#include "Trainer.hpp"
#include "SubConfig.hpp"
#include <execution>
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{
......@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState;
std::unordered_map<std::string, Examples> examplesPerState;
std::mutex examplesMutex;
std::filesystem::create_directories(dir);
......@@ -46,144 +48,152 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
int totalNbExamples = 0;
std::atomic<int> totalNbExamples = 0;
for (auto & config : configs)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
while (true)
NeuralNetworkImpl::device = torch::kCPU;
machine.to(NeuralNetworkImpl::device);
std::for_each(std::execution::par_unseq, configs.begin(), configs.end(),
[this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
{
if (debug)
config.printForDebug(stderr);
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
while (true)
{
if (debug)
config.printForDebug(stderr);
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
torch::Tensor context;
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
try
{
context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
torch::Tensor context;
Transition * transition = nullptr;
try
{
context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
Transition * transition = nullptr;
Transition * goldTransition = goldTransitions[0];
if (config.getState() == "parser")
goldTransitions[std::rand()%goldTransitions.size()];
auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
int nbClasses = machine.getTransitionSet(config.getState()).size();
Transition * goldTransition = goldTransitions[0];
if (config.getState() == "parser")
goldTransitions[std::rand()%goldTransitions.size()];
float bestScore = -std::numeric_limits<float>::max();
int nbClasses = machine.getTransitionSet(config.getState()).size();
float entropy = 0.0;
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto & classifier = *machine.getClassifier(config.getState());
auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction);
std::vector<int> candidates;
float bestScore = -std::numeric_limits<float>::max();
for (unsigned int i = 0; i < prediction.size(0); i++)
float entropy = 0.0;
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
float score = prediction[i].item<float>();
if (score > bestScore and appliableTransitions[i])
bestScore = score;
auto & classifier = *machine.getClassifier(config.getState());
auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction);
std::vector<int> candidates;
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (score > bestScore and appliableTransitions[i])
bestScore = score;
}
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
candidates.emplace_back(i);
}
transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
}
for (unsigned int i = 0; i < prediction.size(0); i++)
else
{
float score = prediction[i].item<float>();
if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
candidates.emplace_back(i);
transition = goldTransition;
}
transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
}
else
{
transition = goldTransition;
}
if (!transition or !goldTransition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
if (!transition or !goldTransition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
std::vector<long> goldIndexes;
bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
std::vector<long> goldIndexes;
bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
if (machine.getClassifier(config.getState())->isRegression())
{
entropy = 0.0;
auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
auto splited = util::split(transition->getName(), ' ');
if (splited.size() != 3 or splited[0] != "WRITESCORE")
util::myThrow(errMessage);
auto col = splited[2];
splited = util::split(splited[1], '.');
if (splited.size() != 2)
util::myThrow(errMessage);
auto object = Config::str2object(splited[0]);
int index = std::stoi(splited[1]);
float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0));
goldIndexes.emplace_back(util::float2long(regressionTarget));
}
else
{
for (auto & t : goldTransitions)
goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));
}
if (machine.getClassifier(config.getState())->isRegression())
{
entropy = 0.0;
auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
auto splited = util::split(transition->getName(), ' ');
if (splited.size() != 3 or splited[0] != "WRITESCORE")
util::myThrow(errMessage);
auto col = splited[2];
splited = util::split(splited[1], '.');
if (splited.size() != 2)
util::myThrow(errMessage);
auto object = Config::str2object(splited[0]);
int index = std::stoi(splited[1]);
float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0));
goldIndexes.emplace_back(util::float2long(regressionTarget));
}
else
{
for (auto & t : goldTransitions)
goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));
if (!exampleIsBanned)
{
totalNbExamples += 1;
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
}
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
}
if (!exampleIsBanned)
{
totalNbExamples += 1;
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
examplesMutex.lock();
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
examplesMutex.unlock();
}
config.setChosenActionScore(bestScore);
config.setChosenActionScore(bestScore);
transition->apply(config, entropy);
config.addToHistory(transition->getName());
transition->apply(config, entropy);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
config.setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (config.needsUpdate())
config.update();
} // End while true
} // End for on configs
if (config.needsUpdate())
config.update();
} // End while true
}); // End for on configs
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
machine.to(NeuralNetworkImpl::device);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment