From 3489e3885fd1ceb262157040d399650d1000af68 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 12 Feb 2020 20:50:45 +0100 Subject: [PATCH] Config is now aware of what is predicted --- decoder/src/Decoder.cpp | 2 + reading_machine/include/Config.hpp | 8 +++- reading_machine/include/ReadingMachine.hpp | 1 + reading_machine/src/Config.cpp | 52 +++++++++++++++++++++- reading_machine/src/ReadingMachine.cpp | 5 +++ trainer/src/Trainer.cpp | 1 + 6 files changed, 66 insertions(+), 3 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 8e74076..543dbbf 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,6 +7,8 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize) { + config.addPredicted(machine.getPredicted()); + try { config.setState(machine.getStrategy().getInitialState()); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 5f34b24..82b3444 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -30,6 +30,7 @@ class Config private : std::vector<String> lines; + std::set<std::string> predicted; protected : @@ -61,6 +62,8 @@ class Config String & get(int colIndex, int lineIndex, int hypothesisIndex); const String & getConst(int colIndex, int lineIndex, int hypothesisIndex) const; String & getLastNotEmpty(int colIndex, int lineIndex); + String & getLastNotEmptyHyp(int colIndex, int lineIndex); + const String & getLastNotEmptyHypConst(int colIndex, int lineIndex) const; const String & getLastNotEmptyConst(int colIndex, int lineIndex) const; ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex); ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const; @@ -75,6 +78,8 @@ class Config const String & getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const; String & getLastNotEmpty(const std::string & colName, int lineIndex); const String & getLastNotEmptyConst(const std::string & colName, int lineIndex) const; + String & getLastNotEmptyHyp(const std::string & colName, int lineIndex); + const String & getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const; String & getFirstEmpty(int colIndex, int lineIndex); String & getFirstEmpty(const std::string & colName, int lineIndex); bool hasCharacter(int letterIndex) const; @@ -100,7 +105,8 @@ class Config void setState(const std::string state); bool stateIsDone() const; std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const; - + void addPredicted(const std::set<std::string> & predicted); + bool isPredicted(const std::string & colName) const; }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 3e9eaf5..11058e9 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -39,6 +39,7 @@ class ReadingMachine Classifier * getClassifier(); void save() const; bool isPredicted(const std::string & columnName) const; + const std::set<std::string> & getPredicted() const; }; #endif diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 2f6e7cf..bd640ff 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -70,7 +70,10 @@ void Config::print(FILE * dest) const continue; } for (unsigned int i = 0; i < getNbColumns()-1; i++) - fmt::print(dest, "{}{}", getLastNotEmptyConst(i, getFirstLineIndex()+line), i < getNbColumns()-2 ? "\t" : "\n"); + { + auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, getFirstLineIndex()+line) : getLastNotEmptyConst(i, getFirstLineIndex()+line); + fmt::print(dest, "{}{}", colContent, i < getNbColumns()-2 ? "\t" : "\n"); + } if (getLastNotEmptyConst(EOSColName, getFirstLineIndex()+line) == EOSSymbol1) fmt::print(dest, "\n"); } @@ -105,7 +108,10 @@ void Config::printForDebug(FILE * dest) const toPrint.emplace_back(); toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : ""); for (unsigned int i = 0; i < getNbColumns(); i++) - toPrint.back().emplace_back(util::shrink(getLastNotEmptyConst(i, line), maxWordLength)); + { + auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, line) : getLastNotEmptyConst(i, getFirstLineIndex()+line); + toPrint.back().emplace_back(util::shrink(colContent, maxWordLength)); + } } std::vector<std::size_t> colLength(toPrint[0].size(), 0); @@ -167,6 +173,17 @@ Config::String & Config::getLastNotEmpty(int colIndex, int lineIndex) return lines[baseIndex]; } +Config::String & Config::getLastNotEmptyHyp(int colIndex, int lineIndex) +{ + int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); + + for (int i = nbHypothesesMax; i > 0; --i) + if (!util::isEmpty(lines[baseIndex+i])) + return lines[baseIndex+i]; + + return lines[baseIndex+1]; +} + Config::String & Config::getFirstEmpty(int colIndex, int lineIndex) { int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); @@ -194,16 +211,37 @@ const Config::String & Config::getLastNotEmptyConst(int colIndex, int lineIndex) return lines[baseIndex]; } +const Config::String & Config::getLastNotEmptyHypConst(int colIndex, int lineIndex) const +{ + int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); + + for (int i = nbHypothesesMax; i > 0; --i) + if (!util::isEmpty(lines[baseIndex+i])) + return lines[baseIndex+i]; + + return lines[baseIndex+1]; +} + Config::String & Config::getLastNotEmpty(const std::string & colName, int lineIndex) { return getLastNotEmpty(getColIndex(colName), lineIndex); } +Config::String & Config::getLastNotEmptyHyp(const std::string & colName, int lineIndex) +{ + return getLastNotEmptyHyp(getColIndex(colName), lineIndex); +} + const Config::String & Config::getLastNotEmptyConst(const std::string & colName, int lineIndex) const { return getLastNotEmptyConst(getColIndex(colName), lineIndex); } +const Config::String & Config::getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const +{ + return getLastNotEmptyHypConst(getColIndex(colName), lineIndex); +} + Config::ValueIterator Config::getIterator(int colIndex, int lineIndex, int hypothesisIndex) { return lines.begin() + getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex) + hypothesisIndex; @@ -393,3 +431,13 @@ std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & return context; } +void Config::addPredicted(const std::set<std::string> & predicted) +{ + this->predicted.insert(predicted.begin(), predicted.end()); +} + +bool Config::isPredicted(const std::string & colName) const +{ + return predicted.count(colName); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 0a7838c..2b5cb61 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -119,3 +119,8 @@ bool ReadingMachine::isPredicted(const std::string & columnName) const return predicted.count(columnName); } +const std::set<std::string> & ReadingMachine::getPredicted() const +{ + return predicted; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 51d1a3b..6496aa3 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) void Trainer::createDataset(SubConfig & config) { + config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); std::vector<torch::Tensor> contexts; -- GitLab