From 408336d8ff582215dd9af2e69af05f0fe4d03f97 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 3 Apr 2019 15:58:35 +0200 Subject: [PATCH] Working fasttext style embeddings learning --- neural_network/include/NeuralNetwork.hpp | 7 +++++ neural_network/src/NeuralNetwork.cpp | 16 +++++++++++ transition_machine/src/Classifier.cpp | 3 +-- transition_machine/src/FeatureBank.cpp | 34 ++++++++++++++++++------ 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/neural_network/include/NeuralNetwork.hpp b/neural_network/include/NeuralNetwork.hpp index 9842338..19751c4 100644 --- a/neural_network/include/NeuralNetwork.hpp +++ b/neural_network/include/NeuralNetwork.hpp @@ -159,6 +159,13 @@ class NeuralNetwork /// /// @return The model of this NeuralNetwork. dynet::ParameterCollection & getModel(); + + /// \brief How much input neurons a certain feature will take. + /// + /// \param fv The FeatureValue to measure the size of. + /// + /// \return The number of input neurons taken by fv. + static unsigned int featureSize(const FeatureModel::FeatureValue & fv); }; #endif diff --git a/neural_network/src/NeuralNetwork.cpp b/neural_network/src/NeuralNetwork.cpp index 22fb007..a377b8c 100644 --- a/neural_network/src/NeuralNetwork.cpp +++ b/neural_network/src/NeuralNetwork.cpp @@ -47,6 +47,9 @@ dynet::Expression NeuralNetwork::featValue2Expression(dynet::ComputationGraph & expressions.emplace_back(dynet::lookup(cg, lu, index)); } + if (fv.func == FeatureModel::Function::Mean) + return dynet::average(expressions); + return dynet::concatenate(expressions); } @@ -172,3 +175,16 @@ dynet::Expression NeuralNetwork::activate(dynet::Expression h, Activation f) return h; } +unsigned int NeuralNetwork::featureSize(const FeatureModel::FeatureValue & fv) +{ + unsigned int res = 0; + + if (fv.func == FeatureModel::Function::Concat) + for (auto dict : fv.dicts) + res += dict->getDimension(); + else if (fv.func == FeatureModel::Function::Mean) + res = fv.dicts[0]->getDimension(); + + return res; +} + diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index 6dcf278..06641a9 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -145,8 +145,7 @@ void Classifier::initClassifier(Config & config) int nbOutputs = as->actions.size(); for (auto feat : fd.values) - for (auto dict : feat.dicts) - nbInputs += dict->getDimension(); + nbInputs += NeuralNetwork::featureSize(feat); nn->init(nbInputs, topology, nbOutputs); } diff --git a/transition_machine/src/FeatureBank.cpp b/transition_machine/src/FeatureBank.cpp index cad942a..9e3e796 100644 --- a/transition_machine/src/FeatureBank.cpp +++ b/transition_machine/src/FeatureBank.cpp @@ -519,7 +519,6 @@ FeatureModel::FeatureValue FeatureBank::aggregateStack(Config & c, int from, con return result; } -//TODO : ne pas utiliser une feature value pour word mais un string, pour que ça marche avec les mots inconnus FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel::FeatureValue & word) { FeatureModel::FeatureValue result(FeatureModel::Function::Mean); @@ -531,7 +530,9 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel: return {lettersDict, word.names[0], Dict::nullValueStr, policy}; unsigned int wordLength = getNbSymbols(word.values[0]); - unsigned int gramLength = 2; + unsigned int gramLength = 4; + + bool slidingMode = false; if (wordLength < gramLength) { @@ -543,13 +544,30 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel: } else { - for (unsigned int i = 0; i+gramLength-1 < wordLength; i++) + if (!slidingMode) + { + int nbGrams = wordLength / gramLength + (wordLength % gramLength ? 1 : 0); + for (int i = 0; i < nbGrams; i++) + { + int from = i * gramLength; + int to = i == nbGrams-1 ? wordLength-1 : (i+1)*gramLength-1; + auto value = getLetters(c, word, from, to); + result.dicts.emplace_back(value.dicts[0]); + result.names.emplace_back(value.names[0]); + result.values.emplace_back(value.values[0]); + result.policies.emplace_back(value.policies[0]); + } + } + else { - auto value = getLetters(c, word, i, i+gramLength-1); - result.dicts.emplace_back(value.dicts[0]); - result.names.emplace_back(value.names[0]); - result.values.emplace_back(value.values[0]); - result.policies.emplace_back(value.policies[0]); + for (unsigned int i = 0; i+gramLength-1 < wordLength; i++) + { + auto value = getLetters(c, word, i, i+gramLength-1); + result.dicts.emplace_back(value.dicts[0]); + result.names.emplace_back(value.names[0]); + result.values.emplace_back(value.values[0]); + result.policies.emplace_back(value.policies[0]); + } } } -- GitLab