Skip to content
Snippets Groups Projects
Commit 6b13e6ae authored by Franck Dary's avatar Franck Dary
Browse files

Added argument --oracleMode for macaon train, to transform a corpus into a list of transitions

parent 0b1c00a9
No related branches found
No related tags found
No related merge requests found
...@@ -62,6 +62,7 @@ class Trainer ...@@ -62,6 +62,7 @@ class Trainer
Trainer(ReadingMachine & machine, int batchSize); Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
void extractActionSequence(BaseConfig & config);
void makeDataLoader(std::filesystem::path dir); void makeDataLoader(std::filesystem::path dir);
void makeDevDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir);
float epoch(bool printAdvancement); float epoch(bool printAdvancement);
......
...@@ -46,7 +46,8 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -46,7 +46,8 @@ po::options_description MacaonTrain::getOptionsDescription()
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings") "Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("help,h", "Produce this help message"); ("help,h", "Produce this help message")
("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions.");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -136,6 +137,7 @@ int MacaonTrain::main() ...@@ -136,6 +137,7 @@ int MacaonTrain::main()
auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto explorationThreshold = variables["explorationThreshold"].as<float>();
auto seed = variables["seed"].as<int>(); auto seed = variables["seed"].as<int>();
auto oracleMode = variables.count("oracleMode") == 0 ? false : true;
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0); WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
...@@ -169,6 +171,12 @@ int MacaonTrain::main() ...@@ -169,6 +171,12 @@ int MacaonTrain::main()
Trainer trainer(machine, batchSize); Trainer trainer(machine, batchSize);
Decoder decoder(machine); Decoder decoder(machine);
if (oracleMode)
{
trainer.extractActionSequence(goldConfig);
exit(0);
}
float bestDevScore = computeDevScore ? -std::numeric_limits<float>::max() : std::numeric_limits<float>::max(); float bestDevScore = computeDevScore ? -std::numeric_limits<float>::max() : std::numeric_limits<float>::max();
auto trainInfos = machinePath.parent_path() / "train.info"; auto trainInfos = machinePath.parent_path() / "train.info";
......
...@@ -326,3 +326,90 @@ Trainer::TrainAction Trainer::str2TrainAction(const std::string & s) ...@@ -326,3 +326,90 @@ Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
return TrainAction::ExtractGold; return TrainAction::ExtractGold;
} }
void Trainer::extractActionSequence(BaseConfig & config)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier(config.getState())->setState(config.getState());
int curSeq = 0;
int curSeqStartIndex = -1;
int curInputIndex = 0;
int curInputSeqSize = 0;
int curOutputSeqSize = 0;
int maxInputSeqSize = 0;
int maxOutputSeqSize = 0;
bool newSent = true;
std::vector<std::string> transitionsIndexes;
while (true)
{
if (config.hasCharacter(0))
curInputIndex = config.getCharacterIndex();
else
curInputIndex = config.getWordIndex();
if (curSeqStartIndex == -1 or newSent)
{
newSent = false;
curSeqStartIndex = curInputIndex;
}
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true);
Transition * transition = goldTransitions[0];
if (machine.getClassifier(config.getState())->isRegression())
util::myThrow("Regressions are not supported in extract action sequence mode");
transitionsIndexes.push_back(fmt::format("{}", machine.getTransitionSet(config.getState()).getTransitionIndex(transition)));
maxOutputSeqSize = std::max(maxOutputSeqSize, curOutputSeqSize++);
curInputSeqSize = -curSeqStartIndex + curInputIndex;
maxInputSeqSize = std::max(maxInputSeqSize, curInputSeqSize++);
if (util::split(transition->getName(), ' ')[0] == "EOS")
if (++curSeq % 3 == 0)
{
newSent = true;
std::string curSeq = "";
for (int i = curSeqStartIndex; i <= curInputIndex; i++)
curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i));
fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
curOutputSeqSize = 0;
curInputSeqSize = 0;
transitionsIndexes.clear();
}
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier(config.getState())->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
}
if (curSeqStartIndex != curInputIndex)
{
std::string curSeq = "";
for (int i = curSeqStartIndex; i <= curInputIndex; i++)
curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i));
fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
curOutputSeqSize = 0;
curInputSeqSize = 0;
curSeqStartIndex = curInputIndex;
}
fmt::print(stderr, "Longest output sequence : {}\n", maxOutputSeqSize);
fmt::print(stderr, "Longest input sequence : {}\n", maxInputSeqSize);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment