From 8946a076bb67d86e69407141cd95690ac4f81283 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 30 Mar 2020 18:08:12 +0200 Subject: [PATCH] Added RandomNetwork --- reading_machine/src/Classifier.cpp | 9 +++++++++ torch_modules/include/RandomNetwork.hpp | 18 ++++++++++++++++++ torch_modules/src/RandomNetwork.cpp | 19 +++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 torch_modules/include/RandomNetwork.hpp create mode 100644 torch_modules/src/RandomNetwork.cpp diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 5cba170..41f9f15 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -5,6 +5,7 @@ #include "RLTNetwork.hpp" #include "CNNNetwork.hpp" #include "LSTMNetwork.hpp" +#include "RandomNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) { @@ -32,6 +33,14 @@ void Classifier::initNeuralNetwork(const std::string & topology) { static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers { + { + std::regex("Random"), + "Random : Output is chosen at random.", + [this,topology](auto sm) + { + this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); + } + }, { std::regex("OneWord\\(([+\\-]?\\d+)\\)"), "OneWord(focusedIndex) : Only use the word embedding of the focused word.", diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp new file mode 100644 index 0000000..e715cc4 --- /dev/null +++ b/torch_modules/include/RandomNetwork.hpp @@ -0,0 +1,18 @@ +#ifndef RANDOMNETWORK__H +#define RANDOMNETWORK__H + +#include "NeuralNetwork.hpp" + +class RandomNetworkImpl : public NeuralNetworkImpl +{ + private : + + long outputSize; + + public : + + RandomNetworkImpl(long outputSize); + torch::Tensor forward(torch::Tensor input) override; +}; + +#endif diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp new file mode 100644 index 0000000..8e973e4 --- /dev/null +++ b/torch_modules/src/RandomNetwork.cpp @@ -0,0 +1,19 @@ +#include "RandomNetwork.hpp" + +RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize) +{ + setBufferContext({0}); + setStackContext({}); + setBufferFocused({}); + setStackFocused({}); + setColumns({"FORM"}); +} + +torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) +{ + if (input.dim() == 1) + input = input.unsqueeze(0); + + return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true)); +} + -- GitLab