Skip to content
Snippets Groups Projects
Commit 10feacb0 authored by Franck Dary's avatar Franck Dary
Browse files

Added new program to convert embeddings file from word2vec format to macaon format

parent 5a987b3e
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,12 @@ target_link_libraries(macaon_compute_l_rules ${Boost_PROGRAM_OPTIONS_LIBRARY})
target_link_libraries(macaon_compute_l_rules maca_common)
install(TARGETS macaon_compute_l_rules DESTINATION bin)
add_executable(macaon_convert_embeddings src/macaon_convert_embeddings.cpp)
target_link_libraries(macaon_convert_embeddings ${Boost_PROGRAM_OPTIONS_LIBRARY})
target_link_libraries(macaon_convert_embeddings maca_common)
target_link_libraries(macaon_convert_embeddings dynet)
install(TARGETS macaon_convert_embeddings DESTINATION bin)
#compiling library
add_library(maca_common STATIC ${SOURCES})
target_link_libraries(maca_common fasttext)
......@@ -168,7 +168,6 @@ class Dict
///
/// @return The lookupParameter index of the newly added entry.
unsigned int addEntry(const std::string & s);
void init(dynet::ParameterCollection & pc);
void initFromFile(dynet::ParameterCollection & pc);
/// @brief Read and construct a new Dict from a file.
///
......@@ -176,12 +175,6 @@ class Dict
/// @param policy The Policy of the new Dict.
/// @param filename The filename we will read the new Dict from.
Dict(const std::string & name, Policy policy, const std::string & filename);
/// @brief Construct a new Dict.
///
/// @param name The name of the Dict to construct.
/// @param dimension The dimension of the vectors in the new Dict.
/// @param mode The Mode of the new Dict.
Dict(const std::string & name, int dimension, Mode mode);
/// @brief Get a pointer to the entry matching s.
///
/// This is used when we need a permanent pointer to a string matching s,
......@@ -197,6 +190,19 @@ class Dict
public :
/// @brief Construct a new Dict.
///
/// @param name The name of the Dict to construct.
/// @param dimension The dimension of the vectors in the new Dict.
/// @param mode The Mode of the new Dict.
Dict(const std::string & name, int dimension, Mode mode);
void init(dynet::ParameterCollection & pc);
void initParameterAsValue(unsigned int index, const std::vector<float> & value);
unsigned int addEntry(const std::string & s, const std::vector<float> & embedding);
/// @brief Get a pointer to a Dict.
///
/// If the Dict doesn't exist, it will be constructed.
......
......@@ -319,6 +319,11 @@ void Dict::initParameterAsEmbedding(const std::string & s, unsigned int index)
initEmbeddingZero(index);
}
void Dict::initParameterAsValue(unsigned int index, const std::vector<float> & value)
{
lookupParameter.initialize(index, value);
}
void Dict::initEmbeddingZero(unsigned int index)
{
lookupParameter.initialize(index, std::vector<float>(dimension, 0.0));
......@@ -390,6 +395,45 @@ unsigned int Dict::addEntry(const std::string & s)
return index;
}
unsigned int Dict::addEntry(const std::string & s, const std::vector<float> & embedding)
{
if (!isInit)
{
fprintf(stderr, "ERROR (%s) : dict \'%s\' is not initialized. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
if(s.empty())
{
fprintf(stderr, "ERROR (%s) : dict \'%s\' was asked to store an empty entry. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
auto index = str2index.size();
str2index.emplace(s, index);
if ((int)str2index.size() >= ProgramParameters::dictCapacity)
{
fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Aborting.\n", ERRINFO, name.c_str(), ProgramParameters::dictCapacity);
exit(1);
}
if(mode == Mode::OneHot)
{
if(oneHotIndex >= dimension)
fprintf(stderr, "WARNING (%s) : Dict %s of dimension %d is asked to store %d elements in one-hot.\n", ERRINFO, name.c_str(), dimension, oneHotIndex+1);
else
{
initParameterAsOneHot(s, index);
oneHotIndex++;
}
}
else
initParameterAsValue(index, embedding);
return index;
}
Dict * Dict::getDict(Policy policy, const std::string & filename)
{
auto it = str2dict.find(filename);
......
/// \file macaon_convert_embeddings.cpp
/// \author Franck Dary
/// @version 1.0
/// @date 2019-04-16
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include "File.hpp"
#include "util.hpp"
#include "Dict.hpp"
#include "ProgramParameters.hpp"
#include <boost/program_options.hpp>
namespace po = boost::program_options;
/// @brief Get the list of mandatory and optional program arguments.
///
/// @return The lists.
po::options_description getOptionsDescription()
{
po::options_description desc("Command-Line Arguments ");
po::options_description req("Required");
req.add_options()
("input,i", po::value<std::string>()->required(),
"File containing the embeddings")
("output,o", po::value<std::string>()->required(),
"Name of the desired output file")
("dictCapacity", po::value<int>()->default_value(30000),
"The maximal size of each Dict (number of differents embeddings).");
po::options_description opt("Optional");
opt.add_options()
("help,h", "Produce this help message")
("debug,d", "Print infos on stderr");
desc.add(req).add(opt);
return desc;
}
/// @brief Store the program arguments inside a variables_map
///
/// @param od The description of all the possible options.
/// @param argc The number of arguments given to this program.
/// @param argv The values of arguments given to this program.
///
/// @return The variables map
po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
{
po::variables_map vm;
try {po::store(po::parse_command_line(argc, argv, od), vm);}
catch(std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
od.print(std::cerr);
exit(1);
}
if (vm.count("help"))
{
std::cout << od << "\n";
exit(0);
}
try {po::notify(vm);}
catch(std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
od.print(std::cerr);
exit(1);
}
return vm;
}
/// @brief Given a fplm file (pairs of word / lemma), compute rules that will transform these words into lemmas, as well as exceptions.
///
/// @param argc The number of arguments given to this program.
/// @param argv[] Array of arguments given to this program.
///
/// @return 0 if there was no crash.
int main(int argc, char * argv[])
{
auto od = getOptionsDescription();
po::variables_map vm = checkOptions(od, argc, argv);
std::string inputFilename = vm["input"].as<std::string>();
std::string outputFilename = vm["output"].as<std::string>();
ProgramParameters::dictCapacity = vm["dictCapacity"].as<int>();
File input(inputFilename, "r");
int nbEmbeddings;
int embeddingsSize;
char buffer[100000];
if (fscanf(input.getDescriptor(), "%d %d\n", &nbEmbeddings, &embeddingsSize) != 2)
{
fprintf(stderr, "ERROR (%s) : Wrong format, expected numberOfEmbeddings and embeddingsSize. Aborting.\n", ERRINFO);
exit(1);
}
std::vector<float> embedding;
dynet::initialize(argc, argv);
dynet::ParameterCollection pc;
Dict dict(outputFilename, embeddingsSize, Dict::Mode::Embeddings);
dict.init(pc);
while (fscanf(input.getDescriptor(), "%[^\n]\n", buffer) == 1)
{
embedding.clear();
auto splited = split(buffer, ' ');
if ((int)splited.size() != embeddingsSize+1)
{
fprintf(stderr, "ERROR (%s) : line \'%s\' wrong format. Aborting.\n", ERRINFO, buffer);
exit(1);
}
for (unsigned int i = 1; i < splited.size(); i++)
embedding.emplace_back(std::stof(splited[i]));
dict.addEntry(splited[0], embedding);
}
dict.save();
return 0;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment