diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index bdec405e338d3af97b03b14bdd960ca493b3d114..01bc7cc5b70e7caaef1d6d556c9ace4d4c1bc320 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -26,7 +26,7 @@ class Decoder Decoder(ReadingMachine & machine); void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); - void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); + void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getRecalls(const std::set<std::string> & colNames) const; diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 73d8984244ef05f069c989feaaaae0975888939a..474db0788b45b047b03b16905bfab0f9ed9f8dce 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -135,12 +135,13 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const return colName; } -void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted) +void Decoder::evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted) { evaluation.clear(); auto predictedTSV = (modelPath/"predicted_dev.tsv").string(); std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w"); - config.print(predictedTSVFile); + for (unsigned int i = 0; i < configs.size(); i++) + configs[i]->print(predictedTSVFile, i==0); std::fclose(predictedTSVFile); std::FILE * evalFromUD = popen(fmt::format("{} {} {} -x {}", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV, util::join(",", std::vector<std::string>(predicted.begin(), predicted.end()))).c_str(), "r"); diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 0bbf5aab3a6a41bf0835581accf861e409341d29..e4194e84005028034e55ccb79e11cd9d10d3b1cc 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -98,16 +98,19 @@ int MacaonDecode::main() if (inputTXT.empty()) { - BaseConfig config(mcd, inputTSV, util::utf8string()); + BaseConfig config(mcd, inputTSV, util::utf8string(), std::vector<int>()); decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); config.print(stdout, true); } else { auto inputs = util::readFileAsUtf8(inputTXT, lineByLine); - for (unsigned int i = 0; i < inputs.size(); i++) + for (int i = 0; i < (int)inputs.size(); i++) { - BaseConfig config(mcd, inputTSV, inputs[i]); + std::vector<int> sentIndexes; + if (inputs.size() > 1) + sentIndexes.emplace_back(i); + BaseConfig config(mcd, inputTSV, inputs[i], sentIndexes); decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); config.print(stdout, i == 0); } diff --git a/reading_machine/include/BaseConfig.hpp b/reading_machine/include/BaseConfig.hpp index d77d0a0720ed14439412258a6fa9813514aac173..a93b1804dfe23a470dd1794a8338a390a923a408 100644 --- a/reading_machine/include/BaseConfig.hpp +++ b/reading_machine/include/BaseConfig.hpp @@ -21,11 +21,11 @@ class BaseConfig : public Config void createColumns(std::string mcd); void readRawInput(std::string_view rawFilename); - void readTSVInput(std::string_view tsvFilename); + void readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes); public : - BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename); + BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes); BaseConfig(const BaseConfig & other); BaseConfig & operator=(const BaseConfig & other) = default; diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 298d4537b064b4126d7f67e04129abceeaf6529c..c9956d4886037ce308840293f6d800c0b7a3272d 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -43,125 +43,142 @@ void BaseConfig::readRawInput(std::string_view rawFilename) rawInput.replace(util::utf8char("\t"), util::utf8char(" ")); } -void BaseConfig::readTSVInput(std::string_view tsvFilename) +void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<int> & _sentencesIndexes) { - std::FILE * file = std::fopen(tsvFilename.data(), "r"); + auto sentencesIndexes = _sentencesIndexes; + if (sentencesIndexes.empty()) + sentencesIndexes.emplace_back(-1); - if (not file) - util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename)); - - char lineBuffer[100000]; - int inputLineIndex = 0; - bool inputHasBeenRead = false; - int usualNbCol = -1; - int nbMultiwords = 0; - std::vector<std::string> pendingComments; - - while (!std::feof(file)) + for (int targetSentenceIndex : sentencesIndexes) { - if (lineBuffer != std::fgets(lineBuffer, 100000, file)) - break; - - std::string_view line(lineBuffer); - inputLineIndex++; - - if (line.size() < 3) + std::FILE * file = std::fopen(tsvFilename.data(), "r"); + + if (not file) + util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename)); + + char lineBuffer[100000]; + int inputLineIndex = 0; + bool inputHasBeenRead = false; + int usualNbCol = -1; + int nbMultiwords = 0; + int curSentenceIndex = 0; + std::vector<std::string> pendingComments; + + while (!std::feof(file)) { - if (!inputHasBeenRead) - continue; - - get(EOSColName, getNbLines()-1, 0) = EOSSymbol1; - - try + if (lineBuffer != std::fgets(lineBuffer, 100000, file)) + break; + + std::string_view line(lineBuffer); + inputLineIndex++; + + if (line.size() < 3) { - std::map<std::string, int> id2index; - int firstIndexOfSequence = getNbLines()-1; - for (int i = (int)getNbLines()-1; has(0, i, 0); --i) - { - if (!isToken(i)) - continue; + if (!inputHasBeenRead) + continue; - if (i != (int)getNbLines()-1 && getConst(EOSColName, i, 0) == EOSSymbol1) - break; - - firstIndexOfSequence = i; - id2index[getConst(idColName, i, 0)] = i; - } - if (hasColIndex(headColName)) - for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i) + if (targetSentenceIndex == -1 or targetSentenceIndex == curSentenceIndex) + { + get(EOSColName, getNbLines()-1, 0) = EOSSymbol1; + + try { - if (!isToken(i)) - continue; - auto & head = get(headColName, i, 0); - if (head == "0") - head = "-1"; - else - head = std::to_string(id2index[head]); - } - } catch(std::exception & e) {util::myThrow(e.what());} - - continue; - } - - if (line.back() == '\n') - line.remove_suffix(1); - - if (line[0] == '#') - { - if (util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)#(?:(?:\\s|\\t)*)global.columns(?:(?:\\s|\\t)*)=(?:(?:\\s|\\t)*)(.+)"), line, [this](const auto & sm) + std::map<std::string, int> id2index; + int firstIndexOfSequence = getNbLines()-1; + for (int i = (int)getNbLines()-1; has(0, i, 0); --i) { - createColumns(util::join(",", util::split(util::strip(sm.str(1)), ' '))); - })) - continue; - - pendingComments.emplace_back(line); - continue; - } - - inputHasBeenRead = true; - - auto splited = util::split(line, '\t'); - if (usualNbCol == -1) - usualNbCol = splited.size(); - if ((int)splited.size() != usualNbCol) - util::myThrow(fmt::format("in file {} line {} is invalid, it shoud have {} columns", tsvFilename, line, usualNbCol)); - - // Ignore empty nodes - if (hasColIndex(idColName) && splited[getColIndex(idColName)].find('.') != std::string::npos) - continue; - - addLines(1); - get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; - if (nbMultiwords > 0) - { - get(isMultiColName, getNbLines()-1, 0) = EOSSymbol1; - nbMultiwords--; - } - else - get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; - - get(commentsColName, getNbLines()-1, 0) = util::join("\n", pendingComments); - pendingComments.clear(); + if (!isToken(i)) + continue; + + if (i != (int)getNbLines()-1 && getConst(EOSColName, i, 0) == EOSSymbol1) + break; + + firstIndexOfSequence = i; + id2index[getConst(idColName, i, 0)] = i; + } + if (hasColIndex(headColName)) + for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i) + { + if (!isToken(i)) + continue; + auto & head = get(headColName, i, 0); + if (head == "0") + head = "-1"; + else + head = std::to_string(id2index[head]); + } + } catch(std::exception & e) {util::myThrow(e.what());} + } - for (unsigned int i = 0; i < splited.size(); i++) - if (i < colIndex2Name.size() - extraColumns.size()) + curSentenceIndex += 1; + + continue; + } + + if (line.back() == '\n') + line.remove_suffix(1); + + if (line[0] == '#') { - std::string value = std::string(splited[i]); - get(i, getNbLines()-1, 0) = value; + if (util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)#(?:(?:\\s|\\t)*)global.columns(?:(?:\\s|\\t)*)=(?:(?:\\s|\\t)*)(.+)"), line, [this](const auto & sm) + { + createColumns(util::join(",", util::split(util::strip(sm.str(1)), ' '))); + })) + continue; + + if (targetSentenceIndex == -1 or targetSentenceIndex == curSentenceIndex) + pendingComments.emplace_back(line); + continue; } + + inputHasBeenRead = true; + + auto splited = util::split(line, '\t'); + if (usualNbCol == -1) + usualNbCol = splited.size(); + if ((int)splited.size() != usualNbCol) + util::myThrow(fmt::format("in file {} line {} is invalid, it shoud have {} columns", tsvFilename, line, usualNbCol)); + + if (targetSentenceIndex != -1 and targetSentenceIndex != curSentenceIndex) + continue; - if (isMultiword(getNbLines()-1)) - nbMultiwords = getMultiwordSize(getNbLines()-1)+1; - } - - std::fclose(file); + // Ignore empty nodes + if (hasColIndex(idColName) && splited[getColIndex(idColName)].find('.') != std::string::npos) + continue; + + addLines(1); + get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; + if (nbMultiwords > 0) + { + get(isMultiColName, getNbLines()-1, 0) = EOSSymbol1; + nbMultiwords--; + } + else + get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; + + get(commentsColName, getNbLines()-1, 0) = util::join("\n", pendingComments); + pendingComments.clear(); + + for (unsigned int i = 0; i < splited.size(); i++) + if (i < colIndex2Name.size() - extraColumns.size()) + { + std::string value = std::string(splited[i]); + get(i, getNbLines()-1, 0) = value; + } + + if (isMultiword(getNbLines()-1)) + nbMultiwords = getMultiwordSize(getNbLines()-1)+1; + } + + std::fclose(file); + } // End for targetSentenceIndex } BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(other.colIndex2Name), colName2Index(other.colName2Index) { } -BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput) +BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes) { if (tsvFilename.empty() and rawInput.empty()) util::myThrow("tsvFilename and rawInput can't be both empty"); @@ -171,8 +188,9 @@ BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util if (not rawInput.empty()) this->rawInput = rawInput; +// sentencesIndexes = index of sentences to keep. Empty vector == keep all sentences. if (not tsvFilename.empty()) - readTSVInput(tsvFilename); + readTSVInput(tsvFilename, sentencesIndexes); if (!has(0,wordIndex,0)) addLines(1); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index ddc5477fb58fce6e3e465e5549ea7f31ea285bbc..3268d77c1564bf6d1d90af4d35475ed060525410 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -11,7 +11,7 @@ Config::Config(const Config & other) this->appliableSplitTransitions = other.appliableSplitTransitions; this->appliableTransitions = other.appliableTransitions; - this->strategy.reset(new Strategy(*other.strategy)); + this->strategy.reset(other.strategy ? new Strategy(*other.strategy) : nullptr); this->rawInput = other.rawInput; this->wordIndex = other.wordIndex; @@ -45,6 +45,9 @@ void Config::resizeLines(unsigned int nbLines) bool Config::has(int colIndex, int lineIndex, int hypothesisIndex) const { +// fmt::print(stderr, "line index = {}\n", lineIndex); +// fmt::print(stderr, "first line index = {}\n", getFirstLineIndex()); +// fmt::print(stderr, "nbLines = {}\n", getNbLines()); return colIndex >= 0 && colIndex < (int)getNbColumns() && lineIndex >= (int)getFirstLineIndex() && lineIndex < (int)getFirstLineIndex() + (int)getNbLines() && hypothesisIndex >= 0 && hypothesisIndex < nbHypothesesMax+1; } @@ -230,7 +233,7 @@ void Config::printForDebug(FILE * dest) const fmt::print(dest, "{}", getLetter(index)); if (!util::isEmpty(rawInput)) fmt::print(dest, "\n{}\n", longLine); - fmt::print(dest, "State={}\nwordIndex={} characterIndex={}\nhistory=({})\nstack=({})\n", state, wordIndex, characterIndex, historyStr, stackStr); + fmt::print(dest, "State={}\nwordIndex={}/{} characterIndex={}/{}\nhistory=({})\nstack=({})\n", state, wordIndex, getNbLines(), characterIndex, rawInput.size(), historyStr, stackStr); fmt::print(dest, "{}\n", longLine); for (unsigned int line = 0; line < toPrint.size(); line++) diff --git a/reading_machine/src/SubConfig.cpp b/reading_machine/src/SubConfig.cpp index 7b7b65ead097e4bb1d0d33b0a5bbd9d8b55831a9..d9c191227e536239849ce348029ecff0ec2210dc 100644 --- a/reading_machine/src/SubConfig.cpp +++ b/reading_machine/src/SubConfig.cpp @@ -30,7 +30,7 @@ bool SubConfig::update() { unsigned int currentLastLineIndex = firstLineIndex + getNbLines(); - if (currentLastLineIndex >= model.getNbLines()-1) + if (getNbLines() > 0 and currentLastLineIndex >= model.getNbLines()-1) return false; std::size_t newFirstLineIndex = spanSize/2 >= wordIndex ? 0 : wordIndex - spanSize/2; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index e80ae0b4252f5b935aaebf4e27d12766a576de80..e15546a8642baafaec0507283f41736be23d5e60 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -313,7 +313,7 @@ void Transition::initSplitWord(std::vector<std::string> words) auto consumedWord = util::splitAsUtf8(words[0]); sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0)); sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0)); - sequence.emplace_back(Action::addLinesIfNeeded(words.size())); + sequence.emplace_back(Action::addLinesIfNeeded(words.size()-1)); sequence.emplace_back(Action::addCharsToCol("FORM", consumedWord.size(), Config::Object::Buffer, 0)); sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); for (unsigned int i = 1; i < words.size(); i++) diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index ef0f25ec18ab3be0ca9994a4f77dc30ec244adb7..dfa465e8a524134b55c7887892940fe0bc1a01cd 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -55,13 +55,13 @@ class Trainer private : - void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); + void extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); public : Trainer(ReadingMachine & machine, int batchSize); - void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); + void createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); void extractActionSequence(BaseConfig & config); void makeDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index a0492974a6370f4f48a1d105664e1bcad129154b..dba21cd7549380cc0f5213f75c71c80f95b329fa 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -46,6 +46,7 @@ po::options_description MacaonTrain::getOptionsDescription() ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()), "Max norm for the embeddings") ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.") + ("lineByLine", "Treat the TXT input as being one different text per line.") ("help,h", "Produce this help message") ("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions."); @@ -138,6 +139,7 @@ int MacaonTrain::main() auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto seed = variables["seed"].as<int>(); auto oracleMode = variables.count("oracleMode") == 0 ? false : true; + auto lineByLine = variables.count("lineByLine") == 0 ? false : true; WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0); WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0); @@ -165,27 +167,37 @@ int MacaonTrain::main() ReadingMachine machine(machinePath.string(), true); - util::utf8string trainRawInput; + std::vector<util::utf8string> trainRawInputs; if (!trainRawFile.empty()) - { - auto input = util::readFileAsUtf8(trainRawFile, false); - trainRawInput = input[0]; - } - BaseConfig goldConfig(mcd, trainTsvFile, trainRawInput); - util::utf8string devRawInput; - if (!devRawFile.empty()) - { - auto input = util::readFileAsUtf8(devRawFile, false); - devRawInput = input[0]; - } - BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput); + trainRawInputs = util::readFileAsUtf8(trainRawFile, lineByLine); + std::vector<BaseConfig> goldConfigs; + for (int i = 0; i < (int)trainRawInputs.size(); i++) + if (trainRawInputs.size() == 1) + goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>()); + else + goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i}); + if (goldConfigs.empty()) + goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>()); + std::vector<util::utf8string> devRawInputs; + if (!devRawFile.empty()) + devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine); + std::vector<BaseConfig> devGoldConfigs; + for (int i = 0; i < (int)devRawInputs.size(); i++) + if (devRawInputs.size() == 1) + devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>()); + else + devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i}); + if (devGoldConfigs.empty()) + devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>()); + Trainer trainer(machine, batchSize); Decoder decoder(machine); if (oracleMode) { - trainer.extractActionSequence(goldConfig); + //TODO : handle more than one + trainer.extractActionSequence(goldConfigs[0]); exit(0); } @@ -230,11 +242,11 @@ int MacaonTrain::main() if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) { machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open); - trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); + trainer.createDataset(goldConfigs, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); if (!computeDevScore) { machine.setDictsState(Dict::State::Closed); - trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); + trainer.createDataset(devGoldConfigs, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); } } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer)) @@ -263,9 +275,21 @@ int MacaonTrain::main() std::vector<std::pair<float,std::string>> devScores; if (computeDevScore) { - BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput); - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); - decoder.evaluate(devConfig, modelPath, devTsvFile, machine.getPredicted()); + std::vector<BaseConfig> devConfigs; + for (int i = 0; i < (int)devRawInputs.size(); i++) + if (devRawInputs.size() == 1) + devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>()); + else + devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i}); + if (devConfigs.empty()) + devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>()); + for (auto & devConfig : devConfigs) + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); + + std::vector<const Config *> devConfigsPtrs; + for (auto & devConfig : devConfigs) + devConfigsPtrs.emplace_back(&devConfig); + decoder.evaluate(devConfigsPtrs, modelPath, devTsvFile, machine.getPredicted()); devScores = decoder.getF1Scores(machine.getPredicted()); } else diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 5b30270013a9b1807c356fc1a6399ef1db9a6cd2..929d8620f530127c313378173d0c423de6e44151 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -17,18 +17,20 @@ void Trainer::makeDevDataLoader(std::filesystem::path dir) devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } -void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) +void Trainer::createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) { - SubConfig config(goldConfig, goldConfig.getNbLines()); + std::vector<SubConfig> configs; + for (auto & goldConfig : goldConfigs) + configs.emplace_back(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); - extractExamples(config, debug, dir, epoch, dynamicOracle, explorationThreshold); + extractExamples(configs, debug, dir, epoch, dynamicOracle, explorationThreshold); machine.saveDicts(); } -void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) +void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) { torch::AutoGradMode useGrad(false); @@ -37,11 +39,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::filesystem::create_directories(dir); - config.addPredicted(machine.getPredicted()); - config.setStrategy(machine.getStrategyDefinition()); - config.setState(config.getStrategy().getInitialState()); - machine.getClassifier(config.getState())->setState(config.getState()); - auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle); if (std::filesystem::exists(currentEpochAllExtractedFile)) @@ -51,132 +48,140 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p int totalNbExamples = 0; - while (true) + for (auto & config : configs) { - if (debug) - config.printForDebug(stderr); + config.addPredicted(machine.getPredicted()); + config.setStrategy(machine.getStrategyDefinition()); + config.setState(config.getStrategy().getInitialState()); + machine.getClassifier(config.getState())->setState(config.getState()); - if (machine.hasSplitWordTransitionSet()) - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + while (true) + { + if (debug) + config.printForDebug(stderr); - auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); - config.setAppliableTransitions(appliableTransitions); + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - std::vector<std::vector<long>> context; + auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); - try - { - context = machine.getClassifier(config.getState())->getNN()->extractContext(config); - } catch(std::exception & e) - { - util::myThrow(fmt::format("Failed to extract context : {}", e.what())); - } + std::vector<std::vector<long>> context; - Transition * transition = nullptr; + try + { + context = machine.getClassifier(config.getState())->getNN()->extractContext(config); + } catch(std::exception & e) + { + util::myThrow(fmt::format("Failed to extract context : {}", e.what())); + } - auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); + Transition * transition = nullptr; - Transition * goldTransition = goldTransitions[0]; - if (config.getState() == "parser") - goldTransitions[std::rand()%goldTransitions.size()]; + auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); - int nbClasses = machine.getTransitionSet(config.getState()).size(); + Transition * goldTransition = goldTransitions[0]; + if (config.getState() == "parser") + goldTransitions[std::rand()%goldTransitions.size()]; - float bestScore = -std::numeric_limits<float>::max(); + int nbClasses = machine.getTransitionSet(config.getState()).size(); - float entropy = 0.0; - - if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") - { - auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0); - entropy = NeuralNetworkImpl::entropy(prediction); - - std::vector<int> candidates; + float bestScore = -std::numeric_limits<float>::max(); - for (unsigned int i = 0; i < prediction.size(0); i++) + float entropy = 0.0; + + if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { - float score = prediction[i].item<float>(); - if (score > bestScore and appliableTransitions[i]) - bestScore = score; + auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); + auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0); + entropy = NeuralNetworkImpl::entropy(prediction); + + std::vector<int> candidates; + + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + if (score > bestScore and appliableTransitions[i]) + bestScore = score; + } + + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + if (appliableTransitions[i] and bestScore - score <= explorationThreshold) + candidates.emplace_back(i); + } + + transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); } - - for (unsigned int i = 0; i < prediction.size(0); i++) + else { - float score = prediction[i].item<float>(); - if (appliableTransitions[i] and bestScore - score <= explorationThreshold) - candidates.emplace_back(i); + transition = goldTransition; } - transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); - } - else - { - transition = goldTransition; - } - - if (!transition or !goldTransition) - { - config.printForDebug(stderr); - util::myThrow("No transition appliable !"); - } + if (!transition or !goldTransition) + { + config.printForDebug(stderr); + util::myThrow("No transition appliable !"); + } - std::vector<long> goldIndexes; - bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config); + std::vector<long> goldIndexes; + bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config); - if (machine.getClassifier(config.getState())->isRegression()) - { - entropy = 0.0; - auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName()); - auto splited = util::split(transition->getName(), ' '); - if (splited.size() != 3 or splited[0] != "WRITESCORE") - util::myThrow(errMessage); - auto col = splited[2]; - splited = util::split(splited[1], '.'); - if (splited.size() != 2) - util::myThrow(errMessage); - auto object = Config::str2object(splited[0]); - int index = std::stoi(splited[1]); - - float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0)); - goldIndexes.emplace_back(util::float2long(regressionTarget)); - } - else - { - for (auto & t : goldTransitions) - goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); + if (machine.getClassifier(config.getState())->isRegression()) + { + entropy = 0.0; + auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName()); + auto splited = util::split(transition->getName(), ' '); + if (splited.size() != 3 or splited[0] != "WRITESCORE") + util::myThrow(errMessage); + auto col = splited[2]; + splited = util::split(splited[1], '.'); + if (splited.size() != 2) + util::myThrow(errMessage); + auto object = Config::str2object(splited[0]); + int index = std::stoi(splited[1]); + + float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0)); + goldIndexes.emplace_back(util::float2long(regressionTarget)); + } + else + { + for (auto & t : goldTransitions) + goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); - } + } - if (!exampleIsBanned) - { - totalNbExamples += context.size(); - if (totalNbExamples >= (int)safetyNbExamplesMax) - util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); + if (!exampleIsBanned) + { + totalNbExamples += context.size(); + if (totalNbExamples >= (int)safetyNbExamplesMax) + util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); - examplesPerState[config.getState()].addContext(context); - examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); - examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); - } + examplesPerState[config.getState()].addContext(context); + examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); + examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); + } - config.setChosenActionScore(bestScore); + config.setChosenActionScore(bestScore); - transition->apply(config, entropy); - config.addToHistory(transition->getName()); + transition->apply(config, entropy); + config.addToHistory(transition->getName()); - auto movement = config.getStrategy().getMovement(config, transition->getName()); - if (debug) - fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); - if (movement == Strategy::endMovement) - break; + auto movement = config.getStrategy().getMovement(config, transition->getName()); + if (debug) + fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); + if (movement == Strategy::endMovement) + break; - config.setState(movement.first); - machine.getClassifier(config.getState())->setState(movement.first); - config.moveWordIndexRelaxed(movement.second); + config.setState(movement.first); + machine.getClassifier(config.getState())->setState(movement.first); + config.moveWordIndexRelaxed(movement.second); - if (config.needsUpdate()) - config.update(); - } + if (config.needsUpdate()) + config.update(); + } // End while true + } // End for on configs for (auto & it : examplesPerState) it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);