Skip to content
Snippets Groups Projects
Commit b2f49689 authored by Franck Dary's avatar Franck Dary
Browse files

Added Fasttext support

parent 77f3e1b4
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules") ...@@ -6,6 +6,7 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules")
find_package(dynet REQUIRED) find_package(dynet REQUIRED)
find_package(eigen3 REQUIRED) find_package(eigen3 REQUIRED)
find_package(Boost REQUIRED COMPONENTS program_options) find_package(Boost REQUIRED COMPONENTS program_options)
find_package(fasttext REQUIRED)
set(CMAKE_VERBOSE_MAKEFILE 0) set(CMAKE_VERBOSE_MAKEFILE 0)
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
......
...@@ -2,4 +2,5 @@ FILE(GLOB SOURCES src/*.cpp) ...@@ -2,4 +2,5 @@ FILE(GLOB SOURCES src/*.cpp)
#compiling library #compiling library
add_library(maca_common STATIC ${SOURCES}) add_library(maca_common STATIC ${SOURCES})
target_link_libraries(maca_common fasttext)
#target_link_libraries(maca_common dynet) #target_link_libraries(maca_common dynet)
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <map> #include <map>
#include <set>
#include <memory> #include <memory>
#include <fasttext/fasttext.h>
class Dict class Dict
{ {
...@@ -42,11 +44,19 @@ class Dict ...@@ -42,11 +44,19 @@ class Dict
std::string filename; std::string filename;
int oneHotIndex; int oneHotIndex;
std::string ftFilename;
std::unique_ptr<fasttext::FastText> ftEmbeddings;
std::unique_ptr<fasttext::Vector> ftVector;
std::map< std::string, std::vector<float> > ftVocab;
static std::map< std::string, std::unique_ptr<Dict> > str2dict; static std::map< std::string, std::unique_ptr<Dict> > str2dict;
private : private :
void initEmbedding(std::vector<float> & vec); std::vector<float> * getValueFasttext(const std::string & s);
void initEmbedding(const std::string & s, std::vector<float> & vec);
void initEmbeddingRandom(std::vector<float> & vec);
void initEmbeddingFromFasttext(const std::string & s, std::vector<float> & vec);
std::vector<float> * addEntry(const std::string & s); std::vector<float> * addEntry(const std::string & s);
Dict(Policy policy, const std::string & filename); Dict(Policy policy, const std::string & filename);
...@@ -58,6 +68,7 @@ class Dict ...@@ -58,6 +68,7 @@ class Dict
void save(); void save();
std::vector<float> * getValue(const std::string & s); std::vector<float> * getValue(const std::string & s);
const std::string * getStr(const std::string & s); const std::string * getStr(const std::string & s);
const std::string * getStrFasttext(const std::string & s);
std::vector<float> * getNullValue(); std::vector<float> * getNullValue();
int getDimension(); int getDimension();
void printForDebug(FILE * output); void printForDebug(FILE * output);
......
...@@ -67,6 +67,24 @@ Dict::Dict(Policy policy, const std::string & filename) ...@@ -67,6 +67,24 @@ Dict::Dict(Policy policy, const std::string & filename)
addEntry(nullValueStr); addEntry(nullValueStr);
// If a fasttext pretrained embedding file is specified
if(fscanf(fd, "Fasttext : %s\n", b1) == 1)
{
static_assert(std::is_same<float, fasttext::real>::value, "ERROR : fasttext::real is not float on this machine, it needs to be. Aborting.\n");
ftEmbeddings.reset(new fasttext::FastText);
ftEmbeddings->loadModel(b1);
ftFilename = b1;
if(ftEmbeddings->getDimension() != dimension)
{
fprintf(stderr, "ERROR (%s) : tried to load fasttext embeddings of dimension %d into dict \'%s\' of dimension %d. Aborting.\n", ERRINFO, ftEmbeddings->getDimension(), name.c_str(), dimension);
exit(1);
}
ftVector.reset(new fasttext::Vector(dimension));
}
// If policy is FromZero, we don't need to read the current entries // If policy is FromZero, we don't need to read the current entries
if(this->policy == Policy::FromZero) if(this->policy == Policy::FromZero)
return; return;
...@@ -110,6 +128,9 @@ void Dict::save() ...@@ -110,6 +128,9 @@ void Dict::save()
fprintf(fd, "%s\n%d\n%s\n", name.c_str(), dimension, mode2str(mode)); fprintf(fd, "%s\n%d\n%s\n", name.c_str(), dimension, mode2str(mode));
if(ftEmbeddings.get())
fprintf(fd, "Fasttext : %s\n", ftFilename.c_str());
for(auto & it : str2vec) for(auto & it : str2vec)
{ {
fprintf(fd, "%s\t", it.first.c_str()); fprintf(fd, "%s\t", it.first.c_str());
...@@ -139,11 +160,36 @@ std::vector<float> * Dict::getValue(const std::string & s) ...@@ -139,11 +160,36 @@ std::vector<float> * Dict::getValue(const std::string & s)
return &(it->second); return &(it->second);
if(policy == Policy::Final) if(policy == Policy::Final)
{
if(ftEmbeddings.get())
return getValueFasttext(s);
return getNullValue(); return getNullValue();
}
return addEntry(s); return addEntry(s);
} }
std::vector<float> * Dict::getValueFasttext(const std::string & s)
{
auto it = ftVocab.find(s);
if(it != ftVocab.end())
return &(it->second);
if(s.empty())
{
fprintf(stderr, "ERROR (%s) : dict \'%s\' was asked to store an empty entry. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
ftVocab.emplace(s, std::vector<float>(dimension, 0.0));
auto & vec = ftVocab[s];
ftEmbeddings->getWordVector(*ftVector.get(), s);
memcpy(vec.data(), ftVector.get()->data(), dimension * sizeof vec[0]);
return &vec;
}
const std::string * Dict::getStr(const std::string & s) const std::string * Dict::getStr(const std::string & s)
{ {
auto it = str2vec.find(s); auto it = str2vec.find(s);
...@@ -151,7 +197,11 @@ const std::string * Dict::getStr(const std::string & s) ...@@ -151,7 +197,11 @@ const std::string * Dict::getStr(const std::string & s)
return &(it->first); return &(it->first);
if(policy == Policy::Final) if(policy == Policy::Final)
{
if(ftEmbeddings.get())
return getStrFasttext(s);
return &nullValueStr; return &nullValueStr;
}
addEntry(s); addEntry(s);
...@@ -160,10 +210,54 @@ const std::string * Dict::getStr(const std::string & s) ...@@ -160,10 +210,54 @@ const std::string * Dict::getStr(const std::string & s)
return &(it->first); return &(it->first);
} }
void Dict::initEmbedding(std::vector<float> & vec) const std::string * Dict::getStrFasttext(const std::string & s)
{
auto it = ftVocab.find(s);
if(it != ftVocab.end())
return &(it->first);
if(s.empty())
{
fprintf(stderr, "ERROR (%s) : dict \'%s\' was asked to store an empty entry. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
ftVocab.emplace(s, std::vector<float>(dimension, 0.0));
auto & vec = ftVocab[s];
ftEmbeddings->getWordVector(*ftVector.get(), s);
memcpy(vec.data(), ftVector.get()->data(), dimension * sizeof vec[0]);
it = ftVocab.find(s);
return &(it->first);
}
void Dict::initEmbedding(const std::string & s, std::vector<float> & vec)
{ {
vec[0] = 0.0; // just to shut warning up vec[0] = 0.0; // just to shut warning up
// Here initialize a new embedding, doing nothing = all zeroes // Here initialize a new embedding, doing nothing = all zeroes
//initEmbeddingRandom(vec);
if(ftEmbeddings.get())
initEmbeddingFromFasttext(s, vec);
}
void Dict::initEmbeddingRandom(std::vector<float> & vec)
{
int range = 1;
for (auto & val : vec)
{
float sign = (rand() % 100000) >= 50000 ? 1.0 : -1.0;
float result = ((rand() % range) + 1) * sign;
float decimal = (rand() % 100000) / 100000.0;
result += decimal;
val = result;
}
}
void Dict::initEmbeddingFromFasttext(const std::string & s, std::vector<float> & vec)
{
ftEmbeddings->getWordVector(*ftVector.get(), s);
memcpy(vec.data(), ftVector.get()->data(), dimension * sizeof vec[0]);
} }
Dict::~Dict() Dict::~Dict()
...@@ -197,7 +291,7 @@ std::vector<float> * Dict::addEntry(const std::string & s) ...@@ -197,7 +291,7 @@ std::vector<float> * Dict::addEntry(const std::string & s)
oneHotIndex++; oneHotIndex++;
} }
else else
initEmbedding(vec); initEmbedding(s, vec);
return &vec; return &vec;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment