diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 8e74076e10b750790dc0cee69f386bfe4733325a..543dbbf37b1326c3d1621d173e32fd72da510918 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,6 +7,8 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize) { + config.addPredicted(machine.getPredicted()); + try { config.setState(machine.getStrategy().getInitialState()); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 5f34b241d7e494a48b345b2d0713647f3e6000e3..82b344499c2f782e225bfcd0f298ab0eba54fd3b 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -30,6 +30,7 @@ class Config private : std::vector<String> lines; + std::set<std::string> predicted; protected : @@ -61,6 +62,8 @@ class Config String & get(int colIndex, int lineIndex, int hypothesisIndex); const String & getConst(int colIndex, int lineIndex, int hypothesisIndex) const; String & getLastNotEmpty(int colIndex, int lineIndex); + String & getLastNotEmptyHyp(int colIndex, int lineIndex); + const String & getLastNotEmptyHypConst(int colIndex, int lineIndex) const; const String & getLastNotEmptyConst(int colIndex, int lineIndex) const; ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex); ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const; @@ -75,6 +78,8 @@ class Config const String & getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const; String & getLastNotEmpty(const std::string & colName, int lineIndex); const String & getLastNotEmptyConst(const std::string & colName, int lineIndex) const; + String & getLastNotEmptyHyp(const std::string & colName, int lineIndex); + const String & getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const; String & getFirstEmpty(int colIndex, int lineIndex); String & getFirstEmpty(const std::string & colName, int lineIndex); bool hasCharacter(int letterIndex) const; @@ -100,7 +105,8 @@ class Config void setState(const std::string state); bool stateIsDone() const; std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const; - + void addPredicted(const std::set<std::string> & predicted); + bool isPredicted(const std::string & colName) const; }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 3e9eaf5158e1f1321409d1fd0825e2bd6e8ccbe1..11058e97eee70b29a110a341a9ff4816a367ae63 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -39,6 +39,7 @@ class ReadingMachine Classifier * getClassifier(); void save() const; bool isPredicted(const std::string & columnName) const; + const std::set<std::string> & getPredicted() const; }; #endif diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 2f6e7cf138875e225877990894964700e78fbc35..bd640ffbf40e41cbbf4e2f94e06173e6101ca9b8 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -70,7 +70,10 @@ void Config::print(FILE * dest) const continue; } for (unsigned int i = 0; i < getNbColumns()-1; i++) - fmt::print(dest, "{}{}", getLastNotEmptyConst(i, getFirstLineIndex()+line), i < getNbColumns()-2 ? "\t" : "\n"); + { + auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, getFirstLineIndex()+line) : getLastNotEmptyConst(i, getFirstLineIndex()+line); + fmt::print(dest, "{}{}", colContent, i < getNbColumns()-2 ? "\t" : "\n"); + } if (getLastNotEmptyConst(EOSColName, getFirstLineIndex()+line) == EOSSymbol1) fmt::print(dest, "\n"); } @@ -105,7 +108,10 @@ void Config::printForDebug(FILE * dest) const toPrint.emplace_back(); toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : ""); for (unsigned int i = 0; i < getNbColumns(); i++) - toPrint.back().emplace_back(util::shrink(getLastNotEmptyConst(i, line), maxWordLength)); + { + auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, line) : getLastNotEmptyConst(i, getFirstLineIndex()+line); + toPrint.back().emplace_back(util::shrink(colContent, maxWordLength)); + } } std::vector<std::size_t> colLength(toPrint[0].size(), 0); @@ -167,6 +173,17 @@ Config::String & Config::getLastNotEmpty(int colIndex, int lineIndex) return lines[baseIndex]; } +Config::String & Config::getLastNotEmptyHyp(int colIndex, int lineIndex) +{ + int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); + + for (int i = nbHypothesesMax; i > 0; --i) + if (!util::isEmpty(lines[baseIndex+i])) + return lines[baseIndex+i]; + + return lines[baseIndex+1]; +} + Config::String & Config::getFirstEmpty(int colIndex, int lineIndex) { int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); @@ -194,16 +211,37 @@ const Config::String & Config::getLastNotEmptyConst(int colIndex, int lineIndex) return lines[baseIndex]; } +const Config::String & Config::getLastNotEmptyHypConst(int colIndex, int lineIndex) const +{ + int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); + + for (int i = nbHypothesesMax; i > 0; --i) + if (!util::isEmpty(lines[baseIndex+i])) + return lines[baseIndex+i]; + + return lines[baseIndex+1]; +} + Config::String & Config::getLastNotEmpty(const std::string & colName, int lineIndex) { return getLastNotEmpty(getColIndex(colName), lineIndex); } +Config::String & Config::getLastNotEmptyHyp(const std::string & colName, int lineIndex) +{ + return getLastNotEmptyHyp(getColIndex(colName), lineIndex); +} + const Config::String & Config::getLastNotEmptyConst(const std::string & colName, int lineIndex) const { return getLastNotEmptyConst(getColIndex(colName), lineIndex); } +const Config::String & Config::getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const +{ + return getLastNotEmptyHypConst(getColIndex(colName), lineIndex); +} + Config::ValueIterator Config::getIterator(int colIndex, int lineIndex, int hypothesisIndex) { return lines.begin() + getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex) + hypothesisIndex; @@ -393,3 +431,13 @@ std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & return context; } +void Config::addPredicted(const std::set<std::string> & predicted) +{ + this->predicted.insert(predicted.begin(), predicted.end()); +} + +bool Config::isPredicted(const std::string & colName) const +{ + return predicted.count(colName); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 0a7838c5984e9c79ad0d058aab0ad4146b878b64..2b5cb61aa29e1a57e46e1c2c46f46d4e9ea538b9 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -119,3 +119,8 @@ bool ReadingMachine::isPredicted(const std::string & columnName) const return predicted.count(columnName); } +const std::set<std::string> & ReadingMachine::getPredicted() const +{ + return predicted; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 51d1a3bb878c885cc18b823e0f8049a48da5bcf2..6496aa3fa68421a96fc87ce52f68c60cb1d4acf8 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) void Trainer::createDataset(SubConfig & config) { + config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); std::vector<torch::Tensor> contexts;