Skip to content
Snippets Groups Projects
Commit 1b96cca2 authored by Franck Dary's avatar Franck Dary
Browse files

Aded AppliableTransModule

parent 401e71e6
No related branches found
No related tags found
No related merge requests found
Showing
with 154 additions and 16 deletions
......@@ -29,6 +29,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
auto context = machine.getClassifier()->getNN()->extractContext(config).back();
......@@ -45,7 +47,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
for (unsigned int i = 0; i < softmaxed.size(0); i++)
{
float score = softmaxed[i].item<float>();
std::string nicePrint = fmt::format("{} {:7.2f} {}", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
toPrint.emplace_back(std::make_pair(score,nicePrint));
}
std::sort(toPrint.rbegin(), toPrint.rend());
......@@ -58,7 +60,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[i])
{
chosenTransition = i;
bestScore = score;
......
......@@ -47,6 +47,7 @@ class Config
int lastPoppedStack{-1};
int currentWordId{0};
std::vector<Transition *> appliableSplitTransitions;
std::vector<int> appliableTransitions;
protected :
......@@ -145,7 +146,9 @@ class Config
void addMissingColumns();
void addComment();
void setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions);
void setAppliableTransitions(const std::vector<int> & appliableTransitions);
const std::vector<Transition *> & getAppliableSplitTransitions() const;
const std::vector<int> & getAppliableTransitions() const;
bool isExtraColumn(const std::string & colName) const;
};
......
......@@ -23,6 +23,7 @@ class TransitionSet
std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
Transition * getBestAppliableTransition(const Config & c);
std::vector<Transition *> getNAppliableTransitions(const Config & c, int n);
std::vector<int> getAppliableTransitions(const Config & c);
std::size_t getTransitionIndex(const Transition * transition) const;
Transition * getTransition(std::size_t index);
Transition * getTransition(const std::string & name);
......
......@@ -662,11 +662,21 @@ void Config::setAppliableSplitTransitions(const std::vector<Transition *> & appl
this->appliableSplitTransitions = appliableSplitTransitions;
}
void Config::setAppliableTransitions(const std::vector<int> & appliableTransitions)
{
this->appliableTransitions = appliableTransitions;
}
const std::vector<Transition *> & Config::getAppliableSplitTransitions() const
{
return appliableSplitTransitions;
}
const std::vector<int> & Config::getAppliableTransitions() const
{
return appliableTransitions;
}
Config::Object Config::str2object(const std::string & s)
{
if (s == "b")
......
......@@ -67,6 +67,19 @@ std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config &
return result;
}
std::vector<int> TransitionSet::getAppliableTransitions(const Config & c)
{
std::vector<int> result;
for (unsigned int i = 0; i < transitions.size(); i++)
if (transitions[i].appliable(c))
result.emplace_back(1);
else
result.emplace_back(0);
return result;
}
Transition * TransitionSet::getBestAppliableTransition(const Config & c)
{
Transition * result = nullptr;
......
#ifndef APPLIABLETRANSRANSMODULE__H
#define APPLIABLETRANSRANSMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
class AppliableTransModuleImpl : public Submodule
{
private :
int nbTrans;
public :
AppliableTransModuleImpl(std::string name, int nbTrans);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(AppliableTransModule);
#endif
......@@ -5,6 +5,7 @@
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "AppliableTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "StateNameModule.hpp"
......@@ -33,6 +34,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
void setDictsState(Dict::State state) override;
void setCountOcc(bool countOcc) override;
void removeRareDictElements(float rarityThreshold) override;
void setState(const std::string & state);
};
#endif
......@@ -5,8 +5,9 @@
#include <filesystem>
#include "Config.hpp"
#include "NameHolder.hpp"
#include "StateHolder.hpp"
class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder
{
public :
......@@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0;
void setState(const std::string & state);
const std::string & getState() const;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0;
......
#ifndef STATEHOLDER__H
#define STATEHOLDER__H
#include <string>
class StateHolder
{
private :
std::string state;
public :
const std::string & getState() const;
void setState(const std::string & state);
};
#endif
......@@ -4,8 +4,9 @@
#include <torch/torch.h>
#include "Config.hpp"
#include "DictHolder.hpp"
#include "StateHolder.hpp"
class Submodule : public torch::nn::Module, public DictHolder
class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
{
protected :
......
#include "AppliableTransModule.hpp"
AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans)
{
setName(name);
}
torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input)
{
return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat);
}
std::size_t AppliableTransModuleImpl::getOutputSize()
{
return nbTrans;
}
std::size_t AppliableTransModuleImpl::getInputSize()
{
return nbTrans;
}
void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
auto & appliableTrans = config.getAppliableTransitions();
for (auto & contextElement : context)
for (int i = 0; i < nbTrans; i++)
if (i < (int)appliableTrans.size())
contextElement.emplace_back(appliableTrans[i]);
else
contextElement.emplace_back(0);
}
void AppliableTransModuleImpl::registerEmbeddings()
{
}
......@@ -15,6 +15,10 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
return result;
};
std::size_t maxNbOutputs = 0;
for (auto & it : nbOutputsPerState)
maxNbOutputs = std::max<std::size_t>(it.second, maxNbOutputs);
int currentInputSize = 0;
int currentOutputSize = 0;
std::string mlpDef;
......@@ -37,6 +41,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
else if (splited.first == "SplitTrans")
modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
else if (splited.first == "AppliableTrans")
modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs)));
else if (splited.first == "DepthLayerTree")
modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
else if (splited.first == "MLP")
......@@ -134,3 +140,10 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
}
}
void ModularNetworkImpl::setState(const std::string & state)
{
NeuralNetworkImpl::setState(state);
for (auto & mod : modules)
mod->setState(state);
}
......@@ -2,13 +2,3 @@
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
void NeuralNetworkImpl::setState(const std::string & state)
{
this->state = state;
}
const std::string & NeuralNetworkImpl::getState() const
{
return state;
}
#include "StateHolder.hpp"
#include "util.hpp"
const std::string & StateHolder::getState() const
{
if (state.empty())
util::myThrow("state is empty");
return state;
}
void StateHolder::setState(const std::string & state)
{
this->state = state;
}
......@@ -70,6 +70,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
std::vector<std::vector<long>> context;
......@@ -300,6 +302,8 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
try
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment