From 72c543b6fb9bc50f84af4aaf353573f61a640c4d Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 11 Mar 2020 10:31:10 +0100
Subject: [PATCH] During trainning, convert rare forms to unknownValue to train
 the corresponding embedding

---
 common/include/Dict.hpp                    |  2 ++
 common/src/Dict.cpp                        | 14 ++++++++++++++
 reading_machine/include/ReadingMachine.hpp |  1 +
 reading_machine/src/ReadingMachine.cpp     |  5 +++++
 torch_modules/include/CNNNetwork.hpp       |  3 +++
 torch_modules/src/CNNNetwork.cpp           | 14 ++++++++++++--
 trainer/src/macaon_train.cpp               | 18 ++++++++++++++++++
 7 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 8df818f..e88d07f 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -44,6 +44,8 @@ class Dict
   void save(std::FILE * destination, Encoding encoding) const;
   bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
   void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
+  std::size_t size() const;
+  int getNbOccs(int index) const;
 };
 
 #endif
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index e645083..ffaca44 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -75,6 +75,8 @@ int Dict::getIndexOrInsert(const std::string & element)
     if (state == State::Open)
     {
       insert(element);
+      if (isCountingOccs)
+        nbOccs[elementsToIndexes[element]]++;
       return elementsToIndexes[element];
     }
     if (isCountingOccs)
@@ -146,3 +148,15 @@ void Dict::countOcc(bool isCountingOccs)
   this->isCountingOccs = isCountingOccs;
 }
 
+std::size_t Dict::size() const
+{
+  return elementsToIndexes.size();
+}
+
+int Dict::getNbOccs(int index) const
+{
+  if (index < 0 || index >= (int)nbOccs.size())
+    return 0;
+  return nbOccs[index];
+}
+
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 8e3f047..4ce25aa 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -36,6 +36,7 @@ class ReadingMachine
   TransitionSet & getTransitionSet();
   Strategy & getStrategy();
   Dict & getDict(const std::string & state);
+  std::map<std::string, Dict> & getDicts();
   Classifier * getClassifier();
   void save() const;
   bool isPredicted(const std::string & columnName) const;
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index a32b5b8..84314a3 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -131,3 +131,8 @@ void ReadingMachine::trainMode(bool isTrainMode)
     it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
 }
 
+std::map<std::string, Dict> & ReadingMachine::getDicts()
+{
+  return dicts;
+}
+
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index cab39f0..1f60c73 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -8,6 +8,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
+  static constexpr int maxNbEmbeddings = 50000;
+  static constexpr int unknownValueThreshold = 0;
+
   std::vector<int> focusedBufferIndexes;
   std::vector<int> focusedStackIndexes;
   std::vector<std::string> focusedColumns;
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 8889fe0..130c092 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -19,7 +19,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
     rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
   int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
 
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
   embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
   cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
   hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
@@ -76,6 +76,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
 
 std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
+  if (dict.size() >= maxNbEmbeddings)
+    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
+
   std::vector<long> contextIndexes = extractContextIndexes(config);
   std::vector<long> context;
 
@@ -100,7 +103,14 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
       if (index == -1)
         context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
       else
-        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
+      {
+        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
+        if (col == "FORM" || col == "LEMMA")
+          if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
+            dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr);
+
+        context.push_back(dictIndex);
+      }
 
   for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
   {
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index 9dbceaa..2b378eb 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -61,6 +61,22 @@ po::variables_map checkOptions(po::options_description & od, int argc, char ** a
   return vm;
 }
 
+void fillDicts(ReadingMachine & rm, const Config & config)
+{
+  static std::vector<std::string> interestingColumns{"FORM", "LEMMA"};
+
+  for (auto & col : interestingColumns)
+    if (config.has(col,0,0))
+      for (auto & it : rm.getDicts())
+      {
+        it.second.countOcc(true);
+        for (unsigned int j = 0; j < config.getNbLines(); j++)
+          for (unsigned int k = 0; k < Config::nbHypothesesMax; k++)
+            it.second.getIndexOrInsert(config.getConst(col,j,k));
+        it.second.countOcc(false);
+      }
+}
+
 int main(int argc, char * argv[])
 {
   auto od = getOptionsDescription();
@@ -89,6 +105,8 @@ int main(int argc, char * argv[])
   BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
   SubConfig config(goldConfig);
 
+  fillDicts(machine, goldConfig);
+
   Trainer trainer(machine);
   trainer.createDataset(config, debug);
   if (!computeDevScore)
-- 
GitLab