diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 5cba1706646caf2ec4142b70b893abc313ad5dde..41f9f1556faf20e5204e6382020c7c514a669a1a 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -5,6 +5,7 @@ #include "RLTNetwork.hpp" #include "CNNNetwork.hpp" #include "LSTMNetwork.hpp" +#include "RandomNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) { @@ -32,6 +33,14 @@ void Classifier::initNeuralNetwork(const std::string & topology) { static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers { + { + std::regex("Random"), + "Random : Output is chosen at random.", + [this,topology](auto sm) + { + this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); + } + }, { std::regex("OneWord\\(([+\\-]?\\d+)\\)"), "OneWord(focusedIndex) : Only use the word embedding of the focused word.", diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e715cc4b920ed5aefe3258244f3ea0d8d888f938 --- /dev/null +++ b/torch_modules/include/RandomNetwork.hpp @@ -0,0 +1,18 @@ +#ifndef RANDOMNETWORK__H +#define RANDOMNETWORK__H + +#include "NeuralNetwork.hpp" + +class RandomNetworkImpl : public NeuralNetworkImpl +{ + private : + + long outputSize; + + public : + + RandomNetworkImpl(long outputSize); + torch::Tensor forward(torch::Tensor input) override; +}; + +#endif diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8e973e4c1af83ba3d2b998956da3172f74e9b572 --- /dev/null +++ b/torch_modules/src/RandomNetwork.cpp @@ -0,0 +1,19 @@ +#include "RandomNetwork.hpp" + +RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize) +{ + setBufferContext({0}); + setStackContext({}); + setBufferFocused({}); + setStackFocused({}); + setColumns({"FORM"}); +} + +torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) +{ + if (input.dim() == 1) + input = input.unsqueeze(0); + + return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true)); +} +