From 1efb791a9fb0c8e23269ea95cef5adf0ff4d31c0 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 9 Mar 2020 21:20:38 +0100
Subject: [PATCH] Neural network now sees multiwords, also ID can now be a
 focused column

---
 torch_modules/src/CNNNetwork.cpp    | 11 ++++++++++-
 torch_modules/src/NeuralNetwork.cpp |  4 ++--
 trainer/include/Trainer.hpp         |  2 +-
 trainer/src/Trainer.cpp             |  2 +-
 4 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 2f0cc9f..285ad8c 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -3,7 +3,7 @@
 CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
 {
   constexpr int embeddingsSize = 64;
-  constexpr int hiddenSize = 512;
+  constexpr int hiddenSize = 1024;
   constexpr int nbFiltersContext = 512;
   constexpr int nbFiltersFocused = 64;
 
@@ -152,6 +152,15 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
           else
             elements.emplace_back(Dict::nullValueStr);
       }
+      else if (col == "ID")
+      {
+        if (config.isTokenPredicted(index))
+          elements.emplace_back("ID(TOKEN)");
+        else if (config.isMultiwordPredicted(index))
+          elements.emplace_back("ID(MULTIWORD)");
+        else if (config.isEmptyNodePredicted(index))
+          elements.emplace_back("ID(EMPTYNODE)");
+      }
       else
       {
         elements.emplace_back(config.getAsFeature(col, index));
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 3f69b4a..0ef9f8d 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
 {
   std::stack<long> leftContext;
   for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
-    if (config.isToken(index))
+    if (!config.isComment(index))
       leftContext.push(index);
 
   std::vector<long> context;
@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
   }
 
   for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
-    if (config.isToken(index))
+    if (!config.isComment(index))
       context.emplace_back(index);
 
   while (context.size() < leftBorder+rightBorder+1)
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index e04f3e3..259a150 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -19,7 +19,7 @@ class Trainer
   DataLoader devDataLoader{nullptr};
   std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
-  int batchSize{50};
+  int batchSize{64};
   int nbExamples{0};
 
   private :
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 450c416..2b68072 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -16,7 +16,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
   dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
-  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
+  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
 }
 
 void Trainer::createDevDataset(SubConfig & config, bool debug)
-- 
GitLab