diff --git a/common/include/upper2lower b/common/include/upper2lower index dfde4b913f3d730ca65ff1c49559bc92762c4986..fc4e8a9fa98eb5e880067d6236b42858371cfbde 100644 --- a/common/include/upper2lower +++ b/common/include/upper2lower @@ -1,3 +1,5 @@ +#ifndef UPPER2LOWER +#define UPPER2LOWER namespace util { @@ -1793,5 +1795,8 @@ std::map<utf8char, utf8char> upper2lower {"𞤡", "𞥃"}, }; +std::map<utf8char, utf8char> lower2upper = inverseMap(upper2lower); + } +#endif diff --git a/common/include/util.hpp b/common/include/util.hpp index 9acb179402362c20edc7e904bc11f5a3ea243921..e0fcf2a60eb6e0ac8cb199f9d1ab61b3e6d40cab 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -79,6 +79,22 @@ bool doIfNameMatch(const std::regex & reg, std::string_view name, const std::fun bool choiceWithProbability(float probability); +std::string lower(const std::string & s); + +void lower(utf8string & s); + +utf8string lower(const utf8string & s); + +void lower(utf8char & c); + +std::string upper(const std::string & s); + +void upper(utf8string & s); + +utf8string upper(const utf8string & s); + +void upper(utf8char & c); + template <typename T> std::string join(const std::string & delim, const std::vector<T> elems) { @@ -101,7 +117,15 @@ std::string join(const std::string & delim, const boost::circular_buffer<T> elem return result; } -std::string lower(const std::string & s); +template <typename K, typename V> +std::map<V, K> inverseMap(const std::map<K, V> & model) +{ + std::map<V, K> res; + for (auto & it : model) + res[it.second] = it.first; + + return res; +} }; diff --git a/common/src/util.cpp b/common/src/util.cpp index a40362ca5b57470c7e5e8ee712166c7ae675e2b2..5d2be74858d0cc6f5ea7eab124f5f43a5a56a304 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -209,13 +209,66 @@ bool util::isUppercase(utf8char c) std::string util::lower(const std::string & s) { auto splited = util::splitAsUtf8(s); - for (auto & c : splited) + lower(splited); + + return fmt::format("{}", splited); +} + +void util::lower(utf8string & s) +{ + for (auto & c : s) { auto it = upper2lower.find(c); if (it != upper2lower.end()) c = it->second; } +} + +util::utf8string util::lower(const utf8string & s) +{ + auto result = s; + lower(result); + + return result; +} + +void util::lower(utf8char & c) +{ + auto it = upper2lower.find(c); + if (it != upper2lower.end()) + c = it->second; +} + +std::string util::upper(const std::string & s) +{ + auto splited = util::splitAsUtf8(s); + upper(splited); return fmt::format("{}", splited); } +void util::upper(utf8string & s) +{ + for (auto & c : s) + { + auto it = lower2upper.find(c); + if (it != lower2upper.end()) + c = it->second; + } +} + +util::utf8string util::upper(const utf8string & s) +{ + auto result = s; + upper(result); + + return result; +} + +void util::upper(utf8char & c) +{ + auto it = lower2upper.find(c); + if (it != lower2upper.end()) + c = it->second; +} + diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index f0e4bfca71d0f71c1a79f4caa2ac1ac114139e69..f9c437787e991b9a9b99eef335eedc7a0913c3f3 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -129,6 +129,8 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const return "UFeats"; if (colName == "FORM") return "Words"; + if (colName == "LEMMA") + return "Lemmas"; return colName; } diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index b511ab5046c8de52582ff0c59b599161d4be75c6..47308d3e34686134a5357d7e68cd728668eb6077 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -62,6 +62,11 @@ class Action static Action split(int index); static Action setRootUpdateIdsEmptyStackIfSentChanged(); static Action deprel(std::string value); + static Action transformSuffix(std::string fromCol, Config::Object fromObj, int fromIndex, std::string toCol, Config::Object toObj, int toIndex, util::utf8string toRemove, util::utf8string toAdd); + static Action uppercase(std::string col, Config::Object obj, int index); + static Action uppercaseIndex(std::string col, Config::Object obj, int index, int inIndex); + static Action lowercase(std::string col, Config::Object obj, int index); + static Action lowercaseIndex(std::string col, Config::Object obj, int index, int inIndex); }; #endif diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 422cbdd15af21b3c5ffeb60b70df410f755524a2..a5b43617410a1898469055ba340c2a8810ef7604 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -43,6 +43,11 @@ class Transition void initAddCharToWord(); void initSplitWord(std::vector<std::string> words); void initSplit(int index); + void initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule); + void initUppercase(std::string col, std::string obj, std::string index); + void initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); + void initLowercase(std::string col, std::string obj, std::string index); + void initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 5ebbda5d57926f0f3ec8b15fada4bb1a78f86a20..26fe98497cae24dfad0021c4f3a522d277c1be8d 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -822,3 +822,184 @@ Action Action::deprel(std::string value) return {Type::Write, apply, undo, appliable}; } +Action Action::transformSuffix(std::string fromCol, Config::Object fromObj, int fromIndex, std::string toCol, Config::Object toObj, int toIndex, util::utf8string toRemove, util::utf8string toAdd) +{ + auto apply = [fromCol, fromObj, fromIndex, toCol, toObj, toIndex, toRemove, toAdd](Config & config, Action & a) + { + int fromLineIndex = config.getRelativeWordIndex(fromObj, fromIndex); + int toLineIndex = config.getRelativeWordIndex(toObj, toIndex); + + if (toRemove.empty() and toAdd.empty()) + { + addHypothesis(toCol, toLineIndex, config.getAsFeature(fromCol, fromLineIndex).get()).apply(config, a); + return; + } + + util::utf8string res = util::splitAsUtf8(util::lower(config.getAsFeature(fromCol, fromLineIndex).get())); + for (unsigned int i = 0; i < toRemove.size(); i++) + res.pop_back(); + for (auto & letter : toAdd) + res.push_back(letter); + addHypothesis(toCol, toLineIndex, fmt::format("{}", res)).apply(config, a); + }; + + auto undo = [toCol, toObj, toIndex](Config & config, Action & a) + { + int toLineIndex = config.getRelativeWordIndex(toObj, toIndex); + addHypothesis(toCol, toLineIndex, "").undo(config, a); + }; + + auto appliable = [fromCol, fromObj, fromIndex, toCol, toObj, toIndex, toRemove, toAdd](const Config & config, const Action & a) + { + if (!config.hasRelativeWordIndex(fromObj, fromIndex) or !config.hasRelativeWordIndex(toObj, toIndex)) + return false; + + int fromLineIndex = config.getRelativeWordIndex(fromObj, fromIndex); + int toLineIndex = config.getRelativeWordIndex(toObj, toIndex); + util::utf8string res = util::splitAsUtf8(util::lower(config.getAsFeature(fromCol, fromLineIndex).get())); + if (res.size() < toRemove.size()) + return false; + + for (unsigned int i = 0; i < toRemove.size(); i++) + { + if (res.back() != toRemove[toRemove.size()-1-i]) + return false; + res.pop_back(); + } + + for (auto & letter : toAdd) + res.push_back(letter); + + return addHypothesis(toCol, toLineIndex, fmt::format("{}", res)).appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; +} + +Action Action::uppercase(std::string col, Config::Object obj, int index) +{ + auto apply = [col, obj, index](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + auto res = util::upper(config.getAsFeature(col, lineIndex)); + + addHypothesis(col, lineIndex, res).apply(config, a); + }; + + auto undo = [col, obj, index](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + addHypothesis(col, lineIndex, "").undo(config, a); + }; + + auto appliable = [col, obj, index](const Config & config, const Action & a) + { + if (!config.hasRelativeWordIndex(obj, index)) + return false; + + int lineIndex = config.getRelativeWordIndex(obj, index); + + return addHypothesis(col, lineIndex, "").appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; +} + +Action Action::uppercaseIndex(std::string col, Config::Object obj, int index, int inIndex) +{ + auto apply = [col, obj, index, inIndex](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + auto res = util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()); + util::upper(res[inIndex]); + + addHypothesis(col, lineIndex, fmt::format("{}", res)).apply(config, a); + }; + + auto undo = [col, obj, index](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + auto & value = config.getLastNotEmptyHyp(col, lineIndex); + auto res = util::splitAsUtf8(value.get()); + value = fmt::format("{}", res); + }; + + auto appliable = [col, obj, index, inIndex](const Config & config, const Action & a) + { + if (!config.hasRelativeWordIndex(obj, index)) + return false; + + int lineIndex = config.getRelativeWordIndex(obj, index); + + if ((int)util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()).size() <= inIndex) + return false; + + return addHypothesis(col, lineIndex, "").appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; +} + +Action Action::lowercase(std::string col, Config::Object obj, int index) +{ + auto apply = [col, obj, index](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + auto res = util::lower(config.getAsFeature(col, lineIndex)); + + addHypothesis(col, lineIndex, res).apply(config, a); + }; + + auto undo = [col, obj, index](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + addHypothesis(col, lineIndex, "").undo(config, a); + }; + + auto appliable = [col, obj, index](const Config & config, const Action & a) + { + if (!config.hasRelativeWordIndex(obj, index)) + return false; + + int lineIndex = config.getRelativeWordIndex(obj, index); + + return addHypothesis(col, lineIndex, "").appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; +} + +Action Action::lowercaseIndex(std::string col, Config::Object obj, int index, int inIndex) +{ + auto apply = [col, obj, index, inIndex](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + auto res = util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()); + util::lower(res[inIndex]); + + addHypothesis(col, lineIndex, fmt::format("{}", res)).apply(config, a); + }; + + auto undo = [col, obj, index](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(obj, index); + auto & value = config.getLastNotEmptyHyp(col, lineIndex); + auto res = util::splitAsUtf8(value.get()); + value = fmt::format("{}", res); + }; + + auto appliable = [col, obj, index, inIndex](const Config & config, const Action & a) + { + if (!config.hasRelativeWordIndex(obj, index)) + return false; + + int lineIndex = config.getRelativeWordIndex(obj, index); + + if ((int)util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()).size() <= inIndex) + return false; + + return addHypothesis(col, lineIndex, "").appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; +} diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 7589612dfa4bb80ae68521ed36765f165c8abea2..9aed16e4aeadf9459a37a05b913cb6adceb08a43 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -43,6 +43,16 @@ Transition::Transition(const std::string & name) [this](auto){initAddCharToWord();}}, {std::regex("SPLIT (.+)"), [this](auto sm){(initSplit(std::stoi(sm.str(1))));}}, + {std::regex("TRANSFORMSUFFIX (.+) ([bs])\\.(.+) (.+) ([bs])\\.(.+) (.+)"), + [this](auto sm){(initTransformSuffix(sm[1], sm[2], sm[3], sm[4], sm[5], sm[6], sm[7]));}}, + {std::regex("UPPERCASE (.+) ([bs])\\.(.+)"), + [this](auto sm){(initUppercase(sm[1], sm[2], sm[3]));}}, + {std::regex("UPPERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), + [this](auto sm){(initUppercaseIndex(sm[1], sm[2], sm[3], sm[4]));}}, + {std::regex("LOWERCASE (.+) ([bs])\\.(.+)"), + [this](auto sm){(initLowercase(sm[1], sm[2], sm[3]));}}, + {std::regex("LOWERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), + [this](auto sm){(initLowercaseIndex(sm[1], sm[2], sm[3], sm[4]));}}, {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"), [this](auto sm) { @@ -479,6 +489,137 @@ void Transition::initDeprel(std::string label) }; } +void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule) +{ + auto fromObjectValue = Config::str2object(fromObj); + int fromIndexValue = std::stoi(fromIndex); + auto toObjectValue = Config::str2object(toObj); + int toIndexValue = std::stoi(toIndex); + + std::string toRemove, toAdd; + util::utf8string toRemoveUtf8, toAddUtf8; + std::size_t index = 0; + for (index = 1; index < rule.size() and rule[index] != '\t'; index++) + toRemove.push_back(rule[index]); + index++; + for (; index < rule.size() and rule[index] != '\t'; index++) + toAdd.push_back(rule[index]); + + toRemoveUtf8 = util::splitAsUtf8(toRemove); + toAddUtf8 = util::splitAsUtf8(toAdd); + sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8)); + + cost = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config) + { + int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue); + int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue); + util::utf8string res = util::splitAsUtf8(util::lower(config.getAsFeature(fromCol, fromLineIndex).get())); + for (unsigned int i = 0; i < toRemoveUtf8.size(); i++) + res.pop_back(); + for (auto & letter : toAddUtf8) + res.push_back(letter); + + if (fmt::format("{}", res) == util::lower(config.getConst(toCol, toLineIndex, 0))) + return 0; + + return 1; + }; +} + +void Transition::initUppercase(std::string col, std::string obj, std::string index) +{ + auto objectValue = Config::str2object(obj); + int indexValue = std::stoi(index); + + sequence.emplace_back(Action::uppercase(col, objectValue, indexValue)); + + cost = [col, objectValue, indexValue](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); + auto & expectedValue = config.getConst(col, lineIndex, 0); + std::string currentValue = config.getAsFeature(col, lineIndex).get(); + if (expectedValue == currentValue) + return 1; + + if (util::upper(currentValue) == expectedValue) + return 0; + + return 1; + }; +} + +void Transition::initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex) +{ + auto objectValue = Config::str2object(obj); + int indexValue = std::stoi(index); + int inIndexValue = std::stoi(inIndex); + + sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue)); + + cost = [col, objectValue, indexValue, inIndexValue](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); + auto & expectedValue = config.getConst(col, lineIndex, 0); + std::string currentValue = config.getAsFeature(col, lineIndex).get(); + if (expectedValue == currentValue) + return 1; + + auto currentValueUtf8 = util::splitAsUtf8(currentValue); + util::upper(currentValueUtf8[inIndexValue]); + if (fmt::format("{}", currentValueUtf8) == expectedValue) + return 0; + + return 1; + }; +} + +void Transition::initLowercase(std::string col, std::string obj, std::string index) +{ + auto objectValue = Config::str2object(obj); + int indexValue = std::stoi(index); + + sequence.emplace_back(Action::lowercase(col, objectValue, indexValue)); + + cost = [col, objectValue, indexValue](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); + auto & expectedValue = config.getConst(col, lineIndex, 0); + std::string currentValue = config.getAsFeature(col, lineIndex).get(); + if (expectedValue == currentValue) + return 1; + + if (util::lower(currentValue) == expectedValue) + return 0; + + return 1; + }; +} + +void Transition::initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex) +{ + auto objectValue = Config::str2object(obj); + int indexValue = std::stoi(index); + int inIndexValue = std::stoi(inIndex); + + sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue)); + + cost = [col, objectValue, indexValue, inIndexValue](const Config & config) + { + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); + auto & expectedValue = config.getConst(col, lineIndex, 0); + std::string currentValue = config.getAsFeature(col, lineIndex).get(); + if (expectedValue == currentValue) + return 1; + + auto currentValueUtf8 = util::splitAsUtf8(currentValue); + util::lower(currentValueUtf8[inIndexValue]); + if (fmt::format("{}", currentValueUtf8) == expectedValue) + return 0; + + return 1; + }; +} + int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) { auto govIndex = config.getConst(Config::headColName, withIndex, 0); diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 2022e5266d002e590520b12fbf5e7ee4f5e8c8cb..d15064e007be26faa109ffa14a841d72915f3baf 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -9,8 +9,9 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir) auto stem = util::split(entry.path().stem().string(), '.')[0]; if (stem == "extracted") continue; - auto state = util::split(stem, '_')[0]; - auto splited = util::split(util::split(stem, '_')[1], '-'); + auto underSplit = util::split(stem, '_'); + auto state = util::join("_", std::vector<std::string>(underSplit.begin(), underSplit.end()-1)); + auto splited = util::split(underSplit.back(), '-'); int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]); size_ += fileSize; if (!holders.count(state))