diff --git a/neural_network/include/NeuralNetwork.hpp b/neural_network/include/NeuralNetwork.hpp index 984233801c0dbc7c60d6be0d987ef8ab4d711256..19751c44bb190da97df7d343682ed794034009cb 100644 --- a/neural_network/include/NeuralNetwork.hpp +++ b/neural_network/include/NeuralNetwork.hpp @@ -159,6 +159,13 @@ class NeuralNetwork /// /// @return The model of this NeuralNetwork. dynet::ParameterCollection & getModel(); + + /// \brief How much input neurons a certain feature will take. + /// + /// \param fv The FeatureValue to measure the size of. + /// + /// \return The number of input neurons taken by fv. + static unsigned int featureSize(const FeatureModel::FeatureValue & fv); }; #endif diff --git a/neural_network/src/NeuralNetwork.cpp b/neural_network/src/NeuralNetwork.cpp index 22fb0071003880b86cf84bedb3f97e0a1ce8121a..a377b8c7be2ee389aecb99fe3b22e2c87098c8c4 100644 --- a/neural_network/src/NeuralNetwork.cpp +++ b/neural_network/src/NeuralNetwork.cpp @@ -47,6 +47,9 @@ dynet::Expression NeuralNetwork::featValue2Expression(dynet::ComputationGraph & expressions.emplace_back(dynet::lookup(cg, lu, index)); } + if (fv.func == FeatureModel::Function::Mean) + return dynet::average(expressions); + return dynet::concatenate(expressions); } @@ -172,3 +175,16 @@ dynet::Expression NeuralNetwork::activate(dynet::Expression h, Activation f) return h; } +unsigned int NeuralNetwork::featureSize(const FeatureModel::FeatureValue & fv) +{ + unsigned int res = 0; + + if (fv.func == FeatureModel::Function::Concat) + for (auto dict : fv.dicts) + res += dict->getDimension(); + else if (fv.func == FeatureModel::Function::Mean) + res = fv.dicts[0]->getDimension(); + + return res; +} + diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index 6dcf278a9ae881de80be0bc0b4f7d3c65d16257b..06641a9eb56727d25d5a701e4b2406ea4252630b 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -145,8 +145,7 @@ void Classifier::initClassifier(Config & config) int nbOutputs = as->actions.size(); for (auto feat : fd.values) - for (auto dict : feat.dicts) - nbInputs += dict->getDimension(); + nbInputs += NeuralNetwork::featureSize(feat); nn->init(nbInputs, topology, nbOutputs); } diff --git a/transition_machine/src/FeatureBank.cpp b/transition_machine/src/FeatureBank.cpp index cad942aa8728668ac3069e314be58710c82a038e..9e3e796e83e3013431c41f78066ed08f8108afb0 100644 --- a/transition_machine/src/FeatureBank.cpp +++ b/transition_machine/src/FeatureBank.cpp @@ -519,7 +519,6 @@ FeatureModel::FeatureValue FeatureBank::aggregateStack(Config & c, int from, con return result; } -//TODO : ne pas utiliser une feature value pour word mais un string, pour que ça marche avec les mots inconnus FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel::FeatureValue & word) { FeatureModel::FeatureValue result(FeatureModel::Function::Mean); @@ -531,7 +530,9 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel: return {lettersDict, word.names[0], Dict::nullValueStr, policy}; unsigned int wordLength = getNbSymbols(word.values[0]); - unsigned int gramLength = 2; + unsigned int gramLength = 4; + + bool slidingMode = false; if (wordLength < gramLength) { @@ -543,13 +544,30 @@ FeatureModel::FeatureValue FeatureBank::fasttext(Config & c, const FeatureModel: } else { - for (unsigned int i = 0; i+gramLength-1 < wordLength; i++) + if (!slidingMode) + { + int nbGrams = wordLength / gramLength + (wordLength % gramLength ? 1 : 0); + for (int i = 0; i < nbGrams; i++) + { + int from = i * gramLength; + int to = i == nbGrams-1 ? wordLength-1 : (i+1)*gramLength-1; + auto value = getLetters(c, word, from, to); + result.dicts.emplace_back(value.dicts[0]); + result.names.emplace_back(value.names[0]); + result.values.emplace_back(value.values[0]); + result.policies.emplace_back(value.policies[0]); + } + } + else { - auto value = getLetters(c, word, i, i+gramLength-1); - result.dicts.emplace_back(value.dicts[0]); - result.names.emplace_back(value.names[0]); - result.values.emplace_back(value.values[0]); - result.policies.emplace_back(value.policies[0]); + for (unsigned int i = 0; i+gramLength-1 < wordLength; i++) + { + auto value = getLetters(c, word, i, i+gramLength-1); + result.dicts.emplace_back(value.dicts[0]); + result.names.emplace_back(value.names[0]); + result.values.emplace_back(value.values[0]); + result.policies.emplace_back(value.policies[0]); + } } }