Skip to content
Snippets Groups Projects
Commit 7cf7a308 authored by Franck Dary's avatar Franck Dary
Browse files

Dict dont insert separators

parent 8bec74de
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@ class Dict
static constexpr char const * unknownValueStr = "__unknownValue__";
static constexpr char const * nullValueStr = "__nullValue__";
static constexpr char const * emptyValueStr = "__emptyValue__";
static constexpr char const * separatorValueStr = "__separatorValue__";
static constexpr std::size_t maxEntrySize = 5000;
private :
......
......@@ -14,6 +14,7 @@ class utf8char : public std::array<char, 4>
public :
utf8char();
utf8char(const std::string & other);
utf8char & operator=(char other);
utf8char & operator=(const std::string & other);
bool operator==(char other);
......
......@@ -50,7 +50,7 @@ void Dict::readFromFile(const char * filename)
util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
elementsToIndexes[entryString] = entryIndex;
while (nbOccs.size() <= entryIndex)
while ((int)nbOccs.size() <= entryIndex)
nbOccs.emplace_back(0);
nbOccs[entryIndex] = nbOccsEntry;
}
......@@ -73,6 +73,9 @@ int Dict::getIndexOrInsert(const std::string & element)
if (element.empty())
return getIndexOrInsert(emptyValueStr);
if (element.size() == 1 and util::isSeparator(util::utf8char(element)))
return getIndexOrInsert(separatorValueStr);
const auto & found = elementsToIndexes.find(element);
if (found == elementsToIndexes.end())
......
......@@ -7,6 +7,11 @@ util::utf8char::utf8char()
val = '\0';
}
util::utf8char::utf8char(const std::string & other)
{
*this = other;
}
util::utf8char & util::utf8char::operator=(char other)
{
(*this)[0] = other;
......
......@@ -87,13 +87,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D
{
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
......@@ -159,7 +159,7 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
......
......@@ -105,13 +105,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
{
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
......@@ -177,7 +177,7 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment