From 72c543b6fb9bc50f84af4aaf353573f61a640c4d Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 11 Mar 2020 10:31:10 +0100 Subject: [PATCH] During trainning, convert rare forms to unknownValue to train the corresponding embedding --- common/include/Dict.hpp | 2 ++ common/src/Dict.cpp | 14 ++++++++++++++ reading_machine/include/ReadingMachine.hpp | 1 + reading_machine/src/ReadingMachine.cpp | 5 +++++ torch_modules/include/CNNNetwork.hpp | 3 +++ torch_modules/src/CNNNetwork.cpp | 14 ++++++++++++-- trainer/src/macaon_train.cpp | 18 ++++++++++++++++++ 7 files changed, 55 insertions(+), 2 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 8df818f..e88d07f 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -44,6 +44,8 @@ class Dict void save(std::FILE * destination, Encoding encoding) const; bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding); void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const; + std::size_t size() const; + int getNbOccs(int index) const; }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index e645083..ffaca44 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -75,6 +75,8 @@ int Dict::getIndexOrInsert(const std::string & element) if (state == State::Open) { insert(element); + if (isCountingOccs) + nbOccs[elementsToIndexes[element]]++; return elementsToIndexes[element]; } if (isCountingOccs) @@ -146,3 +148,15 @@ void Dict::countOcc(bool isCountingOccs) this->isCountingOccs = isCountingOccs; } +std::size_t Dict::size() const +{ + return elementsToIndexes.size(); +} + +int Dict::getNbOccs(int index) const +{ + if (index < 0 || index >= (int)nbOccs.size()) + return 0; + return nbOccs[index]; +} + diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 8e3f047..4ce25aa 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -36,6 +36,7 @@ class ReadingMachine TransitionSet & getTransitionSet(); Strategy & getStrategy(); Dict & getDict(const std::string & state); + std::map<std::string, Dict> & getDicts(); Classifier * getClassifier(); void save() const; bool isPredicted(const std::string & columnName) const; diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index a32b5b8..84314a3 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -131,3 +131,8 @@ void ReadingMachine::trainMode(bool isTrainMode) it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); } +std::map<std::string, Dict> & ReadingMachine::getDicts() +{ + return dicts; +} + diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index cab39f0..1f60c73 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -8,6 +8,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl { private : + static constexpr int maxNbEmbeddings = 50000; + static constexpr int unknownValueThreshold = 0; + std::vector<int> focusedBufferIndexes; std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 8889fe0..130c092 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -19,7 +19,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize(); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3)); hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); @@ -76,6 +76,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const { + if (dict.size() >= maxNbEmbeddings) + util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); + std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; @@ -100,7 +103,14 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c if (index == -1) context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else - context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); + { + int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); + if (col == "FORM" || col == "LEMMA") + if (dict.getNbOccs(dictIndex) < unknownValueThreshold) + dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr); + + context.push_back(dictIndex); + } for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) { diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 9dbceaa..2b378eb 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -61,6 +61,22 @@ po::variables_map checkOptions(po::options_description & od, int argc, char ** a return vm; } +void fillDicts(ReadingMachine & rm, const Config & config) +{ + static std::vector<std::string> interestingColumns{"FORM", "LEMMA"}; + + for (auto & col : interestingColumns) + if (config.has(col,0,0)) + for (auto & it : rm.getDicts()) + { + it.second.countOcc(true); + for (unsigned int j = 0; j < config.getNbLines(); j++) + for (unsigned int k = 0; k < Config::nbHypothesesMax; k++) + it.second.getIndexOrInsert(config.getConst(col,j,k)); + it.second.countOcc(false); + } +} + int main(int argc, char * argv[]) { auto od = getOptionsDescription(); @@ -89,6 +105,8 @@ int main(int argc, char * argv[]) BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); SubConfig config(goldConfig); + fillDicts(machine, goldConfig); + Trainer trainer(machine); trainer.createDataset(config, debug); if (!computeDevScore) -- GitLab