Skip to content
Snippets Groups Projects
Commit 9458cb45 authored by Franck Dary's avatar Franck Dary
Browse files

Added way to set dictCapacity for each dict in .dicts file

parent 3f47e8a0
Branches
No related tags found
No related merge requests found
...@@ -65,6 +65,8 @@ class Dict ...@@ -65,6 +65,8 @@ class Dict
Mode mode; Mode mode;
/// @brief The Policy of this Dict. /// @brief The Policy of this Dict.
Policy policy; Policy policy;
/// @brief Maximal number of entries.
int dictCapacity;
public : public :
...@@ -181,7 +183,8 @@ class Dict ...@@ -181,7 +183,8 @@ class Dict
/// @param name The name the new Dict. /// @param name The name the new Dict.
/// @param policy The Policy of the new Dict. /// @param policy The Policy of the new Dict.
/// @param filename The filename we will read the new Dict from. /// @param filename The filename we will read the new Dict from.
Dict(const std::string & name, Policy policy, const std::string & filename); /// @param dictCapacity Maximal number of entries.
Dict(const std::string & name, Policy policy, const std::string & filename, int dictCapacity);
/// @brief Get a pointer to the entry matching s. /// @brief Get a pointer to the entry matching s.
/// ///
/// This is used when we need a permanent pointer to a string matching s, /// This is used when we need a permanent pointer to a string matching s,
...@@ -202,7 +205,8 @@ class Dict ...@@ -202,7 +205,8 @@ class Dict
/// @param name The name of the Dict to construct. /// @param name The name of the Dict to construct.
/// @param dimension The dimension of the vectors in the new Dict. /// @param dimension The dimension of the vectors in the new Dict.
/// @param mode The Mode of the new Dict. /// @param mode The Mode of the new Dict.
Dict(const std::string & name, int dimension, Mode mode); /// @param dictCapacity Maximal number of entries.
Dict(const std::string & name, int dimension, Mode mode, int dictCapacity);
void init(dynet::ParameterCollection & pc); void init(dynet::ParameterCollection & pc);
......
...@@ -52,7 +52,7 @@ const char * Dict::mode2str(Mode mode) ...@@ -52,7 +52,7 @@ const char * Dict::mode2str(Mode mode)
return "Embeddings"; return "Embeddings";
} }
Dict::Dict(const std::string & name, int dimension, Mode mode) Dict::Dict(const std::string & name, int dimension, Mode mode, int dictCapacity)
{ {
this->isInit = false; this->isInit = false;
this->isTrained = false; this->isTrained = false;
...@@ -60,12 +60,13 @@ Dict::Dict(const std::string & name, int dimension, Mode mode) ...@@ -60,12 +60,13 @@ Dict::Dict(const std::string & name, int dimension, Mode mode)
this->filename = name; this->filename = name;
this->name = name; this->name = name;
this->oneHotIndex = 0; this->oneHotIndex = 0;
this->dictCapacity = dictCapacity;
this->mode = mode; this->mode = mode;
this->dimension = dimension; this->dimension = dimension;
} }
Dict::Dict(const std::string & name, Policy policy, const std::string & filename) Dict::Dict(const std::string & name, Policy policy, const std::string & filename, int dictCapacity)
{ {
this->isInit = false; this->isInit = false;
this->isTrained = true; this->isTrained = true;
...@@ -73,6 +74,7 @@ Dict::Dict(const std::string & name, Policy policy, const std::string & filename ...@@ -73,6 +74,7 @@ Dict::Dict(const std::string & name, Policy policy, const std::string & filename
this->filename = filename; this->filename = filename;
this->oneHotIndex = 0; this->oneHotIndex = 0;
this->name = name; this->name = name;
this->dictCapacity = dictCapacity;
} }
void Dict::init(dynet::ParameterCollection & pc) void Dict::init(dynet::ParameterCollection & pc)
...@@ -84,7 +86,7 @@ void Dict::init(dynet::ParameterCollection & pc) ...@@ -84,7 +86,7 @@ void Dict::init(dynet::ParameterCollection & pc)
} }
isInit = true; isInit = true;
this->lookupParameter = pc.add_lookup_parameters(ProgramParameters::dictCapacity, {(unsigned int)dimension}); this->lookupParameter = pc.add_lookup_parameters(dictCapacity, {(unsigned int)dimension});
addEntry(nullValueStr); addEntry(nullValueStr);
addEntry(unknownValueStr); addEntry(unknownValueStr);
} }
...@@ -139,7 +141,7 @@ void Dict::initFromFile(dynet::ParameterCollection & pc) ...@@ -139,7 +141,7 @@ void Dict::initFromFile(dynet::ParameterCollection & pc)
} }
ftVector.reset(new fasttext::Vector(dimension)); ftVector.reset(new fasttext::Vector(dimension));
this->lookupParameter = pc.add_lookup_parameters(ProgramParameters::dictCapacity, {(unsigned int)dimension}); this->lookupParameter = pc.add_lookup_parameters(dictCapacity, {(unsigned int)dimension});
} }
// If policy is FromZero, we don't need to read the current entries // If policy is FromZero, we don't need to read the current entries
...@@ -154,7 +156,7 @@ void Dict::initFromFile(dynet::ParameterCollection & pc) ...@@ -154,7 +156,7 @@ void Dict::initFromFile(dynet::ParameterCollection & pc)
if (readIndex == -1) // No parameters to read if (readIndex == -1) // No parameters to read
{ {
this->lookupParameter = pc.add_lookup_parameters(ProgramParameters::dictCapacity, {(unsigned int)dimension}); this->lookupParameter = pc.add_lookup_parameters(dictCapacity, {(unsigned int)dimension});
addEntry(nullValueStr); addEntry(nullValueStr);
addEntry(unknownValueStr); addEntry(unknownValueStr);
return; return;
...@@ -380,9 +382,9 @@ unsigned int Dict::addEntry(const std::string & s) ...@@ -380,9 +382,9 @@ unsigned int Dict::addEntry(const std::string & s)
auto index = str2index.size(); auto index = str2index.size();
str2index.emplace(s, index); str2index.emplace(s, index);
if ((int)str2index.size() >= ProgramParameters::dictCapacity) if ((int)str2index.size() >= dictCapacity)
{ {
fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Saving dict then aborting.\n", ERRINFO, name.c_str(), ProgramParameters::dictCapacity); fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Saving dict then aborting.\n", ERRINFO, name.c_str(), dictCapacity);
save(); save();
exit(1); exit(1);
} }
...@@ -420,9 +422,9 @@ unsigned int Dict::addEntry(const std::string & s, const std::vector<float> & em ...@@ -420,9 +422,9 @@ unsigned int Dict::addEntry(const std::string & s, const std::vector<float> & em
auto index = str2index.size(); auto index = str2index.size();
str2index.emplace(s, index); str2index.emplace(s, index);
if ((int)str2index.size() >= ProgramParameters::dictCapacity) if ((int)str2index.size() >= dictCapacity)
{ {
fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Aborting.\n", ERRINFO, name.c_str(), ProgramParameters::dictCapacity); fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Aborting.\n", ERRINFO, name.c_str(), dictCapacity);
exit(1); exit(1);
} }
...@@ -448,7 +450,7 @@ Dict * Dict::getDict(Policy policy, const std::string & filename) ...@@ -448,7 +450,7 @@ Dict * Dict::getDict(Policy policy, const std::string & filename)
if(it != str2dict.end()) if(it != str2dict.end())
return it->second.get(); return it->second.get();
Dict * dict = new Dict(util::removeSuffix(util::getFilenameFromPath(filename), ".dict"),policy, filename); Dict * dict = new Dict(util::removeSuffix(util::getFilenameFromPath(filename), ".dict"),policy, filename, ProgramParameters::dictCapacity);
str2dict.insert(std::make_pair(dict->name, std::unique_ptr<Dict>(dict))); str2dict.insert(std::make_pair(dict->name, std::unique_ptr<Dict>(dict)));
...@@ -479,6 +481,7 @@ void Dict::readDicts(const std::string & directory, const std::string & filename ...@@ -479,6 +481,7 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
char modeStr[1024]; char modeStr[1024];
char pretrained[1024]; char pretrained[1024];
int dim; int dim;
int dictCapacity;
File file(filename, "r"); File file(filename, "r");
FILE * fd = file.getDescriptor(); FILE * fd = file.getDescriptor();
...@@ -489,6 +492,8 @@ void Dict::readDicts(const std::string & directory, const std::string & filename ...@@ -489,6 +492,8 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
continue; continue;
if(trainMode) if(trainMode)
{
if(sscanf(buffer, "%s %d %s %s %d\n", name, &dim, modeStr, pretrained, &dictCapacity) != 5)
{ {
if(sscanf(buffer, "%s %d %s %s\n", name, &dim, modeStr, pretrained) != 4) if(sscanf(buffer, "%s %d %s %s\n", name, &dim, modeStr, pretrained) != 4)
{ {
...@@ -500,6 +505,9 @@ void Dict::readDicts(const std::string & directory, const std::string & filename ...@@ -500,6 +505,9 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
sprintf(pretrained, "_"); sprintf(pretrained, "_");
} }
dictCapacity = ProgramParameters::dictCapacity;
}
auto it = str2dict.find(name); auto it = str2dict.find(name);
if(it != str2dict.end()) if(it != str2dict.end())
{ {
...@@ -512,16 +520,16 @@ void Dict::readDicts(const std::string & directory, const std::string & filename ...@@ -512,16 +520,16 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
if (ProgramParameters::newTemplatePath == ProgramParameters::expPath) if (ProgramParameters::newTemplatePath == ProgramParameters::expPath)
{ {
std::string probableFilename = ProgramParameters::expPath + name + std::string(".dict"); std::string probableFilename = ProgramParameters::expPath + name + std::string(".dict");
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Modifiable, probableFilename)))); str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Modifiable, probableFilename, dictCapacity))));
} }
else else
{ {
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, dim, str2mode(modeStr))))); str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, dim, str2mode(modeStr), dictCapacity))));
} }
} }
else else
{ {
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained)))); str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained, dictCapacity))));
} }
} }
else else
...@@ -545,11 +553,11 @@ void Dict::readDicts(const std::string & directory, const std::string & filename ...@@ -545,11 +553,11 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
if (std::string(pretrained) == "_") if (std::string(pretrained) == "_")
{ {
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + name + ".dict")))); str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + name + ".dict", dictCapacity))));
} }
else else
{ {
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained)))); str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained, dictCapacity))));
} }
} }
} }
......
...@@ -113,7 +113,7 @@ int main(int argc, char * argv[]) ...@@ -113,7 +113,7 @@ int main(int argc, char * argv[])
dynet::initialize(argc, argv); dynet::initialize(argc, argv);
dynet::ParameterCollection pc; dynet::ParameterCollection pc;
Dict dict(outputFilename, embeddingsSize, Dict::Mode::Embeddings); Dict dict(outputFilename, embeddingsSize, Dict::Mode::Embeddings, nbEmbeddings+5);
dict.init(pc); dict.init(pc);
while (fscanf(input.getDescriptor(), "%[^\n]\n", buffer) == 1) while (fscanf(input.getDescriptor(), "%[^\n]\n", buffer) == 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment