Skip to content
Snippets Groups Projects
Commit 82827388 authored by Franck Dary's avatar Franck Dary
Browse files

Added base class to neural network

parent 71909cd6
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
config.printForDebug(stderr);
auto dictState = machine.getDict(config.getState()).getState();
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
......
......@@ -3,7 +3,7 @@
#include <string>
#include "TransitionSet.hpp"
#include "TestNetwork.hpp"
#include "NeuralNetwork.hpp"
class Classifier
{
......@@ -11,13 +11,17 @@ class Classifier
std::string name;
std::unique_ptr<TransitionSet> transitionSet;
TestNetwork nn{nullptr};
std::shared_ptr<NeuralNetworkImpl> nn;
private :
void initNeuralNetwork(const std::string & topology);
public :
Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
TransitionSet & getTransitionSet();
TestNetwork & getNN();
NeuralNetwork & getNN();
const std::string & getName() const;
};
......
......@@ -107,7 +107,6 @@ class Config
String getState() const;
void setState(const std::string state);
bool stateIsDone() const;
std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
void addPredicted(const std::set<std::string> & predicted);
bool isPredicted(const std::string & colName) const;
int getLastPoppedStack() const;
......
#include "Classifier.hpp"
#include "util.hpp"
#include "OneWordNetwork.hpp"
#include "ConcatWordsNetwork.hpp"
Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
{
this->name = name;
this->transitionSet.reset(new TransitionSet(tsFile));
this->nn = TestNetwork(transitionSet->size(), 5);
initNeuralNetwork(topology);
}
TransitionSet & Classifier::getTransitionSet()
......@@ -12,9 +15,9 @@ TransitionSet & Classifier::getTransitionSet()
return *transitionSet;
}
TestNetwork & Classifier::getNN()
NeuralNetwork & Classifier::getNN()
{
return nn;
return reinterpret_cast<NeuralNetwork&>(nn);
}
const std::string & Classifier::getName() const
......@@ -22,3 +25,36 @@ const std::string & Classifier::getName() const
return name;
}
void Classifier::initNeuralNetwork(const std::string & topology)
{
static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
{
{
std::regex("OneWord\\((\\d+)\\)"),
"OneWord(focusedIndex) : Only use the word embedding of the focused word.",
[this,topology](auto sm)
{
this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1])));
}
},
{
std::regex("ConcatWords"),
"ConcatWords : Concatenate embeddings of words in context.",
[this,topology](auto)
{
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size()));
}
},
};
for (auto & initializer : initializers)
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
return;
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
for (auto & initializer : initializers)
errorMessage += std::get<1>(initializer) + "\n";
util::myThrow(errorMessage);
}
......@@ -455,33 +455,6 @@ bool Config::stateIsDone() const
return !has(0, wordIndex+1, 0) and !hasStack(0);
}
std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
{
std::stack<int> leftContext;
for (int index = wordIndex-1; has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
if (isToken(index))
leftContext.push(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
std::vector<long> context;
while ((int)context.size() < leftBorder-(int)leftContext.size())
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (!leftContext.empty())
{
context.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = wordIndex; has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
if (isToken(index))
context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
while ((int)context.size() < leftBorder+rightBorder+1)
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
return context;
}
void Config::addPredicted(const std::set<std::string> & predicted)
{
this->predicted.insert(predicted.begin(), predicted.end());
......
#ifndef CONCATWORDSNETWORK__H
#define CONCATWORDSNETWORK__H
#include "NeuralNetwork.hpp"
class ConcatWordsNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr};
std::vector<torch::Tensor> _denseParameters;
std::vector<torch::Tensor> _sparseParameters;
public :
ConcatWordsNetworkImpl(int nbOutputs);
torch::Tensor forward(torch::Tensor input) override;
std::vector<torch::Tensor> & denseParameters() override;
std::vector<torch::Tensor> & sparseParameters() override;
};
#endif
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H
#include <torch/torch.h>
#include "Config.hpp"
#include "Dict.hpp"
class NeuralNetworkImpl : public torch::nn::Module
{
private :
int leftBorder{5};
int rightBorder{5};
public :
virtual std::vector<torch::Tensor> & denseParameters() = 0;
virtual std::vector<torch::Tensor> & sparseParameters() = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
std::vector<long> extractContext(Config & config, Dict & dict) const;
int getContextSize() const;
};
TORCH_MODULE(NeuralNetwork);
#endif
#ifndef TESTNETWORK__H
#define TESTNETWORK__H
#ifndef ONEWORDNETWORK__H
#define ONEWORDNETWORK__H
#include <torch/torch.h>
#include "Config.hpp"
#include "NeuralNetwork.hpp"
class TestNetworkImpl : public torch::nn::Module
class OneWordNetworkImpl : public NeuralNetworkImpl
{
private :
......@@ -17,11 +16,10 @@ class TestNetworkImpl : public torch::nn::Module
public :
TestNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input);
std::vector<torch::Tensor> & denseParameters();
std::vector<torch::Tensor> & sparseParameters();
OneWordNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input) override;
std::vector<torch::Tensor> & denseParameters() override;
std::vector<torch::Tensor> & sparseParameters() override;
};
TORCH_MODULE(TestNetwork);
#endif
#include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs)
{
constexpr int embeddingsSize = 30;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
auto params = wordEmbeddings->parameters();
_sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
params = linear->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
// reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0);
return res;
}
#include "NeuralNetwork.hpp"
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
if (config.isToken(index))
leftContext.push(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
std::vector<long> context;
while ((int)context.size() < leftBorder-(int)leftContext.size())
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (!leftContext.empty())
{
context.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index))
context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
while ((int)context.size() < leftBorder+rightBorder+1)
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
return context;
}
int NeuralNetworkImpl::getContextSize() const
{
return 1 + leftBorder + rightBorder;
}
#include "TestNetwork.hpp"
#include "OneWordNetwork.hpp"
TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
{
constexpr int embeddingsSize = 30;
......@@ -15,17 +15,17 @@ TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
this->focusedIndex = focusedIndex;
}
std::vector<torch::Tensor> & TestNetworkImpl::denseParameters()
std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters()
std::vector<torch::Tensor> & OneWordNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
......
......@@ -4,7 +4,6 @@
#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"
#include "TestNetwork.hpp"
class Trainer
{
......
......@@ -25,7 +25,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
util::myThrow("No transition appliable !");
}
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment