diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index a61eaf7c6637afb61525c81377252e386486d2c0..a936baddf4e020859f1b2b62b93a6f62d9ca85cd 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -62,6 +62,7 @@ class Trainer Trainer(ReadingMachine & machine, int batchSize); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); + void extractActionSequence(BaseConfig & config); void makeDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir); float epoch(bool printAdvancement); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index f19b546bd70313adb42cf49b4365b96b7748538e..0faad25d5c39f8d115507849d7980b584deb2469 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -46,7 +46,8 @@ po::options_description MacaonTrain::getOptionsDescription() ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), "Max norm for the embeddings") ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") - ("help,h", "Produce this help message"); + ("help,h", "Produce this help message") + ("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions."); desc.add(req).add(opt); @@ -136,6 +137,7 @@ int MacaonTrain::main() auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto seed = variables["seed"].as<int>(); + auto oracleMode = variables.count("oracleMode") == 0 ? false : true; WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0); @@ -169,6 +171,12 @@ int MacaonTrain::main() Trainer trainer(machine, batchSize); Decoder decoder(machine); + if (oracleMode) + { + trainer.extractActionSequence(goldConfig); + exit(0); + } + float bestDevScore = computeDevScore ? -std::numeric_limits<float>::max() : std::numeric_limits<float>::max(); auto trainInfos = machinePath.parent_path() / "train.info"; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 0f0ad1a4cfa9fbbb1543d0a58a8943243101d208..e781105205fd5fe92ef11e45283fec6d3fa76c23 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -326,3 +326,90 @@ Trainer::TrainAction Trainer::str2TrainAction(const std::string & s) return TrainAction::ExtractGold; } +void Trainer::extractActionSequence(BaseConfig & config) +{ + config.addPredicted(machine.getPredicted()); + config.setStrategy(machine.getStrategyDefinition()); + config.setState(config.getStrategy().getInitialState()); + machine.getClassifier(config.getState())->setState(config.getState()); + + int curSeq = 0; + int curSeqStartIndex = -1; + int curInputIndex = 0; + int curInputSeqSize = 0; + int curOutputSeqSize = 0; + int maxInputSeqSize = 0; + int maxOutputSeqSize = 0; + bool newSent = true; + std::vector<std::string> transitionsIndexes; + + while (true) + { + if (config.hasCharacter(0)) + curInputIndex = config.getCharacterIndex(); + else + curInputIndex = config.getWordIndex(); + + if (curSeqStartIndex == -1 or newSent) + { + newSent = false; + curSeqStartIndex = curInputIndex; + } + + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + + auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); + + auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true); + + Transition * transition = goldTransitions[0]; + + if (machine.getClassifier(config.getState())->isRegression()) + util::myThrow("Regressions are not supported in extract action sequence mode"); + + transitionsIndexes.push_back(fmt::format("{}", machine.getTransitionSet(config.getState()).getTransitionIndex(transition))); + maxOutputSeqSize = std::max(maxOutputSeqSize, curOutputSeqSize++); + curInputSeqSize = -curSeqStartIndex + curInputIndex; + maxInputSeqSize = std::max(maxInputSeqSize, curInputSeqSize++); + if (util::split(transition->getName(), ' ')[0] == "EOS") + if (++curSeq % 3 == 0) + { + newSent = true; + std::string curSeq = ""; + for (int i = curSeqStartIndex; i <= curInputIndex; i++) + curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i)); + fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes)); + curOutputSeqSize = 0; + curInputSeqSize = 0; + transitionsIndexes.clear(); + } + + transition->apply(config); + config.addToHistory(transition->getName()); + + auto movement = config.getStrategy().getMovement(config, transition->getName()); + if (movement == Strategy::endMovement) + break; + + config.setState(movement.first); + machine.getClassifier(config.getState())->setState(movement.first); + config.moveWordIndexRelaxed(movement.second); + } + + if (curSeqStartIndex != curInputIndex) + { + std::string curSeq = ""; + for (int i = curSeqStartIndex; i <= curInputIndex; i++) + curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i)); + fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes)); + curOutputSeqSize = 0; + curInputSeqSize = 0; + curSeqStartIndex = curInputIndex; + } + + fmt::print(stderr, "Longest output sequence : {}\n", maxOutputSeqSize); + fmt::print(stderr, "Longest input sequence : {}\n", maxInputSeqSize); +} +