From a31419f9b423d3fa3c5ec2b67321f22773976d3c Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 20 Apr 2020 15:58:55 +0200 Subject: [PATCH] Each classifier can have a different output layer in neural networks --- decoder/src/Decoder.cpp | 2 + reading_machine/include/Classifier.hpp | 6 ++- reading_machine/src/Classifier.cpp | 49 ++++++++++++++++++++----- torch_modules/include/ConfigDataset.hpp | 6 +-- torch_modules/include/LSTMNetwork.hpp | 3 +- torch_modules/include/MLP.hpp | 4 +- torch_modules/include/NeuralNetwork.hpp | 3 ++ torch_modules/include/RandomNetwork.hpp | 4 +- torch_modules/src/ConfigDataset.cpp | 25 ++++++++----- torch_modules/src/LSTMNetwork.cpp | 9 +++-- torch_modules/src/MLP.cpp | 14 ++++--- torch_modules/src/NeuralNetwork.cpp | 10 +++++ torch_modules/src/RandomNetwork.cpp | 4 +- trainer/include/Trainer.hpp | 13 ++++++- trainer/src/Trainer.cpp | 45 ++++++++++++++--------- 15 files changed, 141 insertions(+), 56 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 33c4837..352a9e8 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -21,6 +21,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool try { config.setState(machine.getStrategy().getInitialState()); + machine.getClassifier()->setState(machine.getStrategy().getInitialState()); while (true) { @@ -94,6 +95,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool break; config.setState(movement.first); + machine.getClassifier()->setState(movement.first); config.moveWordIndexRelaxed(movement.second); } } catch(std::exception & e) {util::myThrow(e.what());} diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 2bb3af9..4951e5d 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -10,14 +10,15 @@ class Classifier private : std::string name; - std::unique_ptr<TransitionSet> transitionSet; + std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets; std::shared_ptr<NeuralNetworkImpl> nn; std::unique_ptr<torch::optim::Adam> optimizer; + std::string state; private : void initNeuralNetwork(const std::vector<std::string> & definition); - void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex); + void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); public : @@ -29,6 +30,7 @@ class Classifier void loadOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path); torch::optim::Adam & getOptimizer(); + void setState(const std::string & state); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index b3e0710..d8418c9 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -8,12 +8,31 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std this->name = name; if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm) { - std::vector<std::string> tsFiles; + auto splited = util::split(sm.str(1), ' '); - for (auto & tsFilename : util::split(sm.str(1), ' ')) - tsFiles.emplace_back(path.parent_path() / tsFilename); + for (auto & ss : splited) + { + std::vector<std::string> tsFiles; + std::vector<std::string> states; + for (auto & elem : util::split(ss, ',')) + if (std::filesystem::path(elem).extension().empty()) + states.emplace_back(elem); + else + tsFiles.emplace_back(path.parent_path() / elem); + if (tsFiles.empty()) + util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss)); + if (states.empty()) + util::myThrow(fmt::format("invalid '{}' no states specified", ss)); + + for (auto & stateName : states) + { + if (transitionSets.count(stateName)) + util::myThrow(fmt::format("state '{}' already assigned", stateName)); + + this->transitionSets.emplace(stateName, new TransitionSet(tsFiles)); + } + } - this->transitionSet.reset(new TransitionSet(tsFiles)); })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); @@ -32,7 +51,10 @@ int Classifier::getNbParameters() const TransitionSet & Classifier::getTransitionSet() { - return *transitionSet; + if (!transitionSets.count(state)) + util::myThrow(fmt::format("cannot find transition set for state '{}'", state)); + + return *transitionSets[state]; } NeuralNetwork & Classifier::getNN() @@ -47,6 +69,10 @@ const std::string & Classifier::getName() const void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) { + std::map<std::string,std::size_t> nbOutputsPerState; + for (auto & it : this->transitionSets) + nbOutputsPerState[it.first] = it.second->size(); + std::size_t curIndex = 1; std::string networkType; @@ -58,9 +84,9 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType")); if (networkType == "Random") - this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); + this->nn.reset(new RandomNetworkImpl(nbOutputsPerState)); else if (networkType == "LSTM") - initLSTM(definition, curIndex); + initLSTM(definition, curIndex, nbOutputsPerState); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); @@ -83,7 +109,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); } -void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex) +void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) { int unknownValueThreshold; std::vector<int> bufferContext, stackContext; @@ -299,7 +325,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); - this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); + this->nn.reset(new LSTMNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); } void Classifier::loadOptimizer(std::filesystem::path path) @@ -317,4 +343,9 @@ torch::optim::Adam & Classifier::getOptimizer() return *optimizer; } +void Classifier::setState(const std::string & state) +{ + this->state = state; + nn->setState(state); +} diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index 59289f4..ee8d46c 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -4,12 +4,12 @@ #include <torch/torch.h> #include "Config.hpp" -class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::pair<torch::Tensor,torch::Tensor>> +class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::tuple<torch::Tensor,torch::Tensor,std::string>> { private : std::size_t size_{0}; - std::vector<std::tuple<int,int,std::filesystem::path>> exampleLocations; + std::vector<std::tuple<int,int,std::filesystem::path,std::string>> exampleLocations; torch::Tensor loadedTensor; std::optional<std::size_t> loadedTensorIndex; std::size_t nextIndexToGive{0}; @@ -18,7 +18,7 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase explicit ConfigDataset(std::filesystem::path dir); c10::optional<std::size_t> size() const override; - c10::optional<std::pair<torch::Tensor,torch::Tensor>> get_batch(std::size_t batchSize) override; + c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize) override; void reset() override; void load(torch::serialize::InputArchive &) override; void save(torch::serialize::OutputArchive &) const override; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index e742b97..76b9303 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -24,10 +24,11 @@ class LSTMNetworkImpl : public NeuralNetworkImpl SplitTransLSTM splitTransLSTM{nullptr}; DepthLayerTreeEmbedding treeEmbedding{nullptr}; std::vector<FocusedColumnLSTM> focusedLstms; + std::map<std::string,torch::nn::Linear> outputLayersPerState; public : - LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); + LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp index 71520f2..be272f1 100644 --- a/torch_modules/include/MLP.hpp +++ b/torch_modules/include/MLP.hpp @@ -9,11 +9,13 @@ class MLPImpl : public torch::nn::Module std::vector<torch::nn::Linear> layers; std::vector<torch::nn::Dropout> dropouts; + std::size_t outSize{0}; public : - MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params); + MLPImpl(int inputSize, std::vector<std::pair<int, float>> params); torch::Tensor forward(torch::Tensor input); + std::size_t outputSize() const; }; TORCH_MODULE(MLP); diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 1237f09..3db8651 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -14,6 +14,7 @@ class NeuralNetworkImpl : public torch::nn::Module private : bool splitUnknown{false}; + std::string state; protected : @@ -25,6 +26,8 @@ class NeuralNetworkImpl : public torch::nn::Module virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; bool mustSplitUnknown() const; void setSplitUnknown(bool splitUnknown); + void setState(const std::string & state); + const std::string & getState() const; }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index 8f58d7b..f40c1b0 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -7,11 +7,11 @@ class RandomNetworkImpl : public NeuralNetworkImpl { private : - long outputSize; + std::map<std::string,std::size_t> nbOutputsPerState; public : - RandomNetworkImpl(long outputSize); + RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config &, Dict &) const override; }; diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index e73e88f..3a4a2c5 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -6,12 +6,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir) for (auto & entry : std::filesystem::directory_iterator(dir)) if (entry.is_regular_file()) { - auto splited = util::split(entry.path().stem().string(), '-'); - if (splited.size() != 2) + auto stem = entry.path().stem().string(); + if (stem == "extracted") continue; - exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path())); + auto state = util::split(stem, '_')[0]; + auto splited = util::split(util::split(stem, '_')[1], '-'); + exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state)); size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back()); } + } c10::optional<std::size_t> ConfigDataset::size() const @@ -19,7 +22,7 @@ c10::optional<std::size_t> ConfigDataset::size() const return size_; } -c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(std::size_t batchSize) +c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize) { if (!loadedTensorIndex.has_value()) { @@ -33,25 +36,27 @@ c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(s loadedTensorIndex = loadedTensorIndex.value() + 1; if (loadedTensorIndex >= exampleLocations.size()) - return c10::optional<std::pair<torch::Tensor,torch::Tensor>>(); + return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); } - std::pair<torch::Tensor, torch::Tensor> batch; + std::tuple<torch::Tensor, torch::Tensor, std::string> batch; if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0)) { - batch.first = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1); - batch.second = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1); + std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1); + std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1); nextIndexToGive += batchSize; } else { - batch.first = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1); - batch.second = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1); + std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1); + std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1); nextIndexToGive = loadedTensor.size(0); } + std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]); + return batch; } diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index a4f5863..476e24d 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,6 +1,6 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) +LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) { LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; auto lstmOptionsAll = lstmOptions; @@ -50,7 +50,10 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); - mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams)); + mlp = register_module("mlp", MLP(currentOutputSize, mlpParams)); + + for (auto & it : nbOutputsPerState) + outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); } torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) @@ -81,7 +84,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) auto totalInput = inputDropout(torch::cat(outputs, 1)); - return mlp(totalInput); + return outputLayersPerState.at(getState())(mlp(totalInput)); } std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp index b886245..03880ec 100644 --- a/torch_modules/src/MLP.cpp +++ b/torch_modules/src/MLP.cpp @@ -1,7 +1,7 @@ #include "MLP.hpp" #include "fmt/core.h" -MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params) +MLPImpl::MLPImpl(int inputSize, std::vector<std::pair<int, float>> params) { int inSize = inputSize; @@ -10,17 +10,21 @@ MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first))); dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second))); inSize = param.first; + outSize = inSize; } - - layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, outputSize))); } torch::Tensor MLPImpl::forward(torch::Tensor input) { torch::Tensor output = input; - for (unsigned int i = 0; i < layers.size()-1; i++) + for (unsigned int i = 0; i < layers.size(); i++) output = dropouts[i](torch::relu(layers[i](output))); - return layers.back()(output); + return output; +} + +std::size_t MLPImpl::outputSize() const +{ + return outSize; } diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 235c677..987cfcb 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -12,3 +12,13 @@ void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown) this->splitUnknown = splitUnknown; } +void NeuralNetworkImpl::setState(const std::string & state) +{ + this->state = state; +} + +const std::string & NeuralNetworkImpl::getState() const +{ + return state; +} + diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 7f0137d..8dfafc2 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -1,6 +1,6 @@ #include "RandomNetwork.hpp" -RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize) +RandomNetworkImpl::RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState) { } @@ -9,7 +9,7 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) if (input.dim() == 1) input = input.unsqueeze(0); - return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true)); + return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true)); } std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 7994a2f..91aedbe 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -9,6 +9,17 @@ class Trainer { private : + struct Examples + { + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; + + int currentExampleIndex{0}; + int lastSavedIndex{0}; + }; + + private : + using Dataset = ConfigDataset; using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>; @@ -26,7 +37,7 @@ class Trainer void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); - void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir); +void saveExamples(std::string state, Examples & examples, std::filesystem::path dir); public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 95e98eb..f76f220 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -33,30 +33,28 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } -void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir) +void Trainer::saveExamples(std::string state, Examples & examples, std::filesystem::path dir) { - auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); - auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1); + auto tensorToSave = torch::cat({torch::stack(examples.contexts), torch::stack(examples.classes)}, 1); + auto filename = fmt::format("{}_{}-{}.tensor", state, examples.lastSavedIndex, examples.currentExampleIndex-1); torch::save(tensorToSave, dir/filename); - lastSavedIndex = currentExampleIndex; - contexts.clear(); - classes.clear(); + examples.lastSavedIndex = examples.currentExampleIndex; + examples.contexts.clear(); + examples.classes.clear(); } void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { torch::AutoGradMode useGrad(false); - int maxNbExamplesPerFile = 250000; - int currentExampleIndex = 0; - int lastSavedIndex = 0; - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; + int maxNbExamplesPerFile = 50000; + std::map<std::string, Examples> examplesPerState; std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); + machine.getClassifier()->setState(machine.getStrategy().getInitialState()); machine.getStrategy().reset(); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); @@ -75,6 +73,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p if (entry.is_regular_file()) std::filesystem::remove(entry.path()); + int totalNbExamples = 0; + while (true) { if (debug) @@ -85,6 +85,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::vector<std::vector<long>> context; + auto & contexts = examplesPerState[config.getState()].contexts; + auto & classes = examplesPerState[config.getState()].classes; + auto & currentExampleIndex = examplesPerState[config.getState()].currentExampleIndex; + auto & lastSavedIndex = examplesPerState[config.getState()].lastSavedIndex; + try { context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); @@ -133,10 +138,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p gold[0] = goldIndex; currentExampleIndex += context.size(); + totalNbExamples += context.size(); classes.insert(classes.end(), context.size(), gold); if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile) - saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); + saveExamples(config.getState(), examplesPerState[config.getState()], dir); transition->apply(config); config.addToHistory(transition->getName()); @@ -148,14 +154,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p break; config.setState(movement.first); + machine.getClassifier()->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); } - if (!contexts.empty()) - saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); + for (auto & it : examplesPerState) + if (!it.second.contexts.empty()) + saveExamples(it.first, it.second, dir); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) @@ -164,7 +172,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p machine.saveDicts(); - fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); + fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples)); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples) @@ -188,8 +196,11 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) machine.getClassifier()->getOptimizer().zero_grad(); - auto data = batch.first; - auto labels = batch.second; + auto data = std::get<0>(batch); + auto labels = std::get<1>(batch); + auto state = std::get<2>(batch); + + machine.getClassifier()->setState(state); auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) -- GitLab