Skip to content
Snippets Groups Projects
Commit 6b13e6ae authored by Franck Dary's avatar Franck Dary
Browse files

Added argument --oracleMode for macaon train, to transform a corpus into a list of transitions

parent 0b1c00a9
No related branches found
No related tags found
No related merge requests found
......@@ -62,6 +62,7 @@ class Trainer
Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
void extractActionSequence(BaseConfig & config);
void makeDataLoader(std::filesystem::path dir);
void makeDevDataLoader(std::filesystem::path dir);
float epoch(bool printAdvancement);
......
......@@ -46,7 +46,8 @@ po::options_description MacaonTrain::getOptionsDescription()
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("help,h", "Produce this help message");
("help,h", "Produce this help message")
("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions.");
desc.add(req).add(opt);
......@@ -136,6 +137,7 @@ int MacaonTrain::main()
auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto explorationThreshold = variables["explorationThreshold"].as<float>();
auto seed = variables["seed"].as<int>();
auto oracleMode = variables.count("oracleMode") == 0 ? false : true;
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
......@@ -169,6 +171,12 @@ int MacaonTrain::main()
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
if (oracleMode)
{
trainer.extractActionSequence(goldConfig);
exit(0);
}
float bestDevScore = computeDevScore ? -std::numeric_limits<float>::max() : std::numeric_limits<float>::max();
auto trainInfos = machinePath.parent_path() / "train.info";
......
......@@ -326,3 +326,90 @@ Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
return TrainAction::ExtractGold;
}
void Trainer::extractActionSequence(BaseConfig & config)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier(config.getState())->setState(config.getState());
int curSeq = 0;
int curSeqStartIndex = -1;
int curInputIndex = 0;
int curInputSeqSize = 0;
int curOutputSeqSize = 0;
int maxInputSeqSize = 0;
int maxOutputSeqSize = 0;
bool newSent = true;
std::vector<std::string> transitionsIndexes;
while (true)
{
if (config.hasCharacter(0))
curInputIndex = config.getCharacterIndex();
else
curInputIndex = config.getWordIndex();
if (curSeqStartIndex == -1 or newSent)
{
newSent = false;
curSeqStartIndex = curInputIndex;
}
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true);
Transition * transition = goldTransitions[0];
if (machine.getClassifier(config.getState())->isRegression())
util::myThrow("Regressions are not supported in extract action sequence mode");
transitionsIndexes.push_back(fmt::format("{}", machine.getTransitionSet(config.getState()).getTransitionIndex(transition)));
maxOutputSeqSize = std::max(maxOutputSeqSize, curOutputSeqSize++);
curInputSeqSize = -curSeqStartIndex + curInputIndex;
maxInputSeqSize = std::max(maxInputSeqSize, curInputSeqSize++);
if (util::split(transition->getName(), ' ')[0] == "EOS")
if (++curSeq % 3 == 0)
{
newSent = true;
std::string curSeq = "";
for (int i = curSeqStartIndex; i <= curInputIndex; i++)
curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i));
fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
curOutputSeqSize = 0;
curInputSeqSize = 0;
transitionsIndexes.clear();
}
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier(config.getState())->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
}
if (curSeqStartIndex != curInputIndex)
{
std::string curSeq = "";
for (int i = curSeqStartIndex; i <= curInputIndex; i++)
curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i));
fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
curOutputSeqSize = 0;
curInputSeqSize = 0;
curSeqStartIndex = curInputIndex;
}
fmt::print(stderr, "Longest output sequence : {}\n", maxOutputSeqSize);
fmt::print(stderr, "Longest input sequence : {}\n", maxInputSeqSize);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment