From 7cf7a308a175d4ff89dc70291c28a86c816b4e99 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 19 Mar 2020 15:21:44 +0100 Subject: [PATCH] Dict dont insert separators --- common/include/Dict.hpp | 1 + common/include/utf8string.hpp | 1 + common/src/Dict.cpp | 5 ++++- common/src/utf8string.cpp | 5 +++++ torch_modules/src/CNNNetwork.cpp | 6 +++--- torch_modules/src/LSTMNetwork.cpp | 6 +++--- 6 files changed, 17 insertions(+), 7 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index a5ae772..8bc9d3a 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -17,6 +17,7 @@ class Dict static constexpr char const * unknownValueStr = "__unknownValue__"; static constexpr char const * nullValueStr = "__nullValue__"; static constexpr char const * emptyValueStr = "__emptyValue__"; + static constexpr char const * separatorValueStr = "__separatorValue__"; static constexpr std::size_t maxEntrySize = 5000; private : diff --git a/common/include/utf8string.hpp b/common/include/utf8string.hpp index ddd468e..42fe5d4 100644 --- a/common/include/utf8string.hpp +++ b/common/include/utf8string.hpp @@ -14,6 +14,7 @@ class utf8char : public std::array<char, 4> public : utf8char(); + utf8char(const std::string & other); utf8char & operator=(char other); utf8char & operator=(const std::string & other); bool operator==(char other); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 6154dc1..b75457e 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -50,7 +50,7 @@ void Dict::readFromFile(const char * filename) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); elementsToIndexes[entryString] = entryIndex; - while (nbOccs.size() <= entryIndex) + while ((int)nbOccs.size() <= entryIndex) nbOccs.emplace_back(0); nbOccs[entryIndex] = nbOccsEntry; } @@ -73,6 +73,9 @@ int Dict::getIndexOrInsert(const std::string & element) if (element.empty()) return getIndexOrInsert(emptyValueStr); + if (element.size() == 1 and util::isSeparator(util::utf8char(element))) + return getIndexOrInsert(separatorValueStr); + const auto & found = elementsToIndexes.find(element); if (found == elementsToIndexes.end()) diff --git a/common/src/utf8string.cpp b/common/src/utf8string.cpp index 688b607..d430197 100644 --- a/common/src/utf8string.cpp +++ b/common/src/utf8string.cpp @@ -7,6 +7,11 @@ util::utf8char::utf8char() val = '\0'; } +util::utf8char::utf8char(const std::string & other) +{ + *this = other; +} + util::utf8char & util::utf8char::operator=(char other) { (*this)[0] = other; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 96c6b28..68bd974 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -87,13 +87,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D { for (int i = 0; i < leftWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (int i = 0; i <= rightWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); } @@ -159,7 +159,7 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); + elements.emplace_back(fmt::format("{}", asUtf8[i])); else elements.emplace_back(Dict::nullValueStr); } diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 84c9e79..734770e 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -105,13 +105,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, { for (int i = 0; i < leftWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (int i = 0; i <= rightWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); else context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); } @@ -177,7 +177,7 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); + elements.emplace_back(fmt::format("{}", asUtf8[i])); else elements.emplace_back(Dict::nullValueStr); } -- GitLab