From 1cf6cf2facc743b377c2734d9cdf846154b99c21 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 1 Apr 2020 14:07:02 +0200
Subject: [PATCH] Deleted OneWordNetwork

---
 reading_machine/src/Classifier.cpp       |  9 ---------
 torch_modules/include/OneWordNetwork.hpp | 19 -------------------
 torch_modules/src/OneWordNetwork.cpp     | 24 ------------------------
 3 files changed, 52 deletions(-)
 delete mode 100644 torch_modules/include/OneWordNetwork.hpp
 delete mode 100644 torch_modules/src/OneWordNetwork.cpp

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 41f9f15..a58c3b1 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -1,6 +1,5 @@
 #include "Classifier.hpp"
 #include "util.hpp"
-#include "OneWordNetwork.hpp"
 #include "ConcatWordsNetwork.hpp"
 #include "RLTNetwork.hpp"
 #include "CNNNetwork.hpp"
@@ -41,14 +40,6 @@ void Classifier::initNeuralNetwork(const std::string & topology)
         this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
       }
     },
-    {
-      std::regex("OneWord\\(([+\\-]?\\d+)\\)"),
-      "OneWord(focusedIndex) : Only use the word embedding of the focused word.",
-      [this,topology](auto sm)
-      {
-        this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1))));
-      }
-    },
     {
       std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"),
       "ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.",
diff --git a/torch_modules/include/OneWordNetwork.hpp b/torch_modules/include/OneWordNetwork.hpp
deleted file mode 100644
index 9882b62..0000000
--- a/torch_modules/include/OneWordNetwork.hpp
+++ /dev/null
@@ -1,19 +0,0 @@
-#ifndef ONEWORDNETWORK__H
-#define ONEWORDNETWORK__H
-
-#include "NeuralNetwork.hpp"
-
-class OneWordNetworkImpl : public NeuralNetworkImpl
-{
-  private :
-
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  torch::nn::Linear linear{nullptr};
-
-  public :
-
-  OneWordNetworkImpl(int nbOutputs, int focusedIndex);
-  torch::Tensor forward(torch::Tensor input) override;
-};
-
-#endif
diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp
deleted file mode 100644
index e3ed3d5..0000000
--- a/torch_modules/src/OneWordNetwork.cpp
+++ /dev/null
@@ -1,24 +0,0 @@
-#include "OneWordNetwork.hpp"
-
-OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
-{
-  constexpr int embeddingsSize = 64;
-
-  setBufferContext({focusedIndex});
-  setStackContext({});
-  setColumns({"FORM", "UPOS"});
-
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
-  linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
-}
-
-torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
-{
-  if (input.dim() == 1)
-    input = input.unsqueeze(0);
-  auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1});
-  auto res = linear(wordAsEmb);
-
-  return res;
-}
-
-- 
GitLab