diff --git a/maca_common/include/Dict.hpp b/maca_common/include/Dict.hpp index e3e93703a5d90c55ba1d65a6c8ea6e9901428924..ca2be0b3a194798fc84bcaef157a79cd80c04bdb 100644 --- a/maca_common/include/Dict.hpp +++ b/maca_common/include/Dict.hpp @@ -65,6 +65,8 @@ class Dict Mode mode; /// @brief The Policy of this Dict. Policy policy; + /// @brief Maximal number of entries. + int dictCapacity; public : @@ -181,7 +183,8 @@ class Dict /// @param name The name the new Dict. /// @param policy The Policy of the new Dict. /// @param filename The filename we will read the new Dict from. - Dict(const std::string & name, Policy policy, const std::string & filename); + /// @param dictCapacity Maximal number of entries. + Dict(const std::string & name, Policy policy, const std::string & filename, int dictCapacity); /// @brief Get a pointer to the entry matching s. /// /// This is used when we need a permanent pointer to a string matching s, @@ -202,7 +205,8 @@ class Dict /// @param name The name of the Dict to construct. /// @param dimension The dimension of the vectors in the new Dict. /// @param mode The Mode of the new Dict. - Dict(const std::string & name, int dimension, Mode mode); + /// @param dictCapacity Maximal number of entries. + Dict(const std::string & name, int dimension, Mode mode, int dictCapacity); void init(dynet::ParameterCollection & pc); diff --git a/maca_common/src/Dict.cpp b/maca_common/src/Dict.cpp index 902ad649d5f06db707db36727aa470a7512fd0e0..d5cf05d6377a65ea81013680e553186e3c860b46 100644 --- a/maca_common/src/Dict.cpp +++ b/maca_common/src/Dict.cpp @@ -52,7 +52,7 @@ const char * Dict::mode2str(Mode mode) return "Embeddings"; } -Dict::Dict(const std::string & name, int dimension, Mode mode) +Dict::Dict(const std::string & name, int dimension, Mode mode, int dictCapacity) { this->isInit = false; this->isTrained = false; @@ -60,12 +60,13 @@ Dict::Dict(const std::string & name, int dimension, Mode mode) this->filename = name; this->name = name; this->oneHotIndex = 0; + this->dictCapacity = dictCapacity; this->mode = mode; this->dimension = dimension; } -Dict::Dict(const std::string & name, Policy policy, const std::string & filename) +Dict::Dict(const std::string & name, Policy policy, const std::string & filename, int dictCapacity) { this->isInit = false; this->isTrained = true; @@ -73,6 +74,7 @@ Dict::Dict(const std::string & name, Policy policy, const std::string & filename this->filename = filename; this->oneHotIndex = 0; this->name = name; + this->dictCapacity = dictCapacity; } void Dict::init(dynet::ParameterCollection & pc) @@ -84,7 +86,7 @@ void Dict::init(dynet::ParameterCollection & pc) } isInit = true; - this->lookupParameter = pc.add_lookup_parameters(ProgramParameters::dictCapacity, {(unsigned int)dimension}); + this->lookupParameter = pc.add_lookup_parameters(dictCapacity, {(unsigned int)dimension}); addEntry(nullValueStr); addEntry(unknownValueStr); } @@ -139,7 +141,7 @@ void Dict::initFromFile(dynet::ParameterCollection & pc) } ftVector.reset(new fasttext::Vector(dimension)); - this->lookupParameter = pc.add_lookup_parameters(ProgramParameters::dictCapacity, {(unsigned int)dimension}); + this->lookupParameter = pc.add_lookup_parameters(dictCapacity, {(unsigned int)dimension}); } // If policy is FromZero, we don't need to read the current entries @@ -154,7 +156,7 @@ void Dict::initFromFile(dynet::ParameterCollection & pc) if (readIndex == -1) // No parameters to read { - this->lookupParameter = pc.add_lookup_parameters(ProgramParameters::dictCapacity, {(unsigned int)dimension}); + this->lookupParameter = pc.add_lookup_parameters(dictCapacity, {(unsigned int)dimension}); addEntry(nullValueStr); addEntry(unknownValueStr); return; @@ -380,9 +382,9 @@ unsigned int Dict::addEntry(const std::string & s) auto index = str2index.size(); str2index.emplace(s, index); - if ((int)str2index.size() >= ProgramParameters::dictCapacity) + if ((int)str2index.size() >= dictCapacity) { - fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Saving dict then aborting.\n", ERRINFO, name.c_str(), ProgramParameters::dictCapacity); + fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Saving dict then aborting.\n", ERRINFO, name.c_str(), dictCapacity); save(); exit(1); } @@ -420,9 +422,9 @@ unsigned int Dict::addEntry(const std::string & s, const std::vector<float> & em auto index = str2index.size(); str2index.emplace(s, index); - if ((int)str2index.size() >= ProgramParameters::dictCapacity) + if ((int)str2index.size() >= dictCapacity) { - fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Aborting.\n", ERRINFO, name.c_str(), ProgramParameters::dictCapacity); + fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Aborting.\n", ERRINFO, name.c_str(), dictCapacity); exit(1); } @@ -448,7 +450,7 @@ Dict * Dict::getDict(Policy policy, const std::string & filename) if(it != str2dict.end()) return it->second.get(); - Dict * dict = new Dict(util::removeSuffix(util::getFilenameFromPath(filename), ".dict"),policy, filename); + Dict * dict = new Dict(util::removeSuffix(util::getFilenameFromPath(filename), ".dict"),policy, filename, ProgramParameters::dictCapacity); str2dict.insert(std::make_pair(dict->name, std::unique_ptr<Dict>(dict))); @@ -479,6 +481,7 @@ void Dict::readDicts(const std::string & directory, const std::string & filename char modeStr[1024]; char pretrained[1024]; int dim; + int dictCapacity; File file(filename, "r"); FILE * fd = file.getDescriptor(); @@ -490,14 +493,19 @@ void Dict::readDicts(const std::string & directory, const std::string & filename if(trainMode) { - if(sscanf(buffer, "%s %d %s %s\n", name, &dim, modeStr, pretrained) != 4) + if(sscanf(buffer, "%s %d %s %s %d\n", name, &dim, modeStr, pretrained, &dictCapacity) != 5) { - if(sscanf(buffer, "%s %d %s\n", name, &dim, modeStr) != 3) + if(sscanf(buffer, "%s %d %s %s\n", name, &dim, modeStr, pretrained) != 4) { - fprintf(stderr, "ERROR (%s) : line \'%s\' do not describe a dictionary. Aborting.\n", ERRINFO, buffer); - exit(1); + if(sscanf(buffer, "%s %d %s\n", name, &dim, modeStr) != 3) + { + fprintf(stderr, "ERROR (%s) : line \'%s\' do not describe a dictionary. Aborting.\n", ERRINFO, buffer); + exit(1); + } + sprintf(pretrained, "_"); } - sprintf(pretrained, "_"); + + dictCapacity = ProgramParameters::dictCapacity; } auto it = str2dict.find(name); @@ -512,16 +520,16 @@ void Dict::readDicts(const std::string & directory, const std::string & filename if (ProgramParameters::newTemplatePath == ProgramParameters::expPath) { std::string probableFilename = ProgramParameters::expPath + name + std::string(".dict"); - str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Modifiable, probableFilename)))); + str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Modifiable, probableFilename, dictCapacity)))); } else { - str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, dim, str2mode(modeStr))))); + str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, dim, str2mode(modeStr), dictCapacity)))); } } else { - str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained)))); + str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained, dictCapacity)))); } } else @@ -545,11 +553,11 @@ void Dict::readDicts(const std::string & directory, const std::string & filename if (std::string(pretrained) == "_") { - str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + name + ".dict")))); + str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + name + ".dict", dictCapacity)))); } else { - str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained)))); + str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained, dictCapacity)))); } } } diff --git a/maca_common/src/macaon_convert_embeddings.cpp b/maca_common/src/macaon_convert_embeddings.cpp index 1cbe55c50ff734676dbd72758a5eeafc3ec8a784..23c378da7faeda19aa4dcb2ab54f17dd597694c5 100644 --- a/maca_common/src/macaon_convert_embeddings.cpp +++ b/maca_common/src/macaon_convert_embeddings.cpp @@ -113,7 +113,7 @@ int main(int argc, char * argv[]) dynet::initialize(argc, argv); dynet::ParameterCollection pc; - Dict dict(outputFilename, embeddingsSize, Dict::Mode::Embeddings); + Dict dict(outputFilename, embeddingsSize, Dict::Mode::Embeddings, nbEmbeddings+5); dict.init(pc); while (fscanf(input.getDescriptor(), "%[^\n]\n", buffer) == 1)