diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index a5ae77242df2e34c15c48bb0706ead4396336829..8bc9d3a8d38ef5ca99de202ba4973b243f783a41 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -17,6 +17,7 @@ class Dict static constexpr char const * unknownValueStr = "__unknownValue__"; static constexpr char const * nullValueStr = "__nullValue__"; static constexpr char const * emptyValueStr = "__emptyValue__"; + static constexpr char const * separatorValueStr = "__separatorValue__"; static constexpr std::size_t maxEntrySize = 5000; private : diff --git a/common/include/utf8string.hpp b/common/include/utf8string.hpp index ddd468eab5476900b0f31c2cc53be085bb0d9e9e..42fe5d438ca4327099bec706d85cc6ac8b84b262 100644 --- a/common/include/utf8string.hpp +++ b/common/include/utf8string.hpp @@ -14,6 +14,7 @@ class utf8char : public std::array<char, 4> public : utf8char(); + utf8char(const std::string & other); utf8char & operator=(char other); utf8char & operator=(const std::string & other); bool operator==(char other); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 6154dc1e64beca1e8fb7ff40a0b4bf896887fcb0..b75457ed71a5db07ba47b3702521dc84178d81d3 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -50,7 +50,7 @@ void Dict::readFromFile(const char * filename) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); elementsToIndexes[entryString] = entryIndex; - while (nbOccs.size() <= entryIndex) + while ((int)nbOccs.size() <= entryIndex) nbOccs.emplace_back(0); nbOccs[entryIndex] = nbOccsEntry; } @@ -73,6 +73,9 @@ int Dict::getIndexOrInsert(const std::string & element) if (element.empty()) return getIndexOrInsert(emptyValueStr); + if (element.size() == 1 and util::isSeparator(util::utf8char(element))) + return getIndexOrInsert(separatorValueStr); + const auto & found = elementsToIndexes.find(element); if (found == elementsToIndexes.end()) diff --git a/common/src/utf8string.cpp b/common/src/utf8string.cpp index 688b607ce626a8bf2bd4ecc24ed14c5f1cff3c46..d4301975992b133f815304ca098d47c8c69f03e8 100644 --- a/common/src/utf8string.cpp +++ b/common/src/utf8string.cpp @@ -7,6 +7,11 @@ util::utf8char::utf8char() val = '\0'; } +util::utf8char::utf8char(const std::string & other) +{ + *this = other; +} + util::utf8char & util::utf8char::operator=(char other) { (*this)[0] = other; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 96c6b2826e6cbf60b3be7be0bf6e566bbed038b7..68bd9749cd78ed97687a80543e8109a583b2d3da 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -87,13 +87,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D { for (int i = 0; i < leftWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (int i = 0; i <= rightWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); } @@ -159,7 +159,7 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); + elements.emplace_back(fmt::format("{}", asUtf8[i])); else elements.emplace_back(Dict::nullValueStr); } diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 84c9e79ec9f53732ace93103c9ad05422ac20844..734770e2093726a2dea3c694c3a42bec56c6470a 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -105,13 +105,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, { for (int i = 0; i < leftWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (int i = 0; i <= rightWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); } @@ -177,7 +177,7 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); + elements.emplace_back(fmt::format("{}", asUtf8[i])); else elements.emplace_back(Dict::nullValueStr); }