Skip to content
Snippets Groups Projects
Commit 67a09e30 authored by Franck Dary's avatar Franck Dary
Browse files

Finished implementing the dynamical oracle

parent 02a14376
Branches
No related tags found
No related merge requests found
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <dynet/timing.h> #include <dynet/timing.h>
#include <dynet/expr.h> #include <dynet/expr.h>
#include "FeatureModel.hpp" #include "FeatureModel.hpp"
#include "TrainingExamples.hpp"
/// @brief Multi Layer Perceptron. /// @brief Multi Layer Perceptron.
/// ///
......
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "TransitionMachine.hpp" #include "TransitionMachine.hpp"
#include "BD.hpp" #include "BD.hpp"
#include "Config.hpp" #include "Config.hpp"
#include "TrainingExamples.hpp"
/// @brief An object capable of training a TransitionMachine given a BD initialized with training examples. /// @brief An object capable of training a TransitionMachine given a BD initialized with training examples.
class Trainer class Trainer
...@@ -68,17 +67,6 @@ class Trainer ...@@ -68,17 +67,6 @@ class Trainer
private : private :
/// @brief Train the TransitionMachine one example at a time.
///
/// For each epoch all the Classifier of the TransitionMachine are fed all the
/// training examples, at the end of the epoch Classifier are evaluated on
/// the devBD if available, and each Classifier will be saved only if its score
/// on the current epoch is its all time best.\n
/// When a Classifier is saved that way, all the Dict involved are also saved.
/// @param nbIter The number of epochs.
/// @param mustShuffle Whether or not to shuffle examples between each epoch.
void trainUnbatched(int nbIter, bool mustShuffle);
/// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score. /// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score.
void printScoresAndSave(FILE * output); void printScoresAndSave(FILE * output);
...@@ -107,13 +95,17 @@ class Trainer ...@@ -107,13 +95,17 @@ class Trainer
/// @param debugMode If true, infos will be printed on stderr. /// @param debugMode If true, infos will be printed on stderr.
/// @param removeDuplicates If true, duplicates examples will be removed from the training set. /// @param removeDuplicates If true, duplicates examples will be removed from the training set.
Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode, bool removeDuplicates); Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode, bool removeDuplicates);
/// @brief Train the TransitionMachine.
/// @brief Train the TransitionMachine one example at a time.
/// ///
/// @param nbIter The number of training epochs. /// For each epoch all the Classifier of the TransitionMachine are fed all the
/// @param batchSize The size of each batch. /// training examples, at the end of the epoch Classifier are evaluated on
/// @param mustShuffle Will the examples be shuffled after every epoch ? /// the devBD if available, and each Classifier will be saved only if its score
/// @param batched True if we feed the training algorithm with batches of examples /// on the current epoch is its all time best.\n
void train(int nbIter, int batchSize, bool mustShuffle, bool batched); /// When a Classifier is saved that way, all the Dict involved are also saved.
/// @param nbIter The number of epochs.
/// @param mustShuffle Whether or not to shuffle examples between each epoch.
void train(int nbIter, bool mustShuffle);
}; };
#endif #endif
...@@ -26,6 +26,9 @@ std::map<std::string, float> Trainer::getScoreOnDev() ...@@ -26,6 +26,9 @@ std::map<std::string, float> Trainer::getScoreOnDev()
std::map< std::string, std::pair<int, int> > counts; std::map< std::string, std::pair<int, int> > counts;
if (debugMode)
fprintf(stderr, "Computing score on dev set\n");
while (!devConfig->isFinal()) while (!devConfig->isFinal())
{ {
TransitionMachine::State * currentState = tm.getCurrentState(); TransitionMachine::State * currentState = tm.getCurrentState();
...@@ -73,6 +76,12 @@ std::map<std::string, float> Trainer::getScoreOnDev() ...@@ -73,6 +76,12 @@ std::map<std::string, float> Trainer::getScoreOnDev()
std::string actionName = pAction; std::string actionName = pAction;
Action * action = classifier->getAction(actionName); Action * action = classifier->getAction(actionName);
if (debugMode)
{
devConfig->printForDebug(stderr);
fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str());
}
action->apply(*devConfig); action->apply(*devConfig);
TransitionMachine::Transition * transition = tm.getTransition(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName);
tm.takeTransition(transition); tm.takeTransition(transition);
...@@ -87,7 +96,7 @@ std::map<std::string, float> Trainer::getScoreOnDev() ...@@ -87,7 +96,7 @@ std::map<std::string, float> Trainer::getScoreOnDev()
return scores; return scores;
} }
void Trainer::trainUnbatched(int nbIter, bool mustShuffle) void Trainer::train(int nbIter, bool mustShuffle)
{ {
this->nbIter = nbIter; this->nbIter = nbIter;
...@@ -98,6 +107,7 @@ void Trainer::trainUnbatched(int nbIter, bool mustShuffle) ...@@ -98,6 +107,7 @@ void Trainer::trainUnbatched(int nbIter, bool mustShuffle)
for (curIter = 0; curIter < nbIter; curIter++) for (curIter = 0; curIter < nbIter; curIter++)
{ {
tm.reset(); tm.reset();
trainConfig.reset(); trainConfig.reset();
if(mustShuffle) if(mustShuffle)
...@@ -144,6 +154,7 @@ void Trainer::trainUnbatched(int nbIter, bool mustShuffle) ...@@ -144,6 +154,7 @@ void Trainer::trainUnbatched(int nbIter, bool mustShuffle)
if (zeroCostActions.empty()) if (zeroCostActions.empty())
{ {
fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO); fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO);
fprintf(stderr, "State : %s\n", currentState->name.c_str());
trainConfig.printForDebug(stderr); trainConfig.printForDebug(stderr);
exit(1); exit(1);
} }
...@@ -278,8 +289,3 @@ void Trainer::printScoresAndSave(FILE * output) ...@@ -278,8 +289,3 @@ void Trainer::printScoresAndSave(FILE * output)
printColumns(output, {names, acc, train, dev, savedStr}); printColumns(output, {names, acc, train, dev, savedStr});
} }
void Trainer::train(int nbIter, int batchSize, bool mustShuffle, bool batched)
{
trainUnbatched(nbIter, mustShuffle);
}
...@@ -43,14 +43,10 @@ po::options_description getOptionsDescription() ...@@ -43,14 +43,10 @@ po::options_description getOptionsDescription()
"Language you are working with") "Language you are working with")
("nbiter,n", po::value<int>()->default_value(5), ("nbiter,n", po::value<int>()->default_value(5),
"Number of training epochs (iterations)") "Number of training epochs (iterations)")
("batchsize,b", po::value<int>()->default_value(256),
"Size of each training batch (in number of examples)")
("seed,s", po::value<int>()->default_value(100), ("seed,s", po::value<int>()->default_value(100),
"The random seed that will initialize RNG") "The random seed that will initialize RNG")
("duplicates", po::value<bool>()->default_value(true), ("duplicates", po::value<bool>()->default_value(true),
"Remove identical training examples") "Remove identical training examples")
("batched", po::value<bool>()->default_value(true),
"Uses batch of training examples")
("shuffle", po::value<bool>()->default_value(true), ("shuffle", po::value<bool>()->default_value(true),
"Shuffle examples after each iteration"); "Shuffle examples after each iteration");
...@@ -115,10 +111,8 @@ int main(int argc, char * argv[]) ...@@ -115,10 +111,8 @@ int main(int argc, char * argv[])
std::string expName = vm["expName"].as<std::string>(); std::string expName = vm["expName"].as<std::string>();
std::string lang = vm["lang"].as<std::string>(); std::string lang = vm["lang"].as<std::string>();
int nbIter = vm["nbiter"].as<int>(); int nbIter = vm["nbiter"].as<int>();
int batchSize = vm["batchsize"].as<int>();
int randomSeed = vm["seed"].as<int>(); int randomSeed = vm["seed"].as<int>();
bool mustShuffle = vm["shuffle"].as<bool>(); bool mustShuffle = vm["shuffle"].as<bool>();
bool batched = vm["batched"].as<bool>();
bool removeDuplicates = vm["duplicates"].as<bool>(); bool removeDuplicates = vm["duplicates"].as<bool>();
bool debugMode = vm.count("debug") == 0 ? false : true; bool debugMode = vm.count("debug") == 0 ? false : true;
...@@ -159,7 +153,7 @@ int main(int argc, char * argv[]) ...@@ -159,7 +153,7 @@ int main(int argc, char * argv[])
} }
trainer->expPath = expPath; trainer->expPath = expPath;
trainer->train(nbIter, batchSize, mustShuffle, batched); trainer->train(nbIter, mustShuffle);
return 0; return 0;
} }
......
/// @file TrainingExamples.hpp
/// @author Franck Dary
/// @version 1.0
/// @date 2018-08-09
#ifndef TRAININGEXAMPLES__H
#define TRAININGEXAMPLES__H
#include <vector>
#include "FeatureModel.hpp"
class TrainingExamples
{
public :
std::vector<unsigned int> order;
std::vector<FeatureModel::FeatureDescription> examples;
std::vector<int> classes;
int nextIndex;
public :
TrainingExamples();
void add(const FeatureModel::FeatureDescription & example, int gold);
unsigned int size();
TrainingExamples getBatch(unsigned int batchSize);
void reset();
void shuffle();
void remove(int index);
void removeDuplicates();
};
#endif
...@@ -102,8 +102,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na ...@@ -102,8 +102,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
{c.stack.push_back(c.head);}; {c.stack.push_back(c.head);};
auto undo = [](Config & c, Action::BasicAction &) auto undo = [](Config & c, Action::BasicAction &)
{c.stack.pop_back();}; {c.stack.pop_back();};
auto appliable = [](Config &, Action::BasicAction &) auto appliable = [](Config & c, Action::BasicAction &)
{return true;}; {return c.head < (int)c.tapes[0].ref.size()-1;};
Action::BasicAction basicAction = Action::BasicAction basicAction =
{Action::BasicAction::Type::Push, "", apply, undo, appliable}; {Action::BasicAction::Type::Push, "", apply, undo, appliable};
......
...@@ -56,6 +56,9 @@ void Config::readInput(const std::string & filename) ...@@ -56,6 +56,9 @@ void Config::readInput(const std::string & filename)
tape.ref.emplace_back(); tape.ref.emplace_back();
tape.hyp.resize(tape.ref.size()); tape.hyp.resize(tape.ref.size());
tape.ref.emplace_back("0");
tape.hyp.emplace_back("");
} }
} }
...@@ -134,9 +137,9 @@ void Config::printAsOutput(FILE * output) ...@@ -134,9 +137,9 @@ void Config::printAsOutput(FILE * output)
if(bd.mustPrintLine(j)) if(bd.mustPrintLine(j))
lastToPrint = j; lastToPrint = j;
for (unsigned int i = 0; i < tapes[0].hyp.size(); i++) for (unsigned int i = 0; i < tapes[0].hyp.size() - 1; i++)
{
for (unsigned int j = 0; j < tapes.size(); j++) for (unsigned int j = 0; j < tapes.size(); j++)
{
if(bd.mustPrintLine(j)) if(bd.mustPrintLine(j))
fprintf(output, "%s%s", tapes[j][i].empty() ? "0" : tapes[j][i].c_str(), j == lastToPrint ? "\n" : "\t"); fprintf(output, "%s%s", tapes[j][i].empty() ? "0" : tapes[j][i].c_str(), j == lastToPrint ? "\n" : "\t");
} }
...@@ -144,13 +147,13 @@ void Config::printAsOutput(FILE * output) ...@@ -144,13 +147,13 @@ void Config::printAsOutput(FILE * output)
void Config::moveHead(int mvt) void Config::moveHead(int mvt)
{ {
if (mvt + head < (int)getTapeByInputCol(0).hyp.size()) if (head + mvt < (int)tapes[0].ref.size())
head += mvt; head += mvt;
} }
bool Config::isFinal() bool Config::isFinal()
{ {
return head >= (int)getTapeByInputCol(0).hyp.size() -1; return head >= (int)getTapeByInputCol(0).hyp.size()-1 && stack.empty();
} }
void Config::reset() void Config::reset()
...@@ -210,6 +213,8 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli ...@@ -210,6 +213,8 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli
previousIndex = i+1; previousIndex = i+1;
} }
std::pair<unsigned int, unsigned int> suffix = {delimiters.back().second+1, tape.ref.size()-1};
std::random_shuffle(delimiters.begin(), delimiters.end()); std::random_shuffle(delimiters.begin(), delimiters.end());
std::vector<Tape> newTapes = tapes; std::vector<Tape> newTapes = tapes;
...@@ -224,6 +229,12 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli ...@@ -224,6 +229,12 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli
std::copy(tapes[tape].ref.begin()+delimiter.first, tapes[tape].ref.begin()+delimiter.second+1, std::back_inserter(newTapes[tape].ref)); std::copy(tapes[tape].ref.begin()+delimiter.first, tapes[tape].ref.begin()+delimiter.second+1, std::back_inserter(newTapes[tape].ref));
std::copy(tapes[tape].hyp.begin()+delimiter.first, tapes[tape].hyp.begin()+delimiter.second+1, std::back_inserter(newTapes[tape].hyp)); std::copy(tapes[tape].hyp.begin()+delimiter.first, tapes[tape].hyp.begin()+delimiter.second+1, std::back_inserter(newTapes[tape].hyp));
} }
if (suffix.first <= suffix.second)
{
std::copy(tapes[tape].ref.begin()+suffix.first, tapes[tape].ref.begin()+suffix.second+1, std::back_inserter(newTapes[tape].ref));
std::copy(tapes[tape].hyp.begin()+suffix.first, tapes[tape].hyp.begin()+suffix.second+1, std::back_inserter(newTapes[tape].hyp));
}
} }
tapes = newTapes; tapes = newTapes;
......
...@@ -86,7 +86,7 @@ void Oracle::createDatabase() ...@@ -86,7 +86,7 @@ void Oracle::createDatabase()
}, },
[](Config & c, Oracle *, const std::string & action) [](Config & c, Oracle *, const std::string & action)
{ {
return action == "WRITE 0 POS " + c.getTape("POS").ref[c.head]; return action == "WRITE 0 POS " + c.getTape("POS").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
}))); })));
str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle( str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle(
...@@ -102,7 +102,7 @@ void Oracle::createDatabase() ...@@ -102,7 +102,7 @@ void Oracle::createDatabase()
}, },
[](Config & c, Oracle *, const std::string & action) [](Config & c, Oracle *, const std::string & action)
{ {
return action == "WRITE 0 MORPHO " + c.getTape("MORPHO").ref[c.head]; return action == "WRITE 0 MORPHO " + c.getTape("MORPHO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1;
}))); })));
str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle( str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle(
...@@ -211,7 +211,7 @@ void Oracle::createDatabase() ...@@ -211,7 +211,7 @@ void Oracle::createDatabase()
const std::string & lemma = c.getTape("LEMMA").ref[c.head]; const std::string & lemma = c.getTape("LEMMA").ref[c.head];
std::string rule = getRule(form, lemma); std::string rule = getRule(form, lemma);
return action == std::string("RULE LEMMA ON FORM ") + rule; return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1;
}))); })));
str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle( str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle(
...@@ -244,6 +244,8 @@ void Oracle::createDatabase() ...@@ -244,6 +244,8 @@ void Oracle::createDatabase()
sentenceStart++; sentenceStart++;
while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != "1") while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != "1")
sentenceEnd++; sentenceEnd++;
if (sentenceEnd == (int)eos.ref.size())
sentenceEnd--;
auto parts = split(action); auto parts = split(action);
......
#include "TrainingExamples.hpp"
#include <algorithm>
TrainingExamples::TrainingExamples()
{
nextIndex = 0;
}
void TrainingExamples::add(const FeatureModel::FeatureDescription & example, int gold)
{
examples.emplace_back(example);
classes.emplace_back(gold);
order.emplace_back(order.size());
}
unsigned int TrainingExamples::size()
{
return order.size();
}
TrainingExamples TrainingExamples::getBatch(unsigned int batchSize)
{
TrainingExamples batch;
for(unsigned int i = 0; i < batchSize && (unsigned)nextIndex < order.size()-1; i++)
{
batch.add(examples[order[nextIndex]], classes[order[nextIndex]]);
nextIndex++;
}
return batch;
}
void TrainingExamples::reset()
{
nextIndex = 0;
}
void TrainingExamples::shuffle()
{
std::random_shuffle(order.begin(), order.end());
}
void TrainingExamples::removeDuplicates()
{
std::map<std::string, int> lastIndex;
std::map<int, bool> toRemove;
for (unsigned int i = 0; i < examples.size(); i++)
{
std::string example = examples[i].toString();
if (lastIndex.count(example))
toRemove[i] = true;
else
lastIndex[example] = i;
}
for (auto & it : toRemove)
remove(it.first);
}
void TrainingExamples::remove(int index)
{
for (unsigned int i = 0; i < order.size(); i++)
if ((int)order[i] == index)
{
order[i] = order.back();
order.pop_back();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment