From 0b1e248c13d82bce28c14d37d75e297ab209d23c Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@etu.univ-amu.fr>
Date: Wed, 2 Jan 2019 14:12:59 +0100
Subject: [PATCH] Added NeuralNetwork interface

---
 CMakeLists.txt                           |  2 +
 neural_network/CMakeLists.txt            |  4 ++
 neural_network/include/NeuralNetwork.hpp | 55 ++++++++++++++++++++++++
 neural_network/src/NeuralNetwork.cpp     | 17 ++++++++
 4 files changed, 78 insertions(+)
 create mode 100644 neural_network/CMakeLists.txt
 create mode 100644 neural_network/include/NeuralNetwork.hpp
 create mode 100644 neural_network/src/NeuralNetwork.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9f9c4ab..ec5e2a6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -27,6 +27,7 @@ include_directories(maca_common/include)
 include_directories(transition_machine/include)
 include_directories(trainer/include)
 include_directories(decoder/include)
+include_directories(neural_network/include)
 include_directories(MLP/include)
 include_directories(error_correction/include)
 
@@ -34,6 +35,7 @@ add_subdirectory(maca_common)
 add_subdirectory(transition_machine)
 add_subdirectory(trainer)
 add_subdirectory(decoder)
+add_subdirectory(neural_network)
 add_subdirectory(MLP)
 add_subdirectory(error_correction)
 
diff --git a/neural_network/CMakeLists.txt b/neural_network/CMakeLists.txt
new file mode 100644
index 0000000..a423734
--- /dev/null
+++ b/neural_network/CMakeLists.txt
@@ -0,0 +1,4 @@
+FILE(GLOB SOURCES src/*.cpp)
+
+#compiling library
+add_library(neural_network STATIC ${SOURCES})
diff --git a/neural_network/include/NeuralNetwork.hpp b/neural_network/include/NeuralNetwork.hpp
new file mode 100644
index 0000000..46e9c1b
--- /dev/null
+++ b/neural_network/include/NeuralNetwork.hpp
@@ -0,0 +1,55 @@
+#ifndef NEURALNETWORK__H
+#define NEURALNETWORK__H
+
+#include <dynet/nodes.h>
+#include <dynet/dynet.h>
+#include <dynet/training.h>
+#include <dynet/timing.h>
+#include <dynet/expr.h>
+#include "FeatureModel.hpp"
+
+class NeuralNetwork
+{
+  public :
+
+  /// @brief Convert a dynet expression to a string (usefull for debug purposes)
+  ///
+  /// @param expr The expression to convert.
+  ///
+  /// @return A string representing the expression.
+  static std::string expression2str(dynet::Expression & expr);
+
+  /// @brief initialize a new untrained NeuralNetwork from a desired topology.
+  ///
+  /// @param nbInputs The size of the input layer of the NeuralNetwork.
+  /// @param topology Description of the NeuralNetwork.
+  /// @param nbOutputs The size of the output layer of the NeuralNetwork.
+  virtual void init(int nbInputs, const std::string & topology, int nbOutputs) = 0;
+
+  /// @brief Give a score to each possible class, given an input.
+  ///
+  /// @param fd The input to use.
+  ///
+  /// @return A vector containing one score per possible class.
+  virtual std::vector<float> predict(FeatureModel::FeatureDescription & fd) = 0;
+
+  /// @brief Update the parameters according to the given gold class.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold class of this input.
+  ///
+  /// @return The loss.
+  virtual float update(FeatureModel::FeatureDescription & fd, int gold) = 0;
+
+  /// @brief Save the NeuralNetwork to a file.
+  /// 
+  /// @param filename The file to write the NeuralNetwork to.
+  virtual void save(const std::string & filename) = 0;
+
+  /// @brief Print the topology of the NeuralNetwork.
+  ///
+  /// @param output Where the topology will be printed.
+  virtual void printTopology(FILE * output) = 0;
+};
+
+#endif
diff --git a/neural_network/src/NeuralNetwork.cpp b/neural_network/src/NeuralNetwork.cpp
new file mode 100644
index 0000000..8be87b8
--- /dev/null
+++ b/neural_network/src/NeuralNetwork.cpp
@@ -0,0 +1,17 @@
+#include "NeuralNetwork.hpp"
+
+std::string NeuralNetwork::expression2str(dynet::Expression & expr)
+{
+  std::string result = "";
+
+  auto elem = dynet::as_vector(expr.value());
+
+  for (auto & f : elem)
+    result += float2str(f, "%f") + " ";
+
+  if (!result.empty())
+  result.pop_back();
+
+  return result;
+}
+
-- 
GitLab