Commit d08bf04c authored by Franck Dary's avatar Franck Dary
Browse files

Each SubModule have its own Dict

parent ea3b87d6
......@@ -4,6 +4,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <filesystem>
class Dict
{
......@@ -43,7 +44,7 @@ class Dict
int getIndexOrInsert(const std::string & element);
void setState(State state);
State getState() const;
void save(std::FILE * destination, Encoding encoding) const;
void save(std::filesystem::path path, Encoding encoding) const;
bool readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
std::size_t size() const;
......
......@@ -107,12 +107,18 @@ Dict::State Dict::getState() const
return state;
}
void Dict::save(std::FILE * destination, Encoding encoding) const
void Dict::save(std::filesystem::path path, Encoding encoding) const
{
std::FILE * destination = std::fopen(path.c_str(), "w");
if (!destination)
util::myThrow(fmt::format("could not write file '{}'", path.string()));
fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
for (auto & it : elementsToIndexes)
printEntry(destination, it.second, it.first, encoding);
std::fclose(destination);
}
bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding)
......
......@@ -30,7 +30,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
auto context = machine.getClassifier()->getNN()->extractContext(config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
......
......@@ -65,7 +65,6 @@ int MacaonDecode::main()
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
......@@ -75,8 +74,6 @@ int MacaonDecode::main()
torch::globalContext().setBenchmarkCuDNN(true);
if (dictPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
......@@ -84,7 +81,7 @@ int MacaonDecode::main()
try
{
ReadingMachine machine(machinePath, modelPaths, dictPaths);
ReadingMachine machine(machinePath, modelPaths);
Decoder decoder(machine);
BaseConfig config(mcdFile, inputTSV, inputTXT);
......
......@@ -5,7 +5,6 @@
#include <memory>
#include "Classifier.hpp"
#include "Strategy.hpp"
#include "Dict.hpp"
class ReadingMachine
{
......@@ -14,8 +13,6 @@ class ReadingMachine
static inline const std::string defaultMachineFilename = "machine.rm";
static inline const std::string defaultModelFilename = "{}.pt";
static inline const std::string lastModelFilename = "{}.last";
static inline const std::string defaultDictFilename = "{}.dict";
static inline const std::string defaultDictName = "_default_";
private :
......@@ -23,9 +20,7 @@ class ReadingMachine
std::filesystem::path path;
std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy;
std::map<std::string, Dict> dicts;
std::set<std::string> predicted;
bool _dictsAreNew{false};
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
......@@ -37,13 +32,11 @@ class ReadingMachine
public :
ReadingMachine(std::filesystem::path path);
ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts);
ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models);
TransitionSet & getTransitionSet();
TransitionSet & getSplitWordTransitionSet();
bool hasSplitWordTransitionSet() const;
Strategy & getStrategy();
Dict & getDict(const std::string & state);
std::map<std::string, Dict> & getDicts();
Classifier * getClassifier();
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
......@@ -52,8 +45,10 @@ class ReadingMachine
void saveBest() const;
void saveLast() const;
void saveDicts() const;
bool dictsAreNew() const;
void loadDicts();
void loadLastSaved();
void setCountOcc(bool countOcc);
void removeRareDictElements(float rarityThreshold);
};
#endif
......@@ -84,7 +84,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
if (networkType == "Random")
this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
else if (networkType == "Modular")
initModular(definition, curIndex, nbOutputsPerState);
else
......@@ -135,7 +135,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
modulesDefinitions.emplace_back(definition[curIndex]);
}
this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions));
}
void Classifier::resetOptimizer()
......
......@@ -4,30 +4,14 @@
ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
{
readFromFile(path);
auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
for (auto path : savedDicts)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
if (dicts.count(defaultDictName) == 0)
{
_dictsAreNew = true;
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
}
}
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path)
{
readFromFile(path);
std::size_t maxDictSize = 0;
for (auto path : dicts)
{
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size());
}
classifier->getNN()->registerEmbeddings(maxDictSize);
loadDicts();
classifier->getNN()->registerEmbeddings();
classifier->getNN()->to(NeuralNetworkImpl::device);
if (models.size() > 1)
......@@ -143,19 +127,6 @@ Strategy & ReadingMachine::getStrategy()
return *strategy;
}
Dict & ReadingMachine::getDict(const std::string & state)
{
auto found = dicts.find(state);
try
{
if (found == dicts.end())
return dicts.at(defaultDictName);
} catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
return found->second;
}
Classifier * ReadingMachine::getClassifier()
{
return classifier.get();
......@@ -163,17 +134,12 @@ Classifier * ReadingMachine::getClassifier()
void ReadingMachine::saveDicts() const
{
for (auto & it : dicts)
{
auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
std::FILE * file = std::fopen(pathToDict.c_str(), "w");
if (!file)
util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));
it.second.save(file, Dict::Encoding::Ascii);
classifier->getNN()->saveDicts(path.parent_path());
}
std::fclose(file);
}
void ReadingMachine::loadDicts()
{
classifier->getNN()->loadDicts(path.parent_path());
}
void ReadingMachine::save(const std::string & modelNameTemplate) const
......@@ -211,24 +177,23 @@ void ReadingMachine::trainMode(bool isTrainMode)
void ReadingMachine::setDictsState(Dict::State state)
{
for (auto & it : dicts)
it.second.setState(state);
classifier->getNN()->setDictsState(state);
}
std::map<std::string, Dict> & ReadingMachine::getDicts()
void ReadingMachine::loadLastSaved()
{
return dicts;
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
}
bool ReadingMachine::dictsAreNew() const
void ReadingMachine::setCountOcc(bool countOcc)
{
return _dictsAreNew;
classifier->getNN()->setCountOcc(countOcc);
}
void ReadingMachine::loadLastSaved()
void ReadingMachine::removeRareDictElements(float rarityThreshold)
{
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
classifier->getNN()->removeRareDictElements(rarityThreshold);
}
......@@ -2,7 +2,6 @@
#define CNN__H
#include <torch/torch.h>
#include "fmt/core.h"
class CNNImpl : public torch::nn::Module
{
......
......@@ -20,12 +20,12 @@ class ContextModuleImpl : public Submodule
public :
ContextModuleImpl(const std::string & definition);
ContextModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void registerEmbeddings(std::size_t nbElements) override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(ContextModule);
......
......@@ -21,12 +21,12 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
public :
DepthLayerTreeEmbeddingModuleImpl(const std::string & definition);
DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void registerEmbeddings(std::size_t nbElements) override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
#ifndef DICTHOLDER__H
#define DICTHOLDER__H
#include <memory>
#include <filesystem>
#include "Dict.hpp"
#include "NameHolder.hpp"
class DictHolder : public NameHolder
{
private :
static constexpr char * filenameTemplate = "{}.dict";
std::unique_ptr<Dict> dict;
private :
std::string filename() const;
public :
DictHolder();
void saveDict(std::filesystem::path path);
void loadDict(std::filesystem::path path);
Dict & getDict();
};
#endif
......@@ -20,12 +20,12 @@ class FocusedColumnModuleImpl : public Submodule
public :
FocusedColumnModuleImpl(const std::string & definition);
FocusedColumnModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void registerEmbeddings(std::size_t nbElements) override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(FocusedColumnModule);
......
......@@ -22,10 +22,15 @@ class ModularNetworkImpl : public NeuralNetworkImpl
public :
ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
void registerEmbeddings(std::size_t nbElements) override;
std::vector<std::vector<long>> extractContext(Config & config) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
void setCountOcc(bool countOcc) override;
void removeRareDictElements(float rarityThreshold) override;
};
#endif
#ifndef NAMEHOLDER__H
#define NAMEHOLDER__H
#include <string>
class NameHolder
{
private :
std::string name;
public :
const std::string & getName() const;
void setName(const std::string & name);
};
#endif
......@@ -2,10 +2,11 @@
#define NEURALNETWORK__H
#include <torch/torch.h>
#include <filesystem>
#include "Config.hpp"
#include "Dict.hpp"
#include "NameHolder.hpp"
class NeuralNetworkImpl : public torch::nn::Module
class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
{
public :
......@@ -15,17 +16,18 @@ class NeuralNetworkImpl : public torch::nn::Module
std::string state;
protected :
static constexpr int maxNbEmbeddings = 150000;
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
virtual void registerEmbeddings(std::size_t nbElements) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0;
void setState(const std::string & state);
const std::string & getState() const;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0;
virtual void setCountOcc(bool countOcc) = 0;
virtual void removeRareDictElements(float rarityThreshold) = 0;
};
TORCH_MODULE(NeuralNetwork);
......
......@@ -11,10 +11,15 @@ class RandomNetworkImpl : public NeuralNetworkImpl
public :
RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState);
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config &, Dict &) const override;
void registerEmbeddings(std::size_t nbElements) override;
std::vector<std::vector<long>> extractContext(Config &) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
void setCountOcc(bool countOcc) override;
void removeRareDictElements(float rarityThreshold) override;
};
#endif
......@@ -18,12 +18,12 @@ class RawInputModuleImpl : public Submodule
public :
RawInputModuleImpl(const std::string & definition);
RawInputModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void registerEmbeddings(std::size_t nbElements) override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(RawInputModule);
......
......@@ -18,12 +18,12 @@ class SplitTransModuleImpl : public Submodule
public :
SplitTransModuleImpl(int maxNbTrans, const std::string & definition);
SplitTransModuleImpl(std::string name, int maxNbTrans, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void registerEmbeddings(std::size_t nbElements) override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(SplitTransModule);
......
......@@ -11,18 +11,17 @@ class StateNameModuleImpl : public Submodule
{
private :
std::map<std::string,int> state2index;
torch::nn::Embedding embeddings{nullptr};
int outSize;
public :
StateNameModuleImpl(const std::string & definition);
StateNameModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void registerEmbeddings(std::size_t nbElements) override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(StateNameModule);
......
......@@ -2,10 +2,10 @@
#define SUBMODULE__H
#include <torch/torch.h>
#include "Dict.hpp"
#include "Config.hpp"
#include "DictHolder.hpp"
class Submodule : public torch::nn::Module
class Submodule : public torch::nn::Module, public DictHolder
{
protected :
......@@ -16,9 +16,9 @@ class Submodule : public torch::nn::Module
void setFirstInputIndex(std::size_t firstInputIndex);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual void registerEmbeddings(std::size_t nbElements) = 0;
virtual void registerEmbeddings() = 0;
};
#endif
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment