Skip to content
Snippets Groups Projects
Commit 9e3b06af authored by Franck Dary's avatar Franck Dary
Browse files

Introduced trainStrategy

parent 75b4d5ff
No related branches found
No related tags found
No related merge requests found
...@@ -16,14 +16,15 @@ class Beam ...@@ -16,14 +16,15 @@ class Beam
BaseConfig config; BaseConfig config;
int nextTransition{-1}; int nextTransition{-1};
boost::circular_buffer<double> probabilities{500};
boost::circular_buffer<std::string> name{20}; boost::circular_buffer<std::string> name{20};
float meanProbability{0.0}; float meanProbability{0.0};
int nbTransitions = 0;
double totalProbability{0.0};
bool ended{false}; bool ended{false};
public : public :
Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name); Element(const Element & other, int nextTransition);
Element(const BaseConfig & model); Element(const BaseConfig & model);
}; };
......
...@@ -8,8 +8,9 @@ Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const Reading ...@@ -8,8 +8,9 @@ Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const Reading
elements.emplace_back(model); elements.emplace_back(model);
} }
Beam::Element::Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name) : config(model), nextTransition(nextTransition), probabilities(probabilities), name(name) Beam::Element::Element(const Element & other, int nextTransition) : Element(other)
{ {
this->nextTransition = nextTransition;
} }
Beam::Element::Element(const BaseConfig & model) : config(model) Beam::Element::Element(const BaseConfig & model) : config(model)
...@@ -71,22 +72,19 @@ void Beam::update(ReadingMachine & machine, bool debug) ...@@ -71,22 +72,19 @@ void Beam::update(ReadingMachine & machine, bool debug)
if (width > 1) if (width > 1)
for (unsigned int i = 1; i < scoresOfTransitions.size(); i++) for (unsigned int i = 1; i < scoresOfTransitions.size(); i++)
{ {
elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].probabilities, elements[index].name); elements.emplace_back(elements[index], scoresOfTransitions[i].second);
elements.back().name.push_back(std::to_string(i)); elements.back().name.push_back(std::to_string(i));
elements.back().probabilities.push_back(scoresOfTransitions[i].first); elements.back().totalProbability += scoresOfTransitions[i].first;
elements.back().meanProbability = 0.0; elements.back().nbTransitions++;
for (auto & p : elements.back().probabilities) elements.back().meanProbability = elements.back().totalProbability / elements.back().nbTransitions;
elements.back().meanProbability += p;
elements.back().meanProbability /= elements.back().probabilities.size();
} }
elements[index].nextTransition = scoresOfTransitions[0].second; elements[index].nextTransition = scoresOfTransitions[0].second;
elements[index].probabilities.push_back(scoresOfTransitions[0].first); elements[index].totalProbability += scoresOfTransitions[0].first;
elements[index].nbTransitions++;
elements[index].name.push_back("0"); elements[index].name.push_back("0");
elements[index].meanProbability = 0.0; elements[index].meanProbability = 0.0;
for (auto & p : elements[index].probabilities) elements[index].meanProbability = elements[index].totalProbability / elements[index].nbTransitions;
elements[index].meanProbability += p;
elements[index].meanProbability /= elements[index].probabilities.size();
if (debug) if (debug)
{ {
......
...@@ -19,6 +19,8 @@ class ReadingMachine ...@@ -19,6 +19,8 @@ class ReadingMachine
std::filesystem::path path; std::filesystem::path path;
std::unique_ptr<Classifier> classifier; std::unique_ptr<Classifier> classifier;
std::vector<std::string> strategyDefinition; std::vector<std::string> strategyDefinition;
std::vector<std::string> classifierDefinition;
std::string classifierName;
std::set<std::string> predicted; std::set<std::string> predicted;
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
...@@ -48,6 +50,7 @@ class ReadingMachine ...@@ -48,6 +50,7 @@ class ReadingMachine
void loadLastSaved(); void loadLastSaved();
void setCountOcc(bool countOcc); void setCountOcc(bool countOcc);
void removeRareDictElements(float rarityThreshold); void removeRareDictElements(float rarityThreshold);
void resetClassifier();
}; };
#endif #endif
...@@ -58,7 +58,8 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -58,7 +58,8 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm) while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
{ {
std::vector<std::string> classifierDefinition; classifierDefinition.clear();
classifierName = sm.str(1);
if (lines[curLine] != "{") if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
...@@ -196,3 +197,9 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold) ...@@ -196,3 +197,9 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold)
classifier->getNN()->removeRareDictElements(rarityThreshold); classifier->getNN()->removeRareDictElements(rarityThreshold);
} }
void ReadingMachine::resetClassifier()
{
classifier.reset(new Classifier(classifierName, path, classifierDefinition));
loadDicts();
}
...@@ -6,7 +6,7 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir) ...@@ -6,7 +6,7 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir)
for (auto & entry : std::filesystem::directory_iterator(dir)) for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file()) if (entry.is_regular_file())
{ {
auto stem = entry.path().stem().string(); auto stem = util::split(entry.path().stem().string(), '.')[0];
if (stem == "extracted") if (stem == "extracted")
continue; continue;
auto state = util::split(stem, '_')[0]; auto state = util::split(stem, '_')[0];
......
...@@ -20,6 +20,10 @@ class MacaonTrain ...@@ -20,6 +20,10 @@ class MacaonTrain
po::options_description getOptionsDescription(); po::options_description getOptionsDescription();
po::variables_map checkOptions(po::options_description & od); po::variables_map checkOptions(po::options_description & od);
private :
Trainer::TrainStrategy parseTrainStrategy(std::string s);
public : public :
MacaonTrain(int argc, char ** argv); MacaonTrain(int argc, char ** argv);
......
...@@ -7,6 +7,20 @@ ...@@ -7,6 +7,20 @@
class Trainer class Trainer
{ {
public :
enum TrainAction
{
ExtractGold,
ExtractDynamic,
DeleteExamples,
ResetOptimizer,
ResetParameters,
Save
};
using TrainStrategy = std::map<std::size_t, std::set<TrainAction>>;
static TrainAction str2TrainAction(const std::string & s);
private : private :
static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000; static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000;
...@@ -19,7 +33,7 @@ class Trainer ...@@ -19,7 +33,7 @@ class Trainer
int currentExampleIndex{0}; int currentExampleIndex{0};
int lastSavedIndex{0}; int lastSavedIndex{0};
void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold); void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
void addContext(std::vector<std::vector<long>> & context); void addContext(std::vector<std::vector<long>> & context);
void addClass(int goldIndex); void addClass(int goldIndex);
}; };
...@@ -41,15 +55,16 @@ class Trainer ...@@ -41,15 +55,16 @@ class Trainer
private : private :
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
void fillDicts(SubConfig & config, bool debug); void fillDicts(SubConfig & config, bool debug);
public : public :
Trainer(ReadingMachine & machine, int batchSize); Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); void makeDataLoader(std::filesystem::path dir);
void makeDevDataLoader(std::filesystem::path dir);
void fillDicts(BaseConfig & goldConfig, bool debug); void fillDicts(BaseConfig & goldConfig, bool debug);
float epoch(bool printAdvancement); float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement); float evalOnDev(bool printAdvancement);
......
...@@ -33,12 +33,12 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -33,12 +33,12 @@ po::options_description MacaonTrain::getOptionsDescription()
"Number of training epochs") "Number of training epochs")
("batchSize", po::value<int>()->default_value(64), ("batchSize", po::value<int>()->default_value(64),
"Number of examples per batch") "Number of examples per batch")
("dynamicOracleInterval", po::value<int>()->default_value(-1),
"Every X epochs, the machine will be used to decode the train and dev corpora. Thus allowing the machine to train using it's own predictions as feature. A value of -1 means the machine will always train on GOLD features. This option slows training down by a LOT.")
("rarityThreshold", po::value<float>()->default_value(70.0), ("rarityThreshold", po::value<float>()->default_value(70.0),
"During train, the X% rarest elements will be treated as unknown values") "During train, the X% rarest elements will be treated as unknown values")
("machine", po::value<std::string>()->default_value(""), ("machine", po::value<std::string>()->default_value(""),
"Reading machine file content") "Reading machine file content")
("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold"),
"Description of what should happen during training")
("pretrainedEmbeddings", po::value<std::string>()->default_value(""), ("pretrainedEmbeddings", po::value<std::string>()->default_value(""),
"File containing pretrained embeddings, w2v format") "File containing pretrained embeddings, w2v format")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
...@@ -69,6 +69,27 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od) ...@@ -69,6 +69,27 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od)
return vm; return vm;
} }
Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
{
Trainer::TrainStrategy ts;
try
{
auto splited = util::split(s, ':');
for (auto & ss : splited)
{
auto elements = util::split(ss, ',');
int epoch = std::stoi(elements[0]);
for (unsigned int i = 1; i < elements.size(); i++)
ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
}
} catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}
return ts;
}
int MacaonTrain::main() int MacaonTrain::main()
{ {
auto od = getOptionsDescription(); auto od = getOptionsDescription();
...@@ -83,13 +104,15 @@ int MacaonTrain::main() ...@@ -83,13 +104,15 @@ int MacaonTrain::main()
auto devRawFile = variables["devTXT"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>(); auto nbEpoch = variables["nbEpochs"].as<int>();
auto batchSize = variables["batchSize"].as<int>(); auto batchSize = variables["batchSize"].as<int>();
auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
auto rarityThreshold = variables["rarityThreshold"].as<float>(); auto rarityThreshold = variables["rarityThreshold"].as<float>();
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true; bool computeDevScore = variables.count("devScore") == 0 ? false : true;
auto machineContent = variables["machine"].as<std::string>(); auto machineContent = variables["machine"].as<std::string>();
auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>(); auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>();
auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto trainStrategy = parseTrainStrategy(trainStrategyStr);
torch::globalContext().setBenchmarkCuDNN(true); torch::globalContext().setBenchmarkCuDNN(true);
...@@ -146,20 +169,15 @@ int MacaonTrain::main() ...@@ -146,20 +169,15 @@ int MacaonTrain::main()
{ {
if (buffer != std::fgets(buffer, 1024, f)) if (buffer != std::fgets(buffer, 1024, f))
break; break;
bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
float devScoreMean = std::stof(util::split(buffer, '\t').back()); float devScoreMean = std::stof(util::split(buffer, '\t').back());
if (computeDevScore and (devScoreMean > bestDevScore or currentEpoch == dynamicOracleInterval)) if (saved)
bestDevScore = devScoreMean;
if (!computeDevScore and (devScoreMean < bestDevScore or currentEpoch == dynamicOracleInterval))
bestDevScore = devScoreMean; bestDevScore = devScoreMean;
currentEpoch++; currentEpoch++;
} }
std::fclose(f); std::fclose(f);
} }
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
machine.getClassifier()->resetOptimizer(); machine.getClassifier()->resetOptimizer();
auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer"; auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
if (std::filesystem::exists(trainInfos)) if (std::filesystem::exists(trainInfos))
...@@ -167,9 +185,44 @@ int MacaonTrain::main() ...@@ -167,9 +185,44 @@ int MacaonTrain::main()
for (; currentEpoch < nbEpoch; currentEpoch++) for (; currentEpoch < nbEpoch; currentEpoch++)
{ {
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval); bool saved = false;
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
{
for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
if (!computeDevScore)
for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
{
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
if (!computeDevScore)
trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
{
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
{
machine.resetClassifier();
machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings);
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
}
machine.getClassifier()->resetOptimizer();
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
{
saved = true;
}
trainer.makeDataLoader(modelPath/"examples/train");
if (!computeDevScore) if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); trainer.makeDevDataLoader(modelPath/"examples/dev");
float loss = trainer.epoch(printAdvancement); float loss = trainer.epoch(printAdvancement);
if (debug) if (debug)
...@@ -201,13 +254,12 @@ int MacaonTrain::main() ...@@ -201,13 +254,12 @@ int MacaonTrain::main()
if (!devScoresStr.empty()) if (!devScoresStr.empty())
devScoresStr.pop_back(); devScoresStr.pop_back();
devScoreMean /= devScores.size(); devScoreMean /= devScores.size();
bool saved = devScoreMean >= bestDevScore;
if (!computeDevScore) if (computeDevScore)
saved = devScoreMean <= bestDevScore; saved = saved or devScoreMean >= bestDevScore;
else
saved = saved or devScoreMean <= bestDevScore;
if (currentEpoch == dynamicOracleInterval)
saved = true;
if (saved) if (saved)
{ {
bestDevScore = devScoreMean; bestDevScore = devScoreMean;
......
...@@ -5,33 +5,29 @@ Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), ba ...@@ -5,33 +5,29 @@ Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), ba
{ {
} }
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) void Trainer::makeDataLoader(std::filesystem::path dir)
{ {
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir)); trainDataset.reset(new Dataset(dir));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
} }
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
{ {
SubConfig config(goldConfig, goldConfig.getNbLines()); SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false); machine.trainMode(false);
machine.setDictsState(Dict::State::Closed); machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval); extractExamples(config, debug, dir, epoch, dynamicOracle);
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
} }
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
...@@ -45,22 +41,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -45,22 +41,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
config.setState(config.getStrategy().getInitialState()); config.setState(config.getStrategy().getInitialState());
machine.getClassifier()->setState(config.getState()); machine.getClassifier()->setState(config.getState());
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
mustExtract = false;
if (!mustExtract) if (std::filesystem::exists(currentEpochAllExtractedFile))
return; return;
bool dynamicOracle = epoch != 0;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
int totalNbExamples = 0; int totalNbExamples = 0;
while (true) while (true)
...@@ -88,7 +75,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -88,7 +75,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config); goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "segmenter") if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{ {
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
...@@ -127,7 +114,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -127,7 +114,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(goldIndex); examplesPerState[config.getState()].addClass(goldIndex);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
transition->apply(config); transition->apply(config);
config.addToHistory(transition->getName()); config.addToHistory(transition->getName());
...@@ -147,7 +134,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -147,7 +134,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
} }
for (auto & it : examplesPerState) for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0); it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f) if (!f)
...@@ -240,7 +227,7 @@ float Trainer::evalOnDev(bool printAdvancement) ...@@ -240,7 +227,7 @@ float Trainer::evalOnDev(bool printAdvancement)
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
} }
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold) void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
{ {
if (currentExampleIndex-lastSavedIndex < (int)threshold) if (currentExampleIndex-lastSavedIndex < (int)threshold)
return; return;
...@@ -248,7 +235,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem: ...@@ -248,7 +235,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
return; return;
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1); auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename); torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex; lastSavedIndex = currentExampleIndex;
contexts.clear(); contexts.clear();
...@@ -340,3 +327,23 @@ void Trainer::fillDicts(SubConfig & config, bool debug) ...@@ -340,3 +327,23 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
} }
} }
Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
{
if (s == "ExtractGold")
return TrainAction::ExtractGold;
if (s == "ExtractDynamic")
return TrainAction::ExtractDynamic;
if (s == "DeleteExamples")
return TrainAction::DeleteExamples;
if (s == "ResetOptimizer")
return TrainAction::ResetOptimizer;
if (s == "ResetParameters")
return TrainAction::ResetParameters;
if (s == "Save")
return TrainAction::Save;
util::myThrow(fmt::format("unknown TrainAction '{}'", s));
return TrainAction::ExtractGold;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment