Skip to content
Snippets Groups Projects
Commit 086f52fe authored by Franck Dary's avatar Franck Dary
Browse files

Added support for pre-trained word embeddings

parent 1e60e1ba
Branches
No related tags found
No related merge requests found
......@@ -127,7 +127,10 @@ class Dict
/// When getDict is called it will find the requested Dict here,
/// or construct it.
static std::map< std::string, std::unique_ptr<Dict> > str2dict;
/// @brief Whether or not the init function has been called.
bool isInit;
/// @brief Wheter this Dict is a new one or a trained one.
bool isTrained;
private :
......@@ -168,9 +171,10 @@ class Dict
void initFromFile(dynet::ParameterCollection & pc);
/// @brief Read and construct a new Dict from a file.
///
/// @param name The name the new Dict.
/// @param policy The Policy of the new Dict.
/// @param filename The filename we will read the new Dict from.
Dict(Policy policy, const std::string & filename);
Dict(const std::string & name, Policy policy, const std::string & filename);
/// @brief Construct a new Dict.
///
/// @param name The name of the Dict to construct.
......@@ -224,7 +228,6 @@ class Dict
/// @param namePrefix The prefix that Dict must match.
static void createFiles(const std::string & directory, const std::string & namePrefix);
static void initDicts(dynet::ParameterCollection & pc, const std::string & namePrefix);
static void initDictsFromFile(dynet::ParameterCollection & pc, const std::string & namePrefix);
/// @brief Delete all Dicts.
static void deleteDicts();
/// @brief Save the current Dict in the corresponding file.
......
......@@ -48,6 +48,7 @@ const char * Dict::mode2str(Mode mode)
Dict::Dict(const std::string & name, int dimension, Mode mode)
{
this->isInit = false;
this->isTrained = false;
this->policy = Policy::FromZero;
this->filename = name;
this->name = name;
......@@ -57,13 +58,14 @@ Dict::Dict(const std::string & name, int dimension, Mode mode)
this->dimension = dimension;
}
Dict::Dict(Policy policy, const std::string & filename)
Dict::Dict(const std::string & name, Policy policy, const std::string & filename)
{
this->isInit = false;
this->isTrained = true;
this->policy = policy;
this->filename = filename;
this->oneHotIndex = 0;
this->name = removeSuffix(getFilenameFromPath(filename), ".dict");
this->name = name;
}
void Dict::init(dynet::ParameterCollection & pc)
......@@ -103,7 +105,6 @@ void Dict::initFromFile(dynet::ParameterCollection & pc)
if(fscanf(fd, "%s\n%d\n%s\n", b1, &dimension, b2) != 3)
badFormatAndAbort(ERRINFO);
name = b1;
mode = str2mode(b2);
isInit = true;
......@@ -145,24 +146,30 @@ void Dict::saveDicts(const std::string & directory, const std::string & namePref
for (auto & it : str2dict)
{
if(!strncmp(it.first.c_str(), namePrefix.c_str(), namePrefix.size()))
{
if (!it.second->isTrained)
{
it.second->filename = directory + it.second->name + ".dict";
it.second->save();
}
}
}
}
void Dict::createFiles(const std::string & directory, const std::string & namePrefix)
{
for (auto & it : str2dict)
{
if(!strncmp(it.first.c_str(), namePrefix.c_str(), namePrefix.size()))
{
if (!it.second->isTrained)
{
it.second->filename = directory + it.second->name + ".dict";
it.second->createFile();
}
}
}
}
void Dict::initDicts(dynet::ParameterCollection & pc, const std::string & namePrefix)
{
......@@ -170,17 +177,9 @@ void Dict::initDicts(dynet::ParameterCollection & pc, const std::string & namePr
{
if(!strncmp(it.first.c_str(), namePrefix.c_str(), namePrefix.size()))
{
if (!it.second->isTrained)
it.second->init(pc);
}
}
}
void Dict::initDictsFromFile(dynet::ParameterCollection & pc, const std::string & namePrefix)
{
for (auto & it : str2dict)
{
if(!strncmp(it.first.c_str(), namePrefix.c_str(), namePrefix.size()))
{
else
it.second->initFromFile(pc);
}
}
......@@ -379,7 +378,7 @@ Dict * Dict::getDict(Policy policy, const std::string & filename)
if(it != str2dict.end())
return it->second.get();
Dict * dict = new Dict(policy, filename);
Dict * dict = new Dict(removeSuffix(getFilenameFromPath(filename), ".dict"),policy, filename);
str2dict.insert(std::make_pair(dict->name, std::unique_ptr<Dict>(dict)));
......@@ -393,6 +392,10 @@ Dict * Dict::getDict(const std::string & name)
if(it != str2dict.end())
return it->second.get();
it = str2dict.find(name);
if(it != str2dict.end())
return it->second.get();
fprintf(stderr, "ERROR (%s) : dictionary \'%s\' does not exists. Aborting.\n", ERRINFO, relativeName.c_str());
exit(1);
......@@ -404,6 +407,7 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
char buffer[1024];
char name[1024];
char modeStr[1024];
char pretrained[1024];
int dim;
File file(filename, "r");
......@@ -416,7 +420,7 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
if(trainMode)
{
if(sscanf(buffer, "%s %d %s", name, &dim, modeStr) != 3)
if(sscanf(buffer, "%s %d %s %s\n", name, &dim, modeStr, pretrained) != 4)
{
fprintf(stderr, "ERROR (%s) : line \'%s\' do not describe a dictionary. Aborting.\n", ERRINFO, buffer);
exit(1);
......@@ -429,11 +433,18 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
exit(1);
}
if (std::string(pretrained) == "_")
{
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, dim, str2mode(modeStr)))));
}
else
{
if(sscanf(buffer, "%s", name) != 1)
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained))));
}
}
else
{
if(sscanf(buffer, "%s %d %s %s\n", name, &dim, modeStr, pretrained) != 4)
{
fprintf(stderr, "ERROR (%s) : line \'%s\' do not describe a dictionary. Aborting.\n", ERRINFO, buffer);
exit(1);
......@@ -446,7 +457,14 @@ void Dict::readDicts(const std::string & directory, const std::string & filename
exit(1);
}
getDict(Policy::Final, directory + name + ".dict");
if (std::string(pretrained) == "_")
{
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + name + ".dict"))));
}
else
{
str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, Policy::Final, directory + pretrained))));
}
}
}
}
......
......@@ -123,7 +123,7 @@ void Classifier::initClassifier(Config & config)
if(!trainMode)
{
mlp.reset(new MLP(ProgramParameters::expPath + name + ".model"));
Dict::initDictsFromFile(mlp->getModel(), name);
Dict::initDicts(mlp->getModel(), name);
return;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment