Skip to content
Snippets Groups Projects
Commit 408336d8 authored by Franck Dary's avatar Franck Dary
Browse files

Working fasttext style embeddings learning

parent 868799b9
Branches
No related tags found
No related merge requests found
...@@ -159,6 +159,13 @@ class NeuralNetwork ...@@ -159,6 +159,13 @@ class NeuralNetwork
/// ///
/// @return The model of this NeuralNetwork. /// @return The model of this NeuralNetwork.
dynet::ParameterCollection & getModel(); dynet::ParameterCollection & getModel();
/// \brief How much input neurons a certain feature will take.
///
/// \param fv The FeatureValue to measure the size of.
///
/// \return The number of input neurons taken by fv.
static unsigned int featureSize(const FeatureModel::FeatureValue & fv);
}; };
#endif #endif
...@@ -47,6 +47,9 @@ dynet::Expression NeuralNetwork::featValue2Expression(dynet::ComputationGraph & ...@@ -47,6 +47,9 @@ dynet::Expression NeuralNetwork::featValue2Expression(dynet::ComputationGraph &
expressions.emplace_back(dynet::lookup(cg, lu, index)); expressions.emplace_back(dynet::lookup(cg, lu, index));
} }
if (fv.func == FeatureModel::Function::Mean)
return dynet::average(expressions);
return dynet::concatenate(expressions); return dynet::concatenate(expressions);
} }
...@@ -172,3 +175,16 @@ dynet::Expression NeuralNetwork::activate(dynet::Expression h, Activation f) ...@@ -172,3 +175,16 @@ dynet::Expression NeuralNetwork::activate(dynet::Expression h, Activation f)
return h; return h;
} }
unsigned int NeuralNetwork::featureSize(const FeatureModel::FeatureValue & fv)
{
unsigned int res = 0;
if (fv.func == FeatureModel::Function::Concat)
for (auto dict : fv.dicts)
res += dict->getDimension();
else if (fv.func == FeatureModel::Function::Mean)
res = fv.dicts[0]->getDimension();
return res;
}
...@@ -145,8 +145,7 @@ void Classifier::initClassifier(Config & config) ...@@ -145,8 +145,7 @@ void Classifier::initClassifier(Config & config)
int nbOutputs = as->actions.size(); int nbOutputs = as->actions.size();
for (auto feat : fd.values) for (auto feat : fd.values)
for (auto dict : feat.dicts) nbInputs += NeuralNetwork::featureSize(feat);
nbInputs += dict->getDimension();
nn->init(nbInputs, topology, nbOutputs); nn->init(nbInputs, topology, nbOutputs);
} }
......
...@@ -519,7 +519,6 @@ FeatureModel::FeatureValue FeatureBank::aggregateStack(Config & c, int from, con ...@@ -519,7 +519,6 @@ FeatureModel::FeatureValue FeatureBank::aggregateStack(Config & c, int from, con
return result; return result;
} }
//TODO : ne pas utiliser une feature value pour word mais un string, pour que ça marche avec les mots inconnus
FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel::FeatureValue & word) FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel::FeatureValue & word)
{ {
FeatureModel::FeatureValue result(FeatureModel::Function::Mean); FeatureModel::FeatureValue result(FeatureModel::Function::Mean);
...@@ -531,7 +530,9 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel: ...@@ -531,7 +530,9 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel:
return {lettersDict, word.names[0], Dict::nullValueStr, policy}; return {lettersDict, word.names[0], Dict::nullValueStr, policy};
unsigned int wordLength = getNbSymbols(word.values[0]); unsigned int wordLength = getNbSymbols(word.values[0]);
unsigned int gramLength = 2; unsigned int gramLength = 4;
bool slidingMode = false;
if (wordLength < gramLength) if (wordLength < gramLength)
{ {
...@@ -542,6 +543,22 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel: ...@@ -542,6 +543,22 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel:
result.policies.emplace_back(value.policies[0]); result.policies.emplace_back(value.policies[0]);
} }
else else
{
if (!slidingMode)
{
int nbGrams = wordLength / gramLength + (wordLength % gramLength ? 1 : 0);
for (int i = 0; i < nbGrams; i++)
{
int from = i * gramLength;
int to = i == nbGrams-1 ? wordLength-1 : (i+1)*gramLength-1;
auto value = getLetters(c, word, from, to);
result.dicts.emplace_back(value.dicts[0]);
result.names.emplace_back(value.names[0]);
result.values.emplace_back(value.values[0]);
result.policies.emplace_back(value.policies[0]);
}
}
else
{ {
for (unsigned int i = 0; i+gramLength-1 < wordLength; i++) for (unsigned int i = 0; i+gramLength-1 < wordLength; i++)
{ {
...@@ -552,6 +569,7 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel: ...@@ -552,6 +569,7 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel:
result.policies.emplace_back(value.policies[0]); result.policies.emplace_back(value.policies[0]);
} }
} }
}
return result; return result;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment