From 0f5a864fcacafb44eb83afdee142d0a0247f0e6c Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 20 Jan 2020 22:31:57 +0100 Subject: [PATCH] Starting to build neural network --- dev/CMakeLists.txt | 1 + dev/src/dev.cpp | 11 +++--- torch_modules/include/TestNetwork.hpp | 22 ++++++++++++ torch_modules/src/TestNetwork.cpp | 52 +++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 torch_modules/include/TestNetwork.hpp create mode 100644 torch_modules/src/TestNetwork.cpp diff --git a/dev/CMakeLists.txt b/dev/CMakeLists.txt index e7c5cc7..a473806 100644 --- a/dev/CMakeLists.txt +++ b/dev/CMakeLists.txt @@ -3,3 +3,4 @@ FILE(GLOB SOURCES src/*.cpp) add_executable(dev src/dev.cpp) target_link_libraries(dev common) target_link_libraries(dev reading_machine) +target_link_libraries(dev torch_modules) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 889b473..ccf3162 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -5,6 +5,7 @@ #include "SubConfig.hpp" #include "TransitionSet.hpp" #include "ReadingMachine.hpp" +#include "TestNetwork.hpp" int main(int argc, char * argv[]) { @@ -26,8 +27,7 @@ int main(int argc, char * argv[]) SubConfig config(goldConfig); config.setState(machine.getStrategy().getInitialState()); - - std::vector<std::pair<SubConfig, Transition*>> trainingExamples; + TestNetwork nn; while (true) { @@ -35,7 +35,10 @@ int main(int argc, char * argv[]) if (!transition) util::myThrow("No transition appliable !"); - trainingExamples.emplace_back(config, transition); + //here train + auto testo = nn(config); + +// std::cout << testo << std::endl; transition->apply(config); config.addToHistory(transition->getName()); @@ -52,8 +55,6 @@ int main(int argc, char * argv[]) config.update(); } - trainingExamples[10000].first.printForDebug(stderr); - return 0; } diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp new file mode 100644 index 0000000..f286a40 --- /dev/null +++ b/torch_modules/include/TestNetwork.hpp @@ -0,0 +1,22 @@ +#ifndef TESTNETWORK__H +#define TESTNETWORK__H + +#include <torch/torch.h> +#include "Config.hpp" + +class TestNetworkImpl : torch::nn::Module +{ + private : + + std::map<Config::String, std::size_t> dict; + torch::nn::Embedding wordEmbeddings{nullptr}; + + public : + + TestNetworkImpl(); + torch::Tensor forward(const Config & config); + std::size_t getOrAddDictValue(Config::String s); +}; +TORCH_MODULE(TestNetwork); + +#endif diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp new file mode 100644 index 0000000..fc3d0e6 --- /dev/null +++ b/torch_modules/src/TestNetwork.cpp @@ -0,0 +1,52 @@ +#include "TestNetwork.hpp" + +TestNetworkImpl::TestNetworkImpl() +{ + getOrAddDictValue(Config::String("_null_")); + getOrAddDictValue(Config::String("_unknown_")); + getOrAddDictValue(Config::String("_S_")); + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100)); +} + +torch::Tensor TestNetworkImpl::forward(const Config & config) +{ +// std::vector<std::size_t> test{0,1}; +// torch::Tensor tens = torch::from_blob(test.data(), {1,2}); +// return wordEmbeddings(tens); + constexpr int windowSize = 5; + int startIndex = config.getWordIndex(); + while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize) + startIndex--; + int endIndex = config.getWordIndex(); + while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize) + endIndex++; + + std::vector<std::size_t> words; + for (int i = startIndex; i <= endIndex; ++i) + { + if (!config.has(0, i, 0)) + util::myThrow(fmt::format("Config do not have line %d", i)); + + words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i))); + } + + if (words.empty()) + util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex)); + + return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong)); +} + +std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s) +{ + if (s.get().empty()) + return dict[Config::String("_null_")]; + + const auto & found = dict.find(s); + + if (found == dict.end()) + return dict[s] = dict.size(); + + return found->second; +} + -- GitLab