From 0f5a864fcacafb44eb83afdee142d0a0247f0e6c Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 20 Jan 2020 22:31:57 +0100
Subject: [PATCH] Starting to build neural network

---
 dev/CMakeLists.txt                    |  1 +
 dev/src/dev.cpp                       | 11 +++---
 torch_modules/include/TestNetwork.hpp | 22 ++++++++++++
 torch_modules/src/TestNetwork.cpp     | 52 +++++++++++++++++++++++++++
 4 files changed, 81 insertions(+), 5 deletions(-)
 create mode 100644 torch_modules/include/TestNetwork.hpp
 create mode 100644 torch_modules/src/TestNetwork.cpp

diff --git a/dev/CMakeLists.txt b/dev/CMakeLists.txt
index e7c5cc7..a473806 100644
--- a/dev/CMakeLists.txt
+++ b/dev/CMakeLists.txt
@@ -3,3 +3,4 @@ FILE(GLOB SOURCES src/*.cpp)
 add_executable(dev src/dev.cpp)
 target_link_libraries(dev common)
 target_link_libraries(dev reading_machine)
+target_link_libraries(dev torch_modules)
diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index 889b473..ccf3162 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -5,6 +5,7 @@
 #include "SubConfig.hpp"
 #include "TransitionSet.hpp"
 #include "ReadingMachine.hpp"
+#include "TestNetwork.hpp"
 
 int main(int argc, char * argv[])
 {
@@ -26,8 +27,7 @@ int main(int argc, char * argv[])
   SubConfig config(goldConfig);
 
   config.setState(machine.getStrategy().getInitialState());
-
-  std::vector<std::pair<SubConfig, Transition*>> trainingExamples;
+  TestNetwork nn;
 
   while (true)
   {
@@ -35,7 +35,10 @@ int main(int argc, char * argv[])
     if (!transition)
       util::myThrow("No transition appliable !");
 
-    trainingExamples.emplace_back(config, transition);
+    //here train
+    auto testo = nn(config);
+
+//    std::cout << testo << std::endl;
 
     transition->apply(config);
     config.addToHistory(transition->getName());
@@ -52,8 +55,6 @@ int main(int argc, char * argv[])
       config.update();
   }
 
-  trainingExamples[10000].first.printForDebug(stderr);
-
   return 0;
 }
 
diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp
new file mode 100644
index 0000000..f286a40
--- /dev/null
+++ b/torch_modules/include/TestNetwork.hpp
@@ -0,0 +1,22 @@
+#ifndef TESTNETWORK__H
+#define TESTNETWORK__H
+
+#include <torch/torch.h>
+#include "Config.hpp"
+
+class TestNetworkImpl : torch::nn::Module
+{
+  private :
+
+  std::map<Config::String, std::size_t> dict;
+  torch::nn::Embedding wordEmbeddings{nullptr};
+
+  public :
+
+  TestNetworkImpl();
+  torch::Tensor forward(const Config & config);
+  std::size_t getOrAddDictValue(Config::String s);
+};
+TORCH_MODULE(TestNetwork);
+
+#endif
diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp
new file mode 100644
index 0000000..fc3d0e6
--- /dev/null
+++ b/torch_modules/src/TestNetwork.cpp
@@ -0,0 +1,52 @@
+#include "TestNetwork.hpp"
+
+TestNetworkImpl::TestNetworkImpl()
+{
+  getOrAddDictValue(Config::String("_null_"));
+  getOrAddDictValue(Config::String("_unknown_"));
+  getOrAddDictValue(Config::String("_S_"));
+
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
+}
+
+torch::Tensor TestNetworkImpl::forward(const Config & config)
+{
+//  std::vector<std::size_t> test{0,1};
+//  torch::Tensor tens = torch::from_blob(test.data(), {1,2});
+//  return wordEmbeddings(tens);
+  constexpr int windowSize = 5;
+  int startIndex = config.getWordIndex();
+  while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize)
+    startIndex--;
+  int endIndex = config.getWordIndex();
+  while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize)
+    endIndex++;
+
+  std::vector<std::size_t> words;
+  for (int i = startIndex; i <= endIndex; ++i)
+  {
+    if (!config.has(0, i, 0))
+      util::myThrow(fmt::format("Config do not have line %d", i));
+
+    words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
+  }
+
+  if (words.empty())
+    util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex));
+
+  return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong));
+}
+
+std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
+{
+  if (s.get().empty())
+    return dict[Config::String("_null_")];
+
+  const auto & found = dict.find(s);
+
+  if (found == dict.end())
+    return dict[s] = dict.size();
+
+  return found->second;
+}
+
-- 
GitLab