Skip to content
Snippets Groups Projects
Commit 7cf7a308 authored by Franck Dary's avatar Franck Dary
Browse files

Dict dont insert separators

parent 8bec74de
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,7 @@ class Dict ...@@ -17,6 +17,7 @@ class Dict
static constexpr char const * unknownValueStr = "__unknownValue__"; static constexpr char const * unknownValueStr = "__unknownValue__";
static constexpr char const * nullValueStr = "__nullValue__"; static constexpr char const * nullValueStr = "__nullValue__";
static constexpr char const * emptyValueStr = "__emptyValue__"; static constexpr char const * emptyValueStr = "__emptyValue__";
static constexpr char const * separatorValueStr = "__separatorValue__";
static constexpr std::size_t maxEntrySize = 5000; static constexpr std::size_t maxEntrySize = 5000;
private : private :
......
...@@ -14,6 +14,7 @@ class utf8char : public std::array<char, 4> ...@@ -14,6 +14,7 @@ class utf8char : public std::array<char, 4>
public : public :
utf8char(); utf8char();
utf8char(const std::string & other);
utf8char & operator=(char other); utf8char & operator=(char other);
utf8char & operator=(const std::string & other); utf8char & operator=(const std::string & other);
bool operator==(char other); bool operator==(char other);
......
...@@ -50,7 +50,7 @@ void Dict::readFromFile(const char * filename) ...@@ -50,7 +50,7 @@ void Dict::readFromFile(const char * filename)
util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
elementsToIndexes[entryString] = entryIndex; elementsToIndexes[entryString] = entryIndex;
while (nbOccs.size() <= entryIndex) while ((int)nbOccs.size() <= entryIndex)
nbOccs.emplace_back(0); nbOccs.emplace_back(0);
nbOccs[entryIndex] = nbOccsEntry; nbOccs[entryIndex] = nbOccsEntry;
} }
...@@ -73,6 +73,9 @@ int Dict::getIndexOrInsert(const std::string & element) ...@@ -73,6 +73,9 @@ int Dict::getIndexOrInsert(const std::string & element)
if (element.empty()) if (element.empty())
return getIndexOrInsert(emptyValueStr); return getIndexOrInsert(emptyValueStr);
if (element.size() == 1 and util::isSeparator(util::utf8char(element)))
return getIndexOrInsert(separatorValueStr);
const auto & found = elementsToIndexes.find(element); const auto & found = elementsToIndexes.find(element);
if (found == elementsToIndexes.end()) if (found == elementsToIndexes.end())
......
...@@ -7,6 +7,11 @@ util::utf8char::utf8char() ...@@ -7,6 +7,11 @@ util::utf8char::utf8char()
val = '\0'; val = '\0';
} }
util::utf8char::utf8char(const std::string & other)
{
*this = other;
}
util::utf8char & util::utf8char::operator=(char other) util::utf8char & util::utf8char::operator=(char other)
{ {
(*this)[0] = other; (*this)[0] = other;
......
...@@ -87,13 +87,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D ...@@ -87,13 +87,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D
{ {
for (int i = 0; i < leftWindowRawInput; i++) for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++) for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i)) if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
} }
...@@ -159,7 +159,7 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D ...@@ -159,7 +159,7 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D
for (int i = 0; i < maxNbElements[colIndex]; i++) for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size()) if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); elements.emplace_back(fmt::format("{}", asUtf8[i]));
else else
elements.emplace_back(Dict::nullValueStr); elements.emplace_back(Dict::nullValueStr);
} }
......
...@@ -105,13 +105,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, ...@@ -105,13 +105,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
{ {
for (int i = 0; i < leftWindowRawInput; i++) for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++) for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i)) if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
} }
...@@ -177,7 +177,7 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, ...@@ -177,7 +177,7 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
for (int i = 0; i < maxNbElements[colIndex]; i++) for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size()) if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); elements.emplace_back(fmt::format("{}", asUtf8[i]));
else else
elements.emplace_back(Dict::nullValueStr); elements.emplace_back(Dict::nullValueStr);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment