Skip to content
Snippets Groups Projects
Commit 8946a076 authored by Franck Dary's avatar Franck Dary
Browse files

Added RandomNetwork

parent a43b9993
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@
#include "RLTNetwork.hpp"
#include "CNNNetwork.hpp"
#include "LSTMNetwork.hpp"
#include "RandomNetwork.hpp"
Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
{
......@@ -32,6 +33,14 @@ void Classifier::initNeuralNetwork(const std::string & topology)
{
static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
{
{
std::regex("Random"),
"Random : Output is chosen at random.",
[this,topology](auto sm)
{
this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
}
},
{
std::regex("OneWord\\(([+\\-]?\\d+)\\)"),
"OneWord(focusedIndex) : Only use the word embedding of the focused word.",
......
#ifndef RANDOMNETWORK__H
#define RANDOMNETWORK__H
#include "NeuralNetwork.hpp"
class RandomNetworkImpl : public NeuralNetworkImpl
{
private :
long outputSize;
public :
RandomNetworkImpl(long outputSize);
torch::Tensor forward(torch::Tensor input) override;
};
#endif
#include "RandomNetwork.hpp"
RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize)
{
setBufferContext({0});
setStackContext({});
setBufferFocused({});
setStackFocused({});
setColumns({"FORM"});
}
torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment