diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 353c333aa8b9c8a76c0b55de28e9d35b7b5ceb24..87741cbf1b6d44660413515796169985768b7912 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -4,6 +4,7 @@ #include <string> #include <unordered_map> #include <vector> +#include <filesystem> class Dict { @@ -43,7 +44,7 @@ class Dict int getIndexOrInsert(const std::string & element); void setState(State state); State getState() const; - void save(std::FILE * destination, Encoding encoding) const; + void save(std::filesystem::path path, Encoding encoding) const; bool readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding); void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const; std::size_t size() const; diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 4546702b80a1c2fddb76d75a0bc301039f3c6059..a4c060c1b46135c471e60daf5b7bf744e4324f6e 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -107,12 +107,18 @@ Dict::State Dict::getState() const return state; } -void Dict::save(std::FILE * destination, Encoding encoding) const +void Dict::save(std::filesystem::path path, Encoding encoding) const { + std::FILE * destination = std::fopen(path.c_str(), "w"); + if (!destination) + util::myThrow(fmt::format("could not write file '{}'", path.string())); + fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary"); fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size()); for (auto & it : elementsToIndexes) printEntry(destination, it.second, it.first, encoding); + + std::fclose(destination); } bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 1e24214877537af62f25e99fdbf2742315fae15c..fb3738f9117e1e35e672e35827f474095fde71e8 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -30,7 +30,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); + auto context = machine.getClassifier()->getNN()->extractContext(config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 39a92937b3fc4820a08a353cf32c16e11c75307a..8aac97762dac362506ccd3ecc1e4330f5a9e1903 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -65,7 +65,6 @@ int MacaonDecode::main() std::filesystem::path modelPath(variables["model"].as<std::string>()); auto machinePath = modelPath / ReadingMachine::defaultMachineFilename; - auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, "")); auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, "")); auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; @@ -75,8 +74,6 @@ int MacaonDecode::main() torch::globalContext().setBenchmarkCuDNN(true); - if (dictPaths.empty()) - util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, ""))); if (modelPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); @@ -84,7 +81,7 @@ int MacaonDecode::main() try { - ReadingMachine machine(machinePath, modelPaths, dictPaths); + ReadingMachine machine(machinePath, modelPaths); Decoder decoder(machine); BaseConfig config(mcdFile, inputTSV, inputTXT); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 34f174569e1336d7417c29a63dc9b9a54f9f8d34..f63c21aa7899250fdf11865ce6d3e0fe36a5b4ba 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -5,7 +5,6 @@ #include <memory> #include "Classifier.hpp" #include "Strategy.hpp" -#include "Dict.hpp" class ReadingMachine { @@ -14,8 +13,6 @@ class ReadingMachine static inline const std::string defaultMachineFilename = "machine.rm"; static inline const std::string defaultModelFilename = "{}.pt"; static inline const std::string lastModelFilename = "{}.last"; - static inline const std::string defaultDictFilename = "{}.dict"; - static inline const std::string defaultDictName = "_default_"; private : @@ -23,9 +20,7 @@ class ReadingMachine std::filesystem::path path; std::unique_ptr<Classifier> classifier; std::unique_ptr<Strategy> strategy; - std::map<std::string, Dict> dicts; std::set<std::string> predicted; - bool _dictsAreNew{false}; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; @@ -37,13 +32,11 @@ class ReadingMachine public : ReadingMachine(std::filesystem::path path); - ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts); + ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models); TransitionSet & getTransitionSet(); TransitionSet & getSplitWordTransitionSet(); bool hasSplitWordTransitionSet() const; Strategy & getStrategy(); - Dict & getDict(const std::string & state); - std::map<std::string, Dict> & getDicts(); Classifier * getClassifier(); bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; @@ -52,8 +45,10 @@ class ReadingMachine void saveBest() const; void saveLast() const; void saveDicts() const; - bool dictsAreNew() const; + void loadDicts(); void loadLastSaved(); + void setCountOcc(bool countOcc); + void removeRareDictElements(float rarityThreshold); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index b33a2dfcf338ce26ba5439e15a8921234c4a86a4..cb1c29e558170a360bea9ff51a9deef65060a0af 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -84,7 +84,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType")); if (networkType == "Random") - this->nn.reset(new RandomNetworkImpl(nbOutputsPerState)); + this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState)); else if (networkType == "Modular") initModular(definition, curIndex, nbOutputsPerState); else @@ -135,7 +135,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s modulesDefinitions.emplace_back(definition[curIndex]); } - this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions)); + this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions)); } void Classifier::resetOptimizer() diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index f402961779285500730d0a02ecc08d37ee6912b0..2078c66a69ae788c9e292d45df57b98cc2c88415 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -4,30 +4,14 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) { readFromFile(path); - - auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, "")); - - for (auto path : savedDicts) - this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open}); - - if (dicts.count(defaultDictName) == 0) - { - _dictsAreNew = true; - dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open)); - } } -ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts) +ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path) { readFromFile(path); - std::size_t maxDictSize = 0; - for (auto path : dicts) - { - this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed}); - maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size()); - } - classifier->getNN()->registerEmbeddings(maxDictSize); + loadDicts(); + classifier->getNN()->registerEmbeddings(); classifier->getNN()->to(NeuralNetworkImpl::device); if (models.size() > 1) @@ -143,19 +127,6 @@ Strategy & ReadingMachine::getStrategy() return *strategy; } -Dict & ReadingMachine::getDict(const std::string & state) -{ - auto found = dicts.find(state); - - try - { - if (found == dicts.end()) - return dicts.at(defaultDictName); - } catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));} - - return found->second; -} - Classifier * ReadingMachine::getClassifier() { return classifier.get(); @@ -163,17 +134,12 @@ Classifier * ReadingMachine::getClassifier() void ReadingMachine::saveDicts() const { - for (auto & it : dicts) - { - auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first); - std::FILE * file = std::fopen(pathToDict.c_str(), "w"); - if (!file) - util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str())); - - it.second.save(file, Dict::Encoding::Ascii); + classifier->getNN()->saveDicts(path.parent_path()); +} - std::fclose(file); - } +void ReadingMachine::loadDicts() +{ + classifier->getNN()->loadDicts(path.parent_path()); } void ReadingMachine::save(const std::string & modelNameTemplate) const @@ -211,24 +177,23 @@ void ReadingMachine::trainMode(bool isTrainMode) void ReadingMachine::setDictsState(Dict::State state) { - for (auto & it : dicts) - it.second.setState(state); + classifier->getNN()->setDictsState(state); } -std::map<std::string, Dict> & ReadingMachine::getDicts() +void ReadingMachine::loadLastSaved() { - return dicts; + auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); + if (!lastSavedModel.empty()) + torch::load(classifier->getNN(), lastSavedModel[0]); } -bool ReadingMachine::dictsAreNew() const +void ReadingMachine::setCountOcc(bool countOcc) { - return _dictsAreNew; + classifier->getNN()->setCountOcc(countOcc); } -void ReadingMachine::loadLastSaved() +void ReadingMachine::removeRareDictElements(float rarityThreshold) { - auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); - if (!lastSavedModel.empty()) - torch::load(classifier->getNN(), lastSavedModel[0]); + classifier->getNN()->removeRareDictElements(rarityThreshold); } diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp index 66c405ca13bdb6ddf31ddfe34e2da42788fddc79..2c8431c1ff138ef24a9e67f536917cba95f63f89 100644 --- a/torch_modules/include/CNN.hpp +++ b/torch_modules/include/CNN.hpp @@ -2,7 +2,6 @@ #define CNN__H #include <torch/torch.h> -#include "fmt/core.h" class CNNImpl : public torch::nn::Module { diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index c48eb9f860578fda9cb02d75664c8d2d2602e17a..a9116cf1d1d7030bee10bd7007e314375a8503e4 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -20,12 +20,12 @@ class ContextModuleImpl : public Submodule public : - ContextModuleImpl(const std::string & definition); + ContextModuleImpl(std::string name, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; - void registerEmbeddings(std::size_t nbElements) override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; }; TORCH_MODULE(ContextModule); diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 970e3bc535e2f30ce21fbb26c61d68fa7de28ee9..26fc0ed17a529d9d712277cd96ac962ae9fc6651 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -21,12 +21,12 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule public : - DepthLayerTreeEmbeddingModuleImpl(const std::string & definition); + DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; - void registerEmbeddings(std::size_t nbElements) override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); diff --git a/torch_modules/include/DictHolder.hpp b/torch_modules/include/DictHolder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6edb8e7aaf39ee9c8c879b47a0eee8506f1d5335 --- /dev/null +++ b/torch_modules/include/DictHolder.hpp @@ -0,0 +1,30 @@ +#ifndef DICTHOLDER__H +#define DICTHOLDER__H + +#include <memory> +#include <filesystem> +#include "Dict.hpp" +#include "NameHolder.hpp" + +class DictHolder : public NameHolder +{ + private : + + static constexpr char * filenameTemplate = "{}.dict"; + + std::unique_ptr<Dict> dict; + + private : + + std::string filename() const; + + public : + + DictHolder(); + void saveDict(std::filesystem::path path); + void loadDict(std::filesystem::path path); + Dict & getDict(); +}; + +#endif + diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index f7814a08edcc60b52f5db3f5c0be0cc6aed0914a..4e89372410691a101da27c472f875b21e1c0da67 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -20,12 +20,12 @@ class FocusedColumnModuleImpl : public Submodule public : - FocusedColumnModuleImpl(const std::string & definition); + FocusedColumnModuleImpl(std::string name, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; - void registerEmbeddings(std::size_t nbElements) override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; }; TORCH_MODULE(FocusedColumnModule); diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 41c8beb4eb19fadcb1f9eaca652e8b310d35cf20..11a161e2229311faa63aa38b481a291629f61cbb 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -22,10 +22,15 @@ class ModularNetworkImpl : public NeuralNetworkImpl public : - ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); + ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); torch::Tensor forward(torch::Tensor input) override; - std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; - void registerEmbeddings(std::size_t nbElements) override; + std::vector<std::vector<long>> extractContext(Config & config) override; + void registerEmbeddings() override; + void saveDicts(std::filesystem::path path) override; + void loadDicts(std::filesystem::path path) override; + void setDictsState(Dict::State state) override; + void setCountOcc(bool countOcc) override; + void removeRareDictElements(float rarityThreshold) override; }; #endif diff --git a/torch_modules/include/NameHolder.hpp b/torch_modules/include/NameHolder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60e5801fd6a28961ca3cc6f42235829c4278082e --- /dev/null +++ b/torch_modules/include/NameHolder.hpp @@ -0,0 +1,19 @@ +#ifndef NAMEHOLDER__H +#define NAMEHOLDER__H + +#include <string> + +class NameHolder +{ + private : + + std::string name; + + public : + + const std::string & getName() const; + void setName(const std::string & name); +}; + +#endif + diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 5372255039138d0e9a5b50bb796d9d466af3eafb..3cbfe47443d98de4166f551442df31befff1db9c 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -2,10 +2,11 @@ #define NEURALNETWORK__H #include <torch/torch.h> +#include <filesystem> #include "Config.hpp" -#include "Dict.hpp" +#include "NameHolder.hpp" -class NeuralNetworkImpl : public torch::nn::Module +class NeuralNetworkImpl : public torch::nn::Module, public NameHolder { public : @@ -15,17 +16,18 @@ class NeuralNetworkImpl : public torch::nn::Module std::string state; - protected : - - static constexpr int maxNbEmbeddings = 150000; - public : virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; - virtual void registerEmbeddings(std::size_t nbElements) = 0; + virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; + virtual void registerEmbeddings() = 0; void setState(const std::string & state); const std::string & getState() const; + virtual void saveDicts(std::filesystem::path path) = 0; + virtual void loadDicts(std::filesystem::path path) = 0; + virtual void setDictsState(Dict::State state) = 0; + virtual void setCountOcc(bool countOcc) = 0; + virtual void removeRareDictElements(float rarityThreshold) = 0; }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index b26c6f427354266a9210626f18efa25bc9ba93f2..b20a779eb5f47979eecd7d67f64af3193492ff53 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -11,10 +11,15 @@ class RandomNetworkImpl : public NeuralNetworkImpl public : - RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState); + RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input) override; - std::vector<std::vector<long>> extractContext(Config &, Dict &) const override; - void registerEmbeddings(std::size_t nbElements) override; + std::vector<std::vector<long>> extractContext(Config &) override; + void registerEmbeddings() override; + void saveDicts(std::filesystem::path path) override; + void loadDicts(std::filesystem::path path) override; + void setDictsState(Dict::State state) override; + void setCountOcc(bool countOcc) override; + void removeRareDictElements(float rarityThreshold) override; }; #endif diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index 02e1dd369cc439bc8e6b3ed79e6c761e2d34ad9f..b043f6cdecfb6ce3b25929b7f03a816ef407ec73 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -18,12 +18,12 @@ class RawInputModuleImpl : public Submodule public : - RawInputModuleImpl(const std::string & definition); + RawInputModuleImpl(std::string name, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; - void registerEmbeddings(std::size_t nbElements) override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; }; TORCH_MODULE(RawInputModule); diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index f614588131531087af6a8c90216e5756391539ac..764d9c3d4bd39594c53a553d2bb0d955cd6a1c43 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -18,12 +18,12 @@ class SplitTransModuleImpl : public Submodule public : - SplitTransModuleImpl(int maxNbTrans, const std::string & definition); + SplitTransModuleImpl(std::string name, int maxNbTrans, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; - void registerEmbeddings(std::size_t nbElements) override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; }; TORCH_MODULE(SplitTransModule); diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index 8a2ae71682202ba5fe1b070e46cf0d5e8a428063..2e1a7d4a2752f148deb6c7e41a78b49e113faf90 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -11,18 +11,17 @@ class StateNameModuleImpl : public Submodule { private : - std::map<std::string,int> state2index; torch::nn::Embedding embeddings{nullptr}; int outSize; public : - StateNameModuleImpl(const std::string & definition); + StateNameModuleImpl(std::string name, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; - void registerEmbeddings(std::size_t nbElements) override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; }; TORCH_MODULE(StateNameModule); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 849eb225061e9a29232763c46d7643968a1723d7..135b0f9a781f90e7def6d5453181ffa6d9ce735f 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -2,10 +2,10 @@ #define SUBMODULE__H #include <torch/torch.h> -#include "Dict.hpp" #include "Config.hpp" +#include "DictHolder.hpp" -class Submodule : public torch::nn::Module +class Submodule : public torch::nn::Module, public DictHolder { protected : @@ -16,9 +16,9 @@ class Submodule : public torch::nn::Module void setFirstInputIndex(std::size_t firstInputIndex); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; - virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0; + virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual void registerEmbeddings(std::size_t nbElements) = 0; + virtual void registerEmbeddings() = 0; }; #endif diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp index dbc3797b12f7ce6e3e16e8d93e0959e0a3bdb804..35f357e5d86aaa34db95cf408509667a972897a7 100644 --- a/torch_modules/src/CNN.cpp +++ b/torch_modules/src/CNN.cpp @@ -1,4 +1,5 @@ #include "CNN.hpp" +#include "fmt/core.h" CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize) : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize) diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 248da9387d4f3770484b6c11bc1b307e2e4a162c..ced9aee62b876e5561a76f8f88f3f5ffc964dcac 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -1,7 +1,8 @@ #include "ContextModule.hpp" -ContextModuleImpl::ContextModuleImpl(const std::string & definition) +ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition) { + setName(name); std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { @@ -49,8 +50,9 @@ std::size_t ContextModuleImpl::getInputSize() return columns.size()*(bufferContext.size()+stackContext.size()); } -void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) { + auto & dict = getDict(); std::vector<long> contextIndexes; for (int index : bufferContext) @@ -87,8 +89,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) return myModule->forward(context); } -void ContextModuleImpl::registerEmbeddings(std::size_t nbElements) +void ContextModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index df9c2df62ebc87d324c6be4f5ff5cd4528776c96..0c8abed79aa058e8200b03efbc2d0debaf7c8043 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -1,7 +1,8 @@ #include "DepthLayerTreeEmbeddingModule.hpp" -DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(const std::string & definition) +DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition) { + setName(name); std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { @@ -81,8 +82,9 @@ std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize() return inputSize; } -void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) { + auto & dict = getDict(); std::vector<long> focusedIndexes; for (int index : focusedBuffer) @@ -120,8 +122,8 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon } } -void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::size_t nbElements) +void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); } diff --git a/torch_modules/src/DictHolder.cpp b/torch_modules/src/DictHolder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f1958f55f2a4359296542359b7decb0deac3f83 --- /dev/null +++ b/torch_modules/src/DictHolder.cpp @@ -0,0 +1,28 @@ +#include "DictHolder.hpp" +#include "fmt/core.h" + +DictHolder::DictHolder() +{ + dict.reset(new Dict(Dict::State::Open)); +} + +std::string DictHolder::filename() const +{ + return fmt::format(filenameTemplate, getName()); +} + +void DictHolder::saveDict(std::filesystem::path path) +{ + dict->save(path / filename(), Dict::Encoding::Ascii); +} + +void DictHolder::loadDict(std::filesystem::path path) +{ + dict.reset(new Dict((path / filename()).c_str(), Dict::State::Open)); +} + +Dict & DictHolder::getDict() +{ + return *dict; +} + diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 03cf9b65a6652e9bcd4c55d319f4a865772ee7ad..9a4ce1d225e9dbaddaf15cc369ca2ba15bf7d8b8 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -1,7 +1,8 @@ #include "FocusedColumnModule.hpp" -FocusedColumnModuleImpl::FocusedColumnModuleImpl(const std::string & definition) +FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition) { + setName(name); std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { @@ -59,8 +60,9 @@ std::size_t FocusedColumnModuleImpl::getInputSize() return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; } -void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) { + auto & dict = getDict(); std::vector<long> focusedIndexes; for (int index : focusedBuffer) @@ -132,8 +134,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } -void FocusedColumnModuleImpl::registerEmbeddings(std::size_t nbElements) +void FocusedColumnModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); } diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 82cebd7ad62203d92fd19a9e1046dd24ed6e0ebf..db8d9d02a23dda70c4b8c945ed0b9bca276084dd 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -1,7 +1,8 @@ #include "ModularNetwork.hpp" -ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions) +ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions) { + setName(name); std::string anyBlanks = "(?:(?:\\s|\\t)*)"; auto splitLine = [anyBlanks](std::string line) { @@ -21,18 +22,19 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu { auto splited = splitLine(line); std::string name = fmt::format("{}_{}", modules.size(), splited.first); + std::string nameH = fmt::format("{}_{}", getName(), name); if (splited.first == "Context") - modules.emplace_back(register_module(name, ContextModule(splited.second))); + modules.emplace_back(register_module(name, ContextModule(nameH, splited.second))); else if (splited.first == "StateName") - modules.emplace_back(register_module(name, StateNameModule(splited.second))); + modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); else if (splited.first == "Focused") - modules.emplace_back(register_module(name, FocusedColumnModule(splited.second))); + modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second))); else if (splited.first == "RawInput") - modules.emplace_back(register_module(name, RawInputModule(splited.second))); + modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second))); else if (splited.first == "SplitTrans") - modules.emplace_back(register_module(name, SplitTransModule(Config::maxNbAppliableSplitTransitions, splited.second))); + modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second))); else if (splited.first == "DepthLayerTree") - modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(splited.second))); + modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second))); else if (splited.first == "MLP") { mlpDef = splited.second; @@ -77,17 +79,54 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input) return outputLayersPerState.at(getState())(mlp(totalInput)); } -std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config, Dict & dict) const +std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config) { std::vector<std::vector<long>> context(1); for (auto & mod : modules) - mod->addToContext(context, dict, config); + mod->addToContext(context, config); return context; } -void ModularNetworkImpl::registerEmbeddings(std::size_t nbElements) +void ModularNetworkImpl::registerEmbeddings() { for (auto & mod : modules) - mod->registerEmbeddings(nbElements); + mod->registerEmbeddings(); +} + +void ModularNetworkImpl::saveDicts(std::filesystem::path path) +{ + for (auto & mod : modules) + mod->saveDict(path); +} + +void ModularNetworkImpl::loadDicts(std::filesystem::path path) +{ + for (auto & mod : modules) + mod->loadDict(path); +} + +void ModularNetworkImpl::setDictsState(Dict::State state) +{ + for (auto & mod : modules) + mod->getDict().setState(state); +} + +void ModularNetworkImpl::setCountOcc(bool countOcc) +{ + for (auto & mod : modules) + mod->getDict().countOcc(countOcc); +} + +void ModularNetworkImpl::removeRareDictElements(float rarityThreshold) +{ + std::size_t minNbElems = 1000; + + for (auto & mod : modules) + { + auto & dict = mod->getDict(); + std::size_t originalSize = dict.size(); + while (100.0*(originalSize-dict.size())/originalSize < rarityThreshold and dict.size() > minNbElems) + dict.removeRareElements(); + } } diff --git a/torch_modules/src/NameHolder.cpp b/torch_modules/src/NameHolder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ceb33a4512dcf1525f3c7deae85bbd9fc7a72686 --- /dev/null +++ b/torch_modules/src/NameHolder.cpp @@ -0,0 +1,16 @@ +#include "NameHolder.hpp" +#include "util.hpp" + +const std::string & NameHolder::getName() const +{ + if (name.empty()) + util::myThrow("name is empty"); + + return name; +} + +void NameHolder::setName(const std::string & name) +{ + this->name = name; +} + diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 6622732208f5dfd84cb679e7f0266420046b5191..7a6491b6351a9a20e6d56d6423db44fb268067c2 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -1,7 +1,8 @@ #include "RandomNetwork.hpp" -RandomNetworkImpl::RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState) +RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState) { + setName(name); } torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) @@ -12,12 +13,32 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true)); } -std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const +std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) { return std::vector<std::vector<long>>{{0}}; } -void RandomNetworkImpl::registerEmbeddings(std::size_t) +void RandomNetworkImpl::registerEmbeddings() +{ +} + +void RandomNetworkImpl::saveDicts(std::filesystem::path) +{ +} + +void RandomNetworkImpl::loadDicts(std::filesystem::path) +{ +} + +void RandomNetworkImpl::setDictsState(Dict::State) +{ +} + +void RandomNetworkImpl::setCountOcc(bool) +{ +} + +void RandomNetworkImpl::removeRareDictElements(float) { } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index ac0f5e45b0e11cc85d8cc709cbeab1983e3a074c..a14b9fc0125435a1264eb9b47148995f9553746f 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -1,7 +1,8 @@ #include "RawInputModule.hpp" -RawInputModuleImpl::RawInputModuleImpl(const std::string & definition) +RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & definition) { + setName(name); std::regex regex("(?:(?:\\s|\\t)*)Left\\{(.*)\\}(?:(?:\\s|\\t)*)Right\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { @@ -49,11 +50,12 @@ std::size_t RawInputModuleImpl::getInputSize() return leftWindow + rightWindow + 1; } -void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) { if (leftWindow < 0 or rightWindow < 0) return; + auto & dict = getDict(); for (auto & contextElement : context) { for (int i = 0; i < leftWindow; i++) @@ -70,8 +72,8 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, } } -void RawInputModuleImpl::registerEmbeddings(std::size_t nbElements) +void RawInputModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 6fdf54d5908ef83303ba5026c8db3d71d52d2571..315566a765a1ea6af1b2db07449f8508e0e11732 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -1,8 +1,9 @@ #include "SplitTransModule.hpp" #include "Transition.hpp" -SplitTransModuleImpl::SplitTransModuleImpl(int maxNbTrans, const std::string & definition) +SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, const std::string & definition) { + setName(name); this->maxNbTrans = maxNbTrans; std::regex regex("(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) @@ -48,8 +49,9 @@ std::size_t SplitTransModuleImpl::getInputSize() return maxNbTrans; } -void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) { + auto & dict = getDict(); auto & splitTransitions = config.getAppliableSplitTransitions(); for (auto & contextElement : context) for (int i = 0; i < maxNbTrans; i++) @@ -59,8 +61,8 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -void SplitTransModuleImpl::registerEmbeddings(std::size_t nbElements) +void SplitTransModuleImpl::registerEmbeddings() { - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index afc572142fb4725753c0c8a1eece8613c33b36ef..42edd50ee4621080b512782378bf87e2c1703235 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -1,17 +1,14 @@ #include "StateNameModule.hpp" -StateNameModuleImpl::StateNameModuleImpl(const std::string & definition) +StateNameModuleImpl::StateNameModuleImpl(std::string name, const std::string & definition) { - std::regex regex("(?:(?:\\s|\\t)*)States\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + setName(name); + std::regex regex("(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try { - auto states = util::split(sm.str(1), ' '); - outSize = std::stoi(sm.str(2)); - - for (auto & state : states) - state2index.emplace(state, state2index.size()); + outSize = std::stoi(sm.str(1)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} })) util::myThrow(fmt::format("invalid definition '{}'", definition)); @@ -32,14 +29,15 @@ std::size_t StateNameModuleImpl::getInputSize() return 1; } -void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) { + auto & dict = getDict(); for (auto & contextElement : context) - contextElement.emplace_back(state2index.at(config.getState())); + contextElement.emplace_back(dict.getIndexOrInsert(config.getState())); } -void StateNameModuleImpl::registerEmbeddings(std::size_t) +void StateNameModuleImpl::registerEmbeddings() { - embeddings = register_module("embeddings", torch::nn::Embedding(state2index.size(), outSize)); + embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize)); } diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 2c9ef09590f49fe6501a1721b1f02753ee05e8b6..6c962bca1838b7d2151f95ef7a99292af65722b5 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -112,28 +112,18 @@ int MacaonTrain::main() Trainer trainer(machine, batchSize); Decoder decoder(machine); - if (machine.dictsAreNew()) + if (util::findFilesByExtension(machinePath.parent_path(), ".dict").empty()) { trainer.fillDicts(goldConfig, debug); - for (auto & it : machine.getDicts()) - { - std::size_t originalSize = it.second.size(); - for (;;) - { - std::size_t lastSize = it.second.size(); - it.second.removeRareElements(); - float decrease = 100.0*(originalSize-it.second.size())/originalSize; - if (decrease >= rarityThreshold or lastSize == it.second.size()) - break; - } - } + machine.removeRareDictElements(rarityThreshold); machine.saveDicts(); } + else + { + machine.loadDicts(); + } - std::size_t maxDictSize = 0; - for (auto & it : machine.getDicts()) - maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); - machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize); + machine.getClassifier()->getNN()->registerEmbeddings(); machine.loadLastSaved(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index f257f05bc4cfc26b13319250bef4f00088bae9fe..0136515a7852a95b873c8693c90167d053117465 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -75,7 +75,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p try { - context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); + context = machine.getClassifier()->getNN()->extractContext(config); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); @@ -274,16 +274,14 @@ void Trainer::fillDicts(BaseConfig & goldConfig, bool debug) { SubConfig config(goldConfig, goldConfig.getNbLines()); - for (auto & it : machine.getDicts()) - it.second.countOcc(true); + machine.setCountOcc(true); machine.trainMode(false); machine.setDictsState(Dict::State::Open); fillDicts(config, debug); - for (auto & it : machine.getDicts()) - it.second.countOcc(false); + machine.setCountOcc(false); } void Trainer::fillDicts(SubConfig & config, bool debug) @@ -305,7 +303,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug) try { - machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); + machine.getClassifier()->getNN()->extractContext(config); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what()));