diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index a3e468d2b3f19e554808b6f1e1005f672e76e5ce..c15d63869d859c5bd9bddfc83e2f54e7e120570f 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -224,8 +224,6 @@ Action Action::setRoot() } } - auto & rootId = config.getLastNotEmptyConst(Config::idColName, rootIndex); - for (int i = config.getWordIndex()-1; true; --i) { if (!config.has(0, i, 0)) @@ -249,7 +247,7 @@ Action Action::setRoot() } else { - config.getFirstEmpty(Config::headColName, i) = rootId; + config.getFirstEmpty(Config::headColName, i) = std::to_string(rootIndex); } } } @@ -333,7 +331,7 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent lineIndex = config.getWordIndex() + governorIndex; else lineIndex = config.getStack(governorIndex); - addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, config.getLastNotEmptyConst(Config::idColName, lineIndex)).apply(config, a); + addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a); }; auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) @@ -357,7 +355,7 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent govLineIndex = config.getStack(governorIndex); } - return addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, config.getLastNotEmptyConst(Config::idColName, govLineIndex)).appliable(config, action); + return addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(govLineIndex)).appliable(config, action); }; return {Type::Write, apply, undo, appliable}; diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 94352e6c2d92a53f54be17f7a86de9de93da26b7..ce61016ed8a6ca399e63a5cbd44d75268300de01 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -62,6 +62,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) int inputLineIndex = 0; bool inputHasBeenRead = false; int usualNbCol = -1; + while (!std::feof(file)) { if (lineBuffer != std::fgets(lineBuffer, 100000, file)) @@ -77,6 +78,32 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) get(EOSColName, getNbLines()-1, 0) = EOSSymbol1; + try + { + std::map<std::string, int> id2index; + int firstIndexOfSequence = getNbLines()-1; + for (int i = (int)getNbLines()-1; has(0, i, 0); --i) + { + if (!isToken(i)) + continue; + + if (i != (int)getNbLines()-1 && getConst(EOSColName, i, 0) == EOSSymbol1) + break; + + firstIndexOfSequence = i; + id2index[getConst(idColName, i, 0)] = i; + } + for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i) + { + if (!isToken(i)) + continue; + auto & head = get(headColName, i, 0); + if (head == "0") + continue; + head = std::to_string(id2index[head]); + } + } catch(std::exception & e) {util::myThrow(e.what());} + continue; } @@ -104,7 +131,10 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) for (unsigned int i = 0; i < splited.size(); i++) if (i < colIndex2Name.size()) - get(i, getNbLines()-1, 0) = std::string(splited[i]); + { + std::string value = std::string(splited[i]); + get(i, getNbLines()-1, 0) = value; + } } std::fclose(file); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 1f4ba62792e19241c6cb7bed9d66d000138ce3be..58426e41fe497217756d0960bb6f4df2819d09b1 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -93,6 +93,12 @@ void Config::print(FILE * dest) const { auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, getFirstLineIndex()+line) : getLastNotEmptyConst(i, getFirstLineIndex()+line); std::string valueToPrint = colContent; + try + { + if (getColName(i) == headColName) + if (valueToPrint != "0") + valueToPrint = getLastNotEmptyConst(idColName, std::stoi(valueToPrint)); + } catch(std::exception &) {} if (valueToPrint.empty()) valueToPrint = "_"; @@ -137,7 +143,14 @@ void Config::printForDebug(FILE * dest) const for (unsigned int i = 0; i < getNbColumns(); i++) { auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, line) : getLastNotEmptyConst(i, line); - toPrint.back().emplace_back(util::shrink(colContent, maxWordLength)); + std::string toPrintCol = colContent; + try + { + if (getColName(i) == headColName) + if (toPrintCol != "0") + toPrintCol = getLastNotEmptyConst(idColName, std::stoi(toPrintCol)); + } catch(std::exception &) {} + toPrint.back().emplace_back(util::shrink(toPrintCol, maxWordLength)); } } diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 0489b8b54235167a8abce70c3f800867174273c8..ce3cb811d23ebec4875c08cdca1c5f654a30d8a3 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -87,11 +87,13 @@ void Transition::initShift() cost = [](const Config & config) { + if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) + return std::numeric_limits<int>::max(); + if (!config.isToken(config.getWordIndex())) return 0; - auto headGov = config.getConst(Config::headColName, config.getWordIndex(), 0); - auto headId = config.getConst(Config::idColName, config.getWordIndex(), 0); + auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0); int cost = 0; for (int i = 0; config.hasStack(i); ++i) @@ -100,10 +102,9 @@ void Transition::initShift() continue; auto stackIndex = config.getStack(i); - auto stackId = config.getConst(Config::idColName, stackIndex, 0); - auto stackGov = config.getConst(Config::headColName, stackIndex, 0); + auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - if (stackGov == headId || headGov == stackId) + if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex)) ++cost; } @@ -126,8 +127,7 @@ void Transition::initLeft(std::string label) int cost = 0; - auto idOfStack = config.getConst(Config::idColName, stackIndex, 0); - auto govIdOfStack = config.getConst(Config::headColName, stackIndex, 0); + auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); for (int i = wordIndex+1; config.has(0, i, 0); ++i) { @@ -137,15 +137,14 @@ void Transition::initLeft(std::string label) if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; - auto idOfOther = config.getConst(Config::idColName, i, 0); - auto govIdOfOther = config.getConst(Config::headColName, i, 0); + auto otherGovIndex = config.getConst(Config::headColName, i, 0); - if (govIdOfStack == idOfOther || govIdOfOther == idOfStack) + if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) ++cost; } //TODO : Check if this is necessary - if (govIdOfStack != config.getConst(Config::idColName, wordIndex, 0)) + if (stackGovIndex != std::to_string(wordIndex)) ++cost; if (label != config.getConst(Config::deprelColName, stackIndex, 0)) @@ -170,18 +169,16 @@ void Transition::initRight(std::string label) int cost = 0; - auto idOfBuffer = config.getConst(Config::idColName, wordIndex, 0); - auto govIdOfBuffer = config.getConst(Config::headColName, wordIndex, 0); + auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); for (int i = wordIndex; config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; - auto idOfOther = config.getConst(Config::idColName, i, 0); - auto govIdOfOther = config.getConst(Config::headColName, i, 0); + auto otherGovIndex = config.getConst(Config::headColName, i, 0); - if (govIdOfBuffer == idOfOther || govIdOfOther == idOfBuffer) + if (bufferGovIndex == std::to_string(i) || otherGovIndex == std::to_string(wordIndex)) ++cost; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) @@ -194,15 +191,14 @@ void Transition::initRight(std::string label) continue; auto otherStackIndex = config.getStack(i); - auto stackId = config.getConst(Config::idColName, otherStackIndex, 0); - auto stackGov = config.getConst(Config::headColName, otherStackIndex, 0); + auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0); - if (stackGov == idOfBuffer || govIdOfBuffer == stackId) + if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex)) ++cost; } //TODO : Check if this is necessary - if (govIdOfBuffer != config.getConst(Config::idColName, stackIndex, 0)) + if (bufferGovIndex != std::to_string(stackIndex)) ++cost; if (label != config.getConst(Config::deprelColName, wordIndex, 0)) @@ -226,18 +222,17 @@ void Transition::initReduce() int cost = 0; - auto idOfStack = config.getConst(Config::idColName, config.getStack(0), 0); - auto govIdOfStack = config.getConst(Config::headColName, config.getStack(0), 0); + auto stackIndex = config.getStack(0); + auto stackGovIndex = config.getConst(Config::headColName, config.getStack(0), 0); for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; - auto idOfOther = config.getConst(Config::idColName, i, 0); - auto govIdOfOther = config.getConst(Config::headColName, i, 0); + auto otherGovIndex = config.getConst(Config::headColName, i, 0); - if (govIdOfStack == idOfOther || govIdOfOther == idOfStack) + if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) ++cost; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) @@ -266,14 +261,10 @@ void Transition::initEOS() if (!config.isToken(config.getStack(0))) return std::numeric_limits<int>::max(); - int cost = 0; - if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1) - cost += 100; + return std::numeric_limits<int>::max(); - auto topStackIndex = config.getStack(0); - auto topStackGov = config.getConst(Config::headColName, topStackIndex, 0); - auto topStackGovPred = config.getLastNotEmptyHypConst(Config::headColName, topStackIndex); + int cost = 0; --cost; for (int i = 0; config.hasStack(i); ++i) @@ -282,10 +273,9 @@ void Transition::initEOS() continue; auto otherStackIndex = config.getStack(i); - auto stackId = config.getConst(Config::idColName, otherStackIndex, 0); - auto stackGovPred = config.getLastNotEmptyHypConst(Config::headColName, otherStackIndex); + auto otherStackGovPred = config.getLastNotEmptyHypConst(Config::headColName, otherStackIndex); - if (util::isEmpty(stackGovPred)) + if (util::isEmpty(otherStackGovPred)) ++cost; }