Commit 0c86cb53 authored by Franck Dary's avatar Franck Dary
Browse files

Removed state from neuralnetwork

parent cdc9ed54
......@@ -39,8 +39,6 @@ void Beam::update(ReadingMachine & machine, bool debug)
auto & classifier = *machine.getClassifier(elements[index].config.getState());
classifier.setState(elements[index].config.getState());
if (machine.hasSplitWordTransitionSet())
elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions));
......@@ -50,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
auto context = classifier.getNN()->extractContext(elements[index].config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0);
float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
std::vector<std::pair<float, int>> scoresOfTransitions;
for (unsigned int i = 0; i < prediction.size(0); i++)
......@@ -123,9 +121,6 @@ void Beam::update(ReadingMachine & machine, bool debug)
continue;
auto & config = element.config;
auto & classifier = *machine.getClassifier(config.getState());
classifier.setState(config.getState());
auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition);
......
......@@ -39,7 +39,6 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
} catch(std::exception & e) {util::myThrow(e.what());}
baseConfig = beam[0].config;
machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState());
if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
{
......
......@@ -22,7 +22,6 @@ class Classifier
std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Optimizer> optimizer;
std::string optimizerType, optimizerParameters;
std::string state;
std::vector<std::string> states;
std::filesystem::path path;
bool regression{false};
......@@ -39,7 +38,7 @@ class Classifier
public :
Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train);
TransitionSet & getTransitionSet();
TransitionSet & getTransitionSet(const std::string & state);
NeuralNetwork & getNN();
const std::string & getName() const;
int getNbParameters() const;
......@@ -47,8 +46,7 @@ class Classifier
void loadOptimizer();
void saveOptimizer();
torch::optim::Optimizer & getOptimizer();
void setState(const std::string & state);
float getLossMultiplier();
float getLossMultiplier(const std::string & state);
const std::vector<std::string> & getStates() const;
void saveDicts();
void saveBest();
......
......@@ -110,7 +110,7 @@ int Classifier::getNbParameters() const
return nbParameters;
}
TransitionSet & Classifier::getTransitionSet()
TransitionSet & Classifier::getTransitionSet(const std::string & state)
{
if (!transitionSets.count(state))
util::myThrow(fmt::format("cannot find transition set for state '{}'", state));
......@@ -196,12 +196,6 @@ torch::optim::Optimizer & Classifier::getOptimizer()
return *optimizer;
}
void Classifier::setState(const std::string & state)
{
this->state = state;
nn->setState(state);
}
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
......@@ -244,7 +238,7 @@ void Classifier::resetOptimizer()
util::myThrow(expected);
}
float Classifier::getLossMultiplier()
float Classifier::getLossMultiplier(const std::string & state)
{
return lossMultipliers.at(state);
}
......
......@@ -95,7 +95,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
TransitionSet & ReadingMachine::getTransitionSet(const std::string & state)
{
return classifiers[state2classifier.at(state)]->getTransitionSet();
return classifiers[state2classifier.at(state)]->getTransitionSet(state);
}
bool ReadingMachine::hasSplitWordTransitionSet() const
......
......@@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
public :
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input) override;
torch::Tensor forward(torch::Tensor input, const std::string & state) override;
std::vector<std::vector<long>> extractContext(Config & config) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
......@@ -37,7 +37,6 @@ class ModularNetworkImpl : public NeuralNetworkImpl
void setDictsState(Dict::State state) override;
void setCountOcc(bool countOcc) override;
void removeRareDictElements(float rarityThreshold) override;
void setState(const std::string & state);
};
#endif
......@@ -5,21 +5,16 @@
#include <filesystem>
#include "Config.hpp"
#include "NameHolder.hpp"
#include "StateHolder.hpp"
class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder
class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
{
public :
static torch::Device device;
private :
std::string state;
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0;
virtual void saveDicts(std::filesystem::path path) = 0;
......
......@@ -12,7 +12,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
public :
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input) override;
torch::Tensor forward(torch::Tensor input, const std::string & state) override;
std::vector<std::vector<long>> extractContext(Config &) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
......
......@@ -5,9 +5,8 @@
#include <filesystem>
#include "Config.hpp"
#include "DictHolder.hpp"
#include "StateHolder.hpp"
class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
class Submodule : public torch::nn::Module, public DictHolder
{
private :
......
......@@ -80,7 +80,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}
torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
......@@ -92,7 +92,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
auto totalInput = inputDropout(torch::cat(outputs, 1));
return outputLayersPerState.at(getState())(mlp(totalInput));
return outputLayersPerState.at(state)(mlp(totalInput));
}
std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
......@@ -149,10 +149,3 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
}
}
void ModularNetworkImpl::setState(const std::string & state)
{
NeuralNetworkImpl::setState(state);
for (auto & mod : modules)
mod->setState(state);
}
......@@ -5,12 +5,12 @@ RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std:
setName(name);
}
torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string & state)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
}
std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
......
......@@ -53,7 +53,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier(config.getState())->setState(config.getState());
while (true)
{
......@@ -94,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
{
auto & classifier = *machine.getClassifier(config.getState());
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction);
std::vector<int> candidates;
......@@ -176,7 +175,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
break;
config.setState(movement.first);
machine.getClassifier(config.getState())->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (config.needsUpdate())
......@@ -217,9 +215,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
if (train)
machine.getClassifier(state)->getOptimizer().zero_grad();
machine.getClassifier(state)->setState(state);
auto prediction = machine.getClassifier(state)->getNN()(data);
auto prediction = machine.getClassifier(state)->getNN()->forward(data, state);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
......@@ -229,7 +225,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
labels /= util::float2longScale;
}
auto loss = machine.getClassifier(state)->getLossMultiplier()*machine.getClassifier(state)->getLossFunction()(prediction, labels);
auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels);
float lossAsFloat = 0.0;
try
{
......@@ -340,7 +336,6 @@ void Trainer::extractActionSequence(BaseConfig & config)
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier(config.getState())->setState(config.getState());
int curSeq = 0;
int curSeqStartIndex = -1;
......@@ -403,7 +398,6 @@ void Trainer::extractActionSequence(BaseConfig & config)
break;
config.setState(movement.first);
machine.getClassifier(config.getState())->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment