Skip to content
Snippets Groups Projects
Commit d08bf04c authored by Franck Dary's avatar Franck Dary
Browse files

Each SubModule have its own Dict

parent ea3b87d6
Branches
No related tags found
No related merge requests found
Showing
with 131 additions and 108 deletions
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <filesystem>
class Dict class Dict
{ {
...@@ -43,7 +44,7 @@ class Dict ...@@ -43,7 +44,7 @@ class Dict
int getIndexOrInsert(const std::string & element); int getIndexOrInsert(const std::string & element);
void setState(State state); void setState(State state);
State getState() const; State getState() const;
void save(std::FILE * destination, Encoding encoding) const; void save(std::filesystem::path path, Encoding encoding) const;
bool readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding); bool readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const; void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
std::size_t size() const; std::size_t size() const;
......
...@@ -107,12 +107,18 @@ Dict::State Dict::getState() const ...@@ -107,12 +107,18 @@ Dict::State Dict::getState() const
return state; return state;
} }
void Dict::save(std::FILE * destination, Encoding encoding) const void Dict::save(std::filesystem::path path, Encoding encoding) const
{ {
std::FILE * destination = std::fopen(path.c_str(), "w");
if (!destination)
util::myThrow(fmt::format("could not write file '{}'", path.string()));
fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary"); fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size()); fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
for (auto & it : elementsToIndexes) for (auto & it : elementsToIndexes)
printEntry(destination, it.second, it.first, encoding); printEntry(destination, it.second, it.first, encoding);
std::fclose(destination);
} }
bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding) bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding)
......
...@@ -30,7 +30,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -30,7 +30,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (machine.hasSplitWordTransitionSet()) if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); auto context = machine.getClassifier()->getNN()->extractContext(config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
......
...@@ -65,7 +65,6 @@ int MacaonDecode::main() ...@@ -65,7 +65,6 @@ int MacaonDecode::main()
std::filesystem::path modelPath(variables["model"].as<std::string>()); std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / ReadingMachine::defaultMachineFilename; auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, "")); auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
...@@ -75,8 +74,6 @@ int MacaonDecode::main() ...@@ -75,8 +74,6 @@ int MacaonDecode::main()
torch::globalContext().setBenchmarkCuDNN(true); torch::globalContext().setBenchmarkCuDNN(true);
if (dictPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
if (modelPaths.empty()) if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
...@@ -84,7 +81,7 @@ int MacaonDecode::main() ...@@ -84,7 +81,7 @@ int MacaonDecode::main()
try try
{ {
ReadingMachine machine(machinePath, modelPaths, dictPaths); ReadingMachine machine(machinePath, modelPaths);
Decoder decoder(machine); Decoder decoder(machine);
BaseConfig config(mcdFile, inputTSV, inputTXT); BaseConfig config(mcdFile, inputTSV, inputTXT);
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <memory> #include <memory>
#include "Classifier.hpp" #include "Classifier.hpp"
#include "Strategy.hpp" #include "Strategy.hpp"
#include "Dict.hpp"
class ReadingMachine class ReadingMachine
{ {
...@@ -14,8 +13,6 @@ class ReadingMachine ...@@ -14,8 +13,6 @@ class ReadingMachine
static inline const std::string defaultMachineFilename = "machine.rm"; static inline const std::string defaultMachineFilename = "machine.rm";
static inline const std::string defaultModelFilename = "{}.pt"; static inline const std::string defaultModelFilename = "{}.pt";
static inline const std::string lastModelFilename = "{}.last"; static inline const std::string lastModelFilename = "{}.last";
static inline const std::string defaultDictFilename = "{}.dict";
static inline const std::string defaultDictName = "_default_";
private : private :
...@@ -23,9 +20,7 @@ class ReadingMachine ...@@ -23,9 +20,7 @@ class ReadingMachine
std::filesystem::path path; std::filesystem::path path;
std::unique_ptr<Classifier> classifier; std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy; std::unique_ptr<Strategy> strategy;
std::map<std::string, Dict> dicts;
std::set<std::string> predicted; std::set<std::string> predicted;
bool _dictsAreNew{false};
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
...@@ -37,13 +32,11 @@ class ReadingMachine ...@@ -37,13 +32,11 @@ class ReadingMachine
public : public :
ReadingMachine(std::filesystem::path path); ReadingMachine(std::filesystem::path path);
ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts); ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models);
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
TransitionSet & getSplitWordTransitionSet(); TransitionSet & getSplitWordTransitionSet();
bool hasSplitWordTransitionSet() const; bool hasSplitWordTransitionSet() const;
Strategy & getStrategy(); Strategy & getStrategy();
Dict & getDict(const std::string & state);
std::map<std::string, Dict> & getDicts();
Classifier * getClassifier(); Classifier * getClassifier();
bool isPredicted(const std::string & columnName) const; bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const; const std::set<std::string> & getPredicted() const;
...@@ -52,8 +45,10 @@ class ReadingMachine ...@@ -52,8 +45,10 @@ class ReadingMachine
void saveBest() const; void saveBest() const;
void saveLast() const; void saveLast() const;
void saveDicts() const; void saveDicts() const;
bool dictsAreNew() const; void loadDicts();
void loadLastSaved(); void loadLastSaved();
void setCountOcc(bool countOcc);
void removeRareDictElements(float rarityThreshold);
}; };
#endif #endif
...@@ -84,7 +84,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) ...@@ -84,7 +84,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
if (networkType == "Random") if (networkType == "Random")
this->nn.reset(new RandomNetworkImpl(nbOutputsPerState)); this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
else if (networkType == "Modular") else if (networkType == "Modular")
initModular(definition, curIndex, nbOutputsPerState); initModular(definition, curIndex, nbOutputsPerState);
else else
...@@ -135,7 +135,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s ...@@ -135,7 +135,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
modulesDefinitions.emplace_back(definition[curIndex]); modulesDefinitions.emplace_back(definition[curIndex]);
} }
this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions)); this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions));
} }
void Classifier::resetOptimizer() void Classifier::resetOptimizer()
......
...@@ -4,30 +4,14 @@ ...@@ -4,30 +4,14 @@
ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
{ {
readFromFile(path); readFromFile(path);
auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
for (auto path : savedDicts)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
if (dicts.count(defaultDictName) == 0)
{
_dictsAreNew = true;
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
}
} }
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts) ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path)
{ {
readFromFile(path); readFromFile(path);
std::size_t maxDictSize = 0; loadDicts();
for (auto path : dicts) classifier->getNN()->registerEmbeddings();
{
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size());
}
classifier->getNN()->registerEmbeddings(maxDictSize);
classifier->getNN()->to(NeuralNetworkImpl::device); classifier->getNN()->to(NeuralNetworkImpl::device);
if (models.size() > 1) if (models.size() > 1)
...@@ -143,19 +127,6 @@ Strategy & ReadingMachine::getStrategy() ...@@ -143,19 +127,6 @@ Strategy & ReadingMachine::getStrategy()
return *strategy; return *strategy;
} }
Dict & ReadingMachine::getDict(const std::string & state)
{
auto found = dicts.find(state);
try
{
if (found == dicts.end())
return dicts.at(defaultDictName);
} catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
return found->second;
}
Classifier * ReadingMachine::getClassifier() Classifier * ReadingMachine::getClassifier()
{ {
return classifier.get(); return classifier.get();
...@@ -163,17 +134,12 @@ Classifier * ReadingMachine::getClassifier() ...@@ -163,17 +134,12 @@ Classifier * ReadingMachine::getClassifier()
void ReadingMachine::saveDicts() const void ReadingMachine::saveDicts() const
{ {
for (auto & it : dicts) classifier->getNN()->saveDicts(path.parent_path());
{
auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
std::FILE * file = std::fopen(pathToDict.c_str(), "w");
if (!file)
util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));
it.second.save(file, Dict::Encoding::Ascii);
std::fclose(file);
} }
void ReadingMachine::loadDicts()
{
classifier->getNN()->loadDicts(path.parent_path());
} }
void ReadingMachine::save(const std::string & modelNameTemplate) const void ReadingMachine::save(const std::string & modelNameTemplate) const
...@@ -211,24 +177,23 @@ void ReadingMachine::trainMode(bool isTrainMode) ...@@ -211,24 +177,23 @@ void ReadingMachine::trainMode(bool isTrainMode)
void ReadingMachine::setDictsState(Dict::State state) void ReadingMachine::setDictsState(Dict::State state)
{ {
for (auto & it : dicts) classifier->getNN()->setDictsState(state);
it.second.setState(state);
} }
std::map<std::string, Dict> & ReadingMachine::getDicts() void ReadingMachine::loadLastSaved()
{ {
return dicts; auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
} }
bool ReadingMachine::dictsAreNew() const void ReadingMachine::setCountOcc(bool countOcc)
{ {
return _dictsAreNew; classifier->getNN()->setCountOcc(countOcc);
} }
void ReadingMachine::loadLastSaved() void ReadingMachine::removeRareDictElements(float rarityThreshold)
{ {
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); classifier->getNN()->removeRareDictElements(rarityThreshold);
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
} }
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define CNN__H #define CNN__H
#include <torch/torch.h> #include <torch/torch.h>
#include "fmt/core.h"
class CNNImpl : public torch::nn::Module class CNNImpl : public torch::nn::Module
{ {
......
...@@ -20,12 +20,12 @@ class ContextModuleImpl : public Submodule ...@@ -20,12 +20,12 @@ class ContextModuleImpl : public Submodule
public : public :
ContextModuleImpl(const std::string & definition); ContextModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(ContextModule); TORCH_MODULE(ContextModule);
......
...@@ -21,12 +21,12 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule ...@@ -21,12 +21,12 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
public : public :
DepthLayerTreeEmbeddingModuleImpl(const std::string & definition); DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(DepthLayerTreeEmbeddingModule); TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
#ifndef DICTHOLDER__H
#define DICTHOLDER__H
#include <memory>
#include <filesystem>
#include "Dict.hpp"
#include "NameHolder.hpp"
class DictHolder : public NameHolder
{
private :
static constexpr char * filenameTemplate = "{}.dict";
std::unique_ptr<Dict> dict;
private :
std::string filename() const;
public :
DictHolder();
void saveDict(std::filesystem::path path);
void loadDict(std::filesystem::path path);
Dict & getDict();
};
#endif
...@@ -20,12 +20,12 @@ class FocusedColumnModuleImpl : public Submodule ...@@ -20,12 +20,12 @@ class FocusedColumnModuleImpl : public Submodule
public : public :
FocusedColumnModuleImpl(const std::string & definition); FocusedColumnModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(FocusedColumnModule); TORCH_MODULE(FocusedColumnModule);
......
...@@ -22,10 +22,15 @@ class ModularNetworkImpl : public NeuralNetworkImpl ...@@ -22,10 +22,15 @@ class ModularNetworkImpl : public NeuralNetworkImpl
public : public :
ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; std::vector<std::vector<long>> extractContext(Config & config) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
void setCountOcc(bool countOcc) override;
void removeRareDictElements(float rarityThreshold) override;
}; };
#endif #endif
#ifndef NAMEHOLDER__H
#define NAMEHOLDER__H
#include <string>
class NameHolder
{
private :
std::string name;
public :
const std::string & getName() const;
void setName(const std::string & name);
};
#endif
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#define NEURALNETWORK__H #define NEURALNETWORK__H
#include <torch/torch.h> #include <torch/torch.h>
#include <filesystem>
#include "Config.hpp" #include "Config.hpp"
#include "Dict.hpp" #include "NameHolder.hpp"
class NeuralNetworkImpl : public torch::nn::Module class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
{ {
public : public :
...@@ -15,17 +16,18 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -15,17 +16,18 @@ class NeuralNetworkImpl : public torch::nn::Module
std::string state; std::string state;
protected :
static constexpr int maxNbEmbeddings = 150000;
public : public :
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings(std::size_t nbElements) = 0; virtual void registerEmbeddings() = 0;
void setState(const std::string & state); void setState(const std::string & state);
const std::string & getState() const; const std::string & getState() const;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0;
virtual void setCountOcc(bool countOcc) = 0;
virtual void removeRareDictElements(float rarityThreshold) = 0;
}; };
TORCH_MODULE(NeuralNetwork); TORCH_MODULE(NeuralNetwork);
......
...@@ -11,10 +11,15 @@ class RandomNetworkImpl : public NeuralNetworkImpl ...@@ -11,10 +11,15 @@ class RandomNetworkImpl : public NeuralNetworkImpl
public : public :
RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState); RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config &, Dict &) const override; std::vector<std::vector<long>> extractContext(Config &) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
void setCountOcc(bool countOcc) override;
void removeRareDictElements(float rarityThreshold) override;
}; };
#endif #endif
...@@ -18,12 +18,12 @@ class RawInputModuleImpl : public Submodule ...@@ -18,12 +18,12 @@ class RawInputModuleImpl : public Submodule
public : public :
RawInputModuleImpl(const std::string & definition); RawInputModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(RawInputModule); TORCH_MODULE(RawInputModule);
......
...@@ -18,12 +18,12 @@ class SplitTransModuleImpl : public Submodule ...@@ -18,12 +18,12 @@ class SplitTransModuleImpl : public Submodule
public : public :
SplitTransModuleImpl(int maxNbTrans, const std::string & definition); SplitTransModuleImpl(std::string name, int maxNbTrans, const std::string & definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(SplitTransModule); TORCH_MODULE(SplitTransModule);
......
...@@ -11,18 +11,17 @@ class StateNameModuleImpl : public Submodule ...@@ -11,18 +11,17 @@ class StateNameModuleImpl : public Submodule
{ {
private : private :
std::map<std::string,int> state2index;
torch::nn::Embedding embeddings{nullptr}; torch::nn::Embedding embeddings{nullptr};
int outSize; int outSize;
public : public :
StateNameModuleImpl(const std::string & definition); StateNameModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::size_t nbElements) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(StateNameModule); TORCH_MODULE(StateNameModule);
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
#define SUBMODULE__H #define SUBMODULE__H
#include <torch/torch.h> #include <torch/torch.h>
#include "Dict.hpp"
#include "Config.hpp" #include "Config.hpp"
#include "DictHolder.hpp"
class Submodule : public torch::nn::Module class Submodule : public torch::nn::Module, public DictHolder
{ {
protected : protected :
...@@ -16,9 +16,9 @@ class Submodule : public torch::nn::Module ...@@ -16,9 +16,9 @@ class Submodule : public torch::nn::Module
void setFirstInputIndex(std::size_t firstInputIndex); void setFirstInputIndex(std::size_t firstInputIndex);
virtual std::size_t getOutputSize() = 0; virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0; virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual void registerEmbeddings(std::size_t nbElements) = 0; virtual void registerEmbeddings() = 0;
}; };
#endif #endif
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment