Skip to content
Snippets Groups Projects
Commit 9e3b06af authored by Franck Dary's avatar Franck Dary
Browse files

Introduced trainStrategy

parent 75b4d5ff
No related branches found
No related tags found
No related merge requests found
......@@ -16,14 +16,15 @@ class Beam
BaseConfig config;
int nextTransition{-1};
boost::circular_buffer<double> probabilities{500};
boost::circular_buffer<std::string> name{20};
float meanProbability{0.0};
int nbTransitions = 0;
double totalProbability{0.0};
bool ended{false};
public :
Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name);
Element(const Element & other, int nextTransition);
Element(const BaseConfig & model);
};
......
......@@ -8,8 +8,9 @@ Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const Reading
elements.emplace_back(model);
}
Beam::Element::Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name) : config(model), nextTransition(nextTransition), probabilities(probabilities), name(name)
Beam::Element::Element(const Element & other, int nextTransition) : Element(other)
{
this->nextTransition = nextTransition;
}
Beam::Element::Element(const BaseConfig & model) : config(model)
......@@ -71,22 +72,19 @@ void Beam::update(ReadingMachine & machine, bool debug)
if (width > 1)
for (unsigned int i = 1; i < scoresOfTransitions.size(); i++)
{
elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].probabilities, elements[index].name);
elements.emplace_back(elements[index], scoresOfTransitions[i].second);
elements.back().name.push_back(std::to_string(i));
elements.back().probabilities.push_back(scoresOfTransitions[i].first);
elements.back().meanProbability = 0.0;
for (auto & p : elements.back().probabilities)
elements.back().meanProbability += p;
elements.back().meanProbability /= elements.back().probabilities.size();
elements.back().totalProbability += scoresOfTransitions[i].first;
elements.back().nbTransitions++;
elements.back().meanProbability = elements.back().totalProbability / elements.back().nbTransitions;
}
elements[index].nextTransition = scoresOfTransitions[0].second;
elements[index].probabilities.push_back(scoresOfTransitions[0].first);
elements[index].totalProbability += scoresOfTransitions[0].first;
elements[index].nbTransitions++;
elements[index].name.push_back("0");
elements[index].meanProbability = 0.0;
for (auto & p : elements[index].probabilities)
elements[index].meanProbability += p;
elements[index].meanProbability /= elements[index].probabilities.size();
elements[index].meanProbability = elements[index].totalProbability / elements[index].nbTransitions;
if (debug)
{
......
......@@ -19,6 +19,8 @@ class ReadingMachine
std::filesystem::path path;
std::unique_ptr<Classifier> classifier;
std::vector<std::string> strategyDefinition;
std::vector<std::string> classifierDefinition;
std::string classifierName;
std::set<std::string> predicted;
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
......@@ -48,6 +50,7 @@ class ReadingMachine
void loadLastSaved();
void setCountOcc(bool countOcc);
void removeRareDictElements(float rarityThreshold);
void resetClassifier();
};
#endif
......@@ -58,7 +58,8 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
{
std::vector<std::string> classifierDefinition;
classifierDefinition.clear();
classifierName = sm.str(1);
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
......@@ -196,3 +197,9 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold)
classifier->getNN()->removeRareDictElements(rarityThreshold);
}
void ReadingMachine::resetClassifier()
{
classifier.reset(new Classifier(classifierName, path, classifierDefinition));
loadDicts();
}
......@@ -6,7 +6,7 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir)
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
{
auto stem = entry.path().stem().string();
auto stem = util::split(entry.path().stem().string(), '.')[0];
if (stem == "extracted")
continue;
auto state = util::split(stem, '_')[0];
......
......@@ -20,6 +20,10 @@ class MacaonTrain
po::options_description getOptionsDescription();
po::variables_map checkOptions(po::options_description & od);
private :
Trainer::TrainStrategy parseTrainStrategy(std::string s);
public :
MacaonTrain(int argc, char ** argv);
......
......@@ -7,6 +7,20 @@
class Trainer
{
public :
enum TrainAction
{
ExtractGold,
ExtractDynamic,
DeleteExamples,
ResetOptimizer,
ResetParameters,
Save
};
using TrainStrategy = std::map<std::size_t, std::set<TrainAction>>;
static TrainAction str2TrainAction(const std::string & s);
private :
static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000;
......@@ -19,7 +33,7 @@ class Trainer
int currentExampleIndex{0};
int lastSavedIndex{0};
void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold);
void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
void addContext(std::vector<std::vector<long>> & context);
void addClass(int goldIndex);
};
......@@ -41,15 +55,16 @@ class Trainer
private :
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
void fillDicts(SubConfig & config, bool debug);
public :
Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
void makeDataLoader(std::filesystem::path dir);
void makeDevDataLoader(std::filesystem::path dir);
void fillDicts(BaseConfig & goldConfig, bool debug);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
......
......@@ -33,12 +33,12 @@ po::options_description MacaonTrain::getOptionsDescription()
"Number of training epochs")
("batchSize", po::value<int>()->default_value(64),
"Number of examples per batch")
("dynamicOracleInterval", po::value<int>()->default_value(-1),
"Every X epochs, the machine will be used to decode the train and dev corpora. Thus allowing the machine to train using it's own predictions as feature. A value of -1 means the machine will always train on GOLD features. This option slows training down by a LOT.")
("rarityThreshold", po::value<float>()->default_value(70.0),
"During train, the X% rarest elements will be treated as unknown values")
("machine", po::value<std::string>()->default_value(""),
"Reading machine file content")
("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold"),
"Description of what should happen during training")
("pretrainedEmbeddings", po::value<std::string>()->default_value(""),
"File containing pretrained embeddings, w2v format")
("help,h", "Produce this help message");
......@@ -69,6 +69,27 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od)
return vm;
}
Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
{
Trainer::TrainStrategy ts;
try
{
auto splited = util::split(s, ':');
for (auto & ss : splited)
{
auto elements = util::split(ss, ',');
int epoch = std::stoi(elements[0]);
for (unsigned int i = 1; i < elements.size(); i++)
ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
}
} catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}
return ts;
}
int MacaonTrain::main()
{
auto od = getOptionsDescription();
......@@ -83,13 +104,15 @@ int MacaonTrain::main()
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
auto batchSize = variables["batchSize"].as<int>();
auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
auto rarityThreshold = variables["rarityThreshold"].as<float>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
auto machineContent = variables["machine"].as<std::string>();
auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>();
auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto trainStrategy = parseTrainStrategy(trainStrategyStr);
torch::globalContext().setBenchmarkCuDNN(true);
......@@ -146,20 +169,15 @@ int MacaonTrain::main()
{
if (buffer != std::fgets(buffer, 1024, f))
break;
bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
float devScoreMean = std::stof(util::split(buffer, '\t').back());
if (computeDevScore and (devScoreMean > bestDevScore or currentEpoch == dynamicOracleInterval))
bestDevScore = devScoreMean;
if (!computeDevScore and (devScoreMean < bestDevScore or currentEpoch == dynamicOracleInterval))
if (saved)
bestDevScore = devScoreMean;
currentEpoch++;
}
std::fclose(f);
}
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
machine.getClassifier()->resetOptimizer();
auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
if (std::filesystem::exists(trainInfos))
......@@ -167,9 +185,44 @@ int MacaonTrain::main()
for (; currentEpoch < nbEpoch; currentEpoch++)
{
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
bool saved = false;
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
{
for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
if (!computeDevScore)
for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
{
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
if (!computeDevScore)
trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
{
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
{
machine.resetClassifier();
machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings);
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
}
machine.getClassifier()->resetOptimizer();
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
{
saved = true;
}
trainer.makeDataLoader(modelPath/"examples/train");
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
trainer.makeDevDataLoader(modelPath/"examples/dev");
float loss = trainer.epoch(printAdvancement);
if (debug)
......@@ -201,13 +254,12 @@ int MacaonTrain::main()
if (!devScoresStr.empty())
devScoresStr.pop_back();
devScoreMean /= devScores.size();
bool saved = devScoreMean >= bestDevScore;
if (!computeDevScore)
saved = devScoreMean <= bestDevScore;
if (computeDevScore)
saved = saved or devScoreMean >= bestDevScore;
else
saved = saved or devScoreMean <= bestDevScore;
if (currentEpoch == dynamicOracleInterval)
saved = true;
if (saved)
{
bestDevScore = devScoreMean;
......
......@@ -5,33 +5,29 @@ Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), ba
{
}
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
void Trainer::makeDataLoader(std::filesystem::path dir)
{
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
{
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
extractExamples(config, debug, dir, epoch, dynamicOracle);
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
{
torch::AutoGradMode useGrad(false);
......@@ -45,22 +41,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
config.setState(config.getStrategy().getInitialState());
machine.getClassifier()->setState(config.getState());
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
mustExtract = false;
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
if (!mustExtract)
if (std::filesystem::exists(currentEpochAllExtractedFile))
return;
bool dynamicOracle = epoch != 0;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
int totalNbExamples = 0;
while (true)
......@@ -88,7 +75,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "segmenter")
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
......@@ -127,7 +114,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(goldIndex);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
transition->apply(config);
config.addToHistory(transition->getName());
......@@ -147,7 +134,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
}
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0);
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
......@@ -240,7 +227,7 @@ float Trainer::evalOnDev(bool printAdvancement)
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
}
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold)
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
{
if (currentExampleIndex-lastSavedIndex < (int)threshold)
return;
......@@ -248,7 +235,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
return;
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1);
auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
......@@ -340,3 +327,23 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
}
}
Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
{
if (s == "ExtractGold")
return TrainAction::ExtractGold;
if (s == "ExtractDynamic")
return TrainAction::ExtractDynamic;
if (s == "DeleteExamples")
return TrainAction::DeleteExamples;
if (s == "ResetOptimizer")
return TrainAction::ResetOptimizer;
if (s == "ResetParameters")
return TrainAction::ResetParameters;
if (s == "Save")
return TrainAction::Save;
util::myThrow(fmt::format("unknown TrainAction '{}'", s));
return TrainAction::ExtractGold;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment