diff --git a/common/include/util.hpp b/common/include/util.hpp index 82620a56f08adcc1e8d354ad84a1af9c238d33c8..1a71486b5dcb0fe0c8c037e1b60ae3151e286d26 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -90,11 +90,13 @@ bool choiceWithProbability(float probability); std::string lower(const std::string & s); -void lower(utf8string & s); +void lowerInPlace(utf8string & s); utf8string lower(const utf8string & s); -void lower(utf8char & c); +utf8char lower(const utf8char & c); + +void lowerInPlace(utf8char & c); std::string upper(const std::string & s); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 10cc040d4ad9c272ad9548be1bf526b56f698760..2187a0ca0738b8c11c3ad4b7f81f1eef48f79c6b 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -9,6 +9,7 @@ Dict::Dict(State state) insert(emptyValueStr); insert(numberValueStr); insert(urlValueStr); + insert(separatorValueStr); } Dict::Dict(const char * filename, State state) @@ -78,8 +79,10 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref if (element.empty()) return getIndexOrInsert(emptyValueStr, prefix); - if (element.size() == 1 and util::isSeparator(util::utf8char(element))) + if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element))) + { return getIndexOrInsert(separatorValueStr, prefix); + } if (util::isNumber(element)) return getIndexOrInsert(numberValueStr, prefix); diff --git a/common/src/util.cpp b/common/src/util.cpp index ec795d8a4218fd94a679a748558dccdbcb6eda00..b084f81748bce7e5cb252a6e1b25ff1fba8d9312 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -32,7 +32,7 @@ bool util::isSeparator(utf8char c) bool util::isIllegal(utf8char c) { - return c == '\n' || c == '\t'; + return c == '\n' || c == '\t' || c == '\r'; } bool util::isNumber(const std::string & s) @@ -248,12 +248,12 @@ bool util::isUppercase(utf8char c) std::string util::lower(const std::string & s) { auto splited = util::splitAsUtf8(s); - lower(splited); + lowerInPlace(splited); return fmt::format("{}", splited); } -void util::lower(utf8string & s) +void util::lowerInPlace(utf8string & s) { for (auto & c : s) { @@ -265,19 +265,26 @@ void util::lower(utf8string & s) util::utf8string util::lower(const utf8string & s) { - auto result = s; - lower(result); + utf8string result = s; + lowerInPlace(result); return result; } -void util::lower(utf8char & c) +void util::lowerInPlace(utf8char & c) { auto it = upper2lower.find(c); if (it != upper2lower.end()) c = it->second; } +util::utf8char util::lower(const utf8char & c) +{ + auto res = c; + lowerInPlace(res); + return res; +} + std::string util::upper(const std::string & s) { auto splited = util::splitAsUtf8(s); diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index e376f500e244991f10cee0cb166c40a2c15d00fc..a0b23a33e0934a72822b350e7d2190af12b74fb4 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -59,7 +59,7 @@ class Action static Action attach(Config::Object governorObject, int governorIndex, Config::Object dependentObject, int dependentIndex); static Action addCharsToCol(const std::string & col, int n, Config::Object object, int relativeIndex); static Action ignoreCurrentCharacter(); - static Action consumeCharacterIndex(util::utf8string consumed); + static Action consumeCharacterIndex(const util::utf8string & consumed); static Action setMultiwordIds(int multiwordSize); static Action split(int index); static Action setRootUpdateIdsEmptyStackIfSentChanged(); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index dbd49e96c16f7c8e3ae6a453480074f470323704..42b55af78b19e03d6ccada9e2faf017c2d1d444a 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -120,7 +120,7 @@ class Config String & getFirstEmpty(int colIndex, int lineIndex); String & getFirstEmpty(const std::string & colName, int lineIndex); bool hasCharacter(int letterIndex) const; - util::utf8char getLetter(int letterIndex) const; + const util::utf8char & getLetter(int letterIndex) const; void addToHistory(const std::string & transition); void addToStack(std::size_t index); void popStack(); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 5d4dedb43d3995e5e143d5d75f01835835fc5045..c15880859ffa1e7f662a6bedd6ec2a468e246aab 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -76,7 +76,7 @@ Action Action::setMultiwordIds(int multiwordSize) return {Type::Write, apply, undo, appliable}; } -Action Action::consumeCharacterIndex(util::utf8string consumed) +Action Action::consumeCharacterIndex(const util::utf8string & consumed) { auto apply = [consumed](Config & config, Action &) { @@ -97,8 +97,14 @@ Action Action::consumeCharacterIndex(util::utf8string consumed) return false; for (unsigned int i = 0; i < consumed.size(); i++) - if (!config.hasCharacter(config.getCharacterIndex()+i) or config.getLetter(config.getCharacterIndex()+i) != consumed[i]) + { + if (!config.hasCharacter(config.getCharacterIndex()+i)) + return false; + const util::utf8char & letter = config.getLetter(config.getCharacterIndex()+i); + const util::utf8char & consumedLetter = consumed[i]; + if (util::lower(letter) != util::lower(consumedLetter)) return false; + } return true; }; diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 5b37d2c1edba4275e127bda9ce4157b2d66c3f39..5bd1f7df5a014ad08938625451b892bb48e6153e 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -384,7 +384,7 @@ bool Config::hasCharacter(int letterIndex) const return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput); } -util::utf8char Config::getLetter(int letterIndex) const +const util::utf8char & Config::getLetter(int letterIndex) const { return rawInput[letterIndex]; } diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 68d9dbbca5243ec03e03107a80d194fa1eb7d16b..31dfd5d7dbddcc09c1f7c432b81963f2ddc736e9 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -328,7 +328,7 @@ void Transition::initSplitWord(std::vector<std::string> words) return std::numeric_limits<int>::max(); for (unsigned int i = 0; i < words.size(); i++) - if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i]) + if (!config.has("FORM", config.getWordIndex()+i, 0) or util::lower(config.getConst("FORM", config.getWordIndex()+i, 0)) != words[i]) return std::numeric_limits<int>::max(); return 0; diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 66bd13d503e087bc888be5609583af68bacedf93..4ed7a52deb8919995b7d2cc81c449f8cfb20211a 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -64,13 +64,13 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, { for (int i = 0; i < leftWindow; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()-leftWindow+i)), "")); + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix)); else contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); for (int i = 0; i <= rightWindow; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()+i)), "")); + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix)); else contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); }