diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 33c483783f9c02060a40057dba26597735c34dac..352a9e8ad96542f7a9985a3a196d9b1f29d1b567 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -21,6 +21,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool try { config.setState(machine.getStrategy().getInitialState()); + machine.getClassifier()->setState(machine.getStrategy().getInitialState()); while (true) { @@ -94,6 +95,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool break; config.setState(movement.first); + machine.getClassifier()->setState(movement.first); config.moveWordIndexRelaxed(movement.second); } } catch(std::exception & e) {util::myThrow(e.what());} diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 2bb3af9347d2e25f66ac094738cec22978556928..4951e5dce72e57b6d02f74598e426c248ce0ddcb 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -10,14 +10,15 @@ class Classifier private : std::string name; - std::unique_ptr<TransitionSet> transitionSet; + std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets; std::shared_ptr<NeuralNetworkImpl> nn; std::unique_ptr<torch::optim::Adam> optimizer; + std::string state; private : void initNeuralNetwork(const std::vector<std::string> & definition); - void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex); + void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); public : @@ -29,6 +30,7 @@ class Classifier void loadOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path); torch::optim::Adam & getOptimizer(); + void setState(const std::string & state); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index b3e0710eb4e23dc155c629618de8a63f95c522fb..d8418c99c563c00dfce55e7ed8f6a1e0fae89fa8 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -8,12 +8,31 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std this->name = name; if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm) { - std::vector<std::string> tsFiles; + auto splited = util::split(sm.str(1), ' '); - for (auto & tsFilename : util::split(sm.str(1), ' ')) - tsFiles.emplace_back(path.parent_path() / tsFilename); + for (auto & ss : splited) + { + std::vector<std::string> tsFiles; + std::vector<std::string> states; + for (auto & elem : util::split(ss, ',')) + if (std::filesystem::path(elem).extension().empty()) + states.emplace_back(elem); + else + tsFiles.emplace_back(path.parent_path() / elem); + if (tsFiles.empty()) + util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss)); + if (states.empty()) + util::myThrow(fmt::format("invalid '{}' no states specified", ss)); + + for (auto & stateName : states) + { + if (transitionSets.count(stateName)) + util::myThrow(fmt::format("state '{}' already assigned", stateName)); + + this->transitionSets.emplace(stateName, new TransitionSet(tsFiles)); + } + } - this->transitionSet.reset(new TransitionSet(tsFiles)); })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); @@ -32,7 +51,10 @@ int Classifier::getNbParameters() const TransitionSet & Classifier::getTransitionSet() { - return *transitionSet; + if (!transitionSets.count(state)) + util::myThrow(fmt::format("cannot find transition set for state '{}'", state)); + + return *transitionSets[state]; } NeuralNetwork & Classifier::getNN() @@ -47,6 +69,10 @@ const std::string & Classifier::getName() const void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) { + std::map<std::string,std::size_t> nbOutputsPerState; + for (auto & it : this->transitionSets) + nbOutputsPerState[it.first] = it.second->size(); + std::size_t curIndex = 1; std::string networkType; @@ -58,9 +84,9 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType")); if (networkType == "Random") - this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); + this->nn.reset(new RandomNetworkImpl(nbOutputsPerState)); else if (networkType == "LSTM") - initLSTM(definition, curIndex); + initLSTM(definition, curIndex, nbOutputsPerState); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); @@ -83,7 +109,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); } -void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex) +void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) { int unknownValueThreshold; std::vector<int> bufferContext, stackContext; @@ -299,7 +325,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); - this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); + this->nn.reset(new LSTMNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); } void Classifier::loadOptimizer(std::filesystem::path path) @@ -317,4 +343,9 @@ torch::optim::Adam & Classifier::getOptimizer() return *optimizer; } +void Classifier::setState(const std::string & state) +{ + this->state = state; + nn->setState(state); +} diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index 59289f4cb50cc501f40a93cace53255b09bbd4ec..ee8d46cc97ae05df8262028836e997a1d285d2bb 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -4,12 +4,12 @@ #include <torch/torch.h> #include "Config.hpp" -class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::pair<torch::Tensor,torch::Tensor>> +class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::tuple<torch::Tensor,torch::Tensor,std::string>> { private : std::size_t size_{0}; - std::vector<std::tuple<int,int,std::filesystem::path>> exampleLocations; + std::vector<std::tuple<int,int,std::filesystem::path,std::string>> exampleLocations; torch::Tensor loadedTensor; std::optional<std::size_t> loadedTensorIndex; std::size_t nextIndexToGive{0}; @@ -18,7 +18,7 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase explicit ConfigDataset(std::filesystem::path dir); c10::optional<std::size_t> size() const override; - c10::optional<std::pair<torch::Tensor,torch::Tensor>> get_batch(std::size_t batchSize) override; + c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize) override; void reset() override; void load(torch::serialize::InputArchive &) override; void save(torch::serialize::OutputArchive &) const override; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index e742b9723a3ba5c9e14485dac5f5f4a9dc46c5ad..76b9303a8cb7245fe27e42d9b7f9673a832ef2e3 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -24,10 +24,11 @@ class LSTMNetworkImpl : public NeuralNetworkImpl SplitTransLSTM splitTransLSTM{nullptr}; DepthLayerTreeEmbedding treeEmbedding{nullptr}; std::vector<FocusedColumnLSTM> focusedLstms; + std::map<std::string,torch::nn::Linear> outputLayersPerState; public : - LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); + LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp index 71520f2be3712e78cd9fe4987d50ade32d289825..be272f1cd1369a7b1290aefd1868265111ac00da 100644 --- a/torch_modules/include/MLP.hpp +++ b/torch_modules/include/MLP.hpp @@ -9,11 +9,13 @@ class MLPImpl : public torch::nn::Module std::vector<torch::nn::Linear> layers; std::vector<torch::nn::Dropout> dropouts; + std::size_t outSize{0}; public : - MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params); + MLPImpl(int inputSize, std::vector<std::pair<int, float>> params); torch::Tensor forward(torch::Tensor input); + std::size_t outputSize() const; }; TORCH_MODULE(MLP); diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 1237f09e15989dc6534e150e9ca03cfe983f797b..3db86517d391b5ee6384c6e9d8dba4e1b3810837 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -14,6 +14,7 @@ class NeuralNetworkImpl : public torch::nn::Module private : bool splitUnknown{false}; + std::string state; protected : @@ -25,6 +26,8 @@ class NeuralNetworkImpl : public torch::nn::Module virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; bool mustSplitUnknown() const; void setSplitUnknown(bool splitUnknown); + void setState(const std::string & state); + const std::string & getState() const; }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index 8f58d7b30859a9cd4130fc506f83fc9c51bce34e..f40c1b02a753999de3b7649d18d6e75765f5506d 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -7,11 +7,11 @@ class RandomNetworkImpl : public NeuralNetworkImpl { private : - long outputSize; + std::map<std::string,std::size_t> nbOutputsPerState; public : - RandomNetworkImpl(long outputSize); + RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config &, Dict &) const override; }; diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index e73e88f528a7d5145e8c2e94c53779c277ff1cc5..3a4a2c54e50a0891db61a3e97d9dc2a26554df26 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -6,12 +6,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir) for (auto & entry : std::filesystem::directory_iterator(dir)) if (entry.is_regular_file()) { - auto splited = util::split(entry.path().stem().string(), '-'); - if (splited.size() != 2) + auto stem = entry.path().stem().string(); + if (stem == "extracted") continue; - exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path())); + auto state = util::split(stem, '_')[0]; + auto splited = util::split(util::split(stem, '_')[1], '-'); + exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state)); size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back()); } + } c10::optional<std::size_t> ConfigDataset::size() const @@ -19,7 +22,7 @@ c10::optional<std::size_t> ConfigDataset::size() const return size_; } -c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(std::size_t batchSize) +c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize) { if (!loadedTensorIndex.has_value()) { @@ -33,25 +36,27 @@ c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(s loadedTensorIndex = loadedTensorIndex.value() + 1; if (loadedTensorIndex >= exampleLocations.size()) - return c10::optional<std::pair<torch::Tensor,torch::Tensor>>(); + return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); } - std::pair<torch::Tensor, torch::Tensor> batch; + std::tuple<torch::Tensor, torch::Tensor, std::string> batch; if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0)) { - batch.first = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1); - batch.second = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1); + std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1); + std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1); nextIndexToGive += batchSize; } else { - batch.first = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1); - batch.second = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1); + std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1); + std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1); nextIndexToGive = loadedTensor.size(0); } + std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]); + return batch; } diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index a4f5863ce39dab0fd51730e9edb52a8a0a43e2a2..476e24dc938bd302db62194fae0e4cb665f82dab 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,6 +1,6 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) +LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) { LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; auto lstmOptionsAll = lstmOptions; @@ -50,7 +50,10 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); - mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams)); + mlp = register_module("mlp", MLP(currentOutputSize, mlpParams)); + + for (auto & it : nbOutputsPerState) + outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); } torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) @@ -81,7 +84,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) auto totalInput = inputDropout(torch::cat(outputs, 1)); - return mlp(totalInput); + return outputLayersPerState.at(getState())(mlp(totalInput)); } std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp index b886245913d33d6ed5e138fddc57d2ca22a6207e..03880ecca0554d1ab4ad1b0c92fca4f1a0ea2481 100644 --- a/torch_modules/src/MLP.cpp +++ b/torch_modules/src/MLP.cpp @@ -1,7 +1,7 @@ #include "MLP.hpp" #include "fmt/core.h" -MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params) +MLPImpl::MLPImpl(int inputSize, std::vector<std::pair<int, float>> params) { int inSize = inputSize; @@ -10,17 +10,21 @@ MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first))); dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second))); inSize = param.first; + outSize = inSize; } - - layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, outputSize))); } torch::Tensor MLPImpl::forward(torch::Tensor input) { torch::Tensor output = input; - for (unsigned int i = 0; i < layers.size()-1; i++) + for (unsigned int i = 0; i < layers.size(); i++) output = dropouts[i](torch::relu(layers[i](output))); - return layers.back()(output); + return output; +} + +std::size_t MLPImpl::outputSize() const +{ + return outSize; } diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 235c67793305280d0e09a3f1d45593fa727d13a3..987cfcb44c1e243348d6073da4252b7570658fa1 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -12,3 +12,13 @@ void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown) this->splitUnknown = splitUnknown; } +void NeuralNetworkImpl::setState(const std::string & state) +{ + this->state = state; +} + +const std::string & NeuralNetworkImpl::getState() const +{ + return state; +} + diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 7f0137d5d25b36eaffcde2cf5e776f16fca0685e..8dfafc2ab2b044523daf5764047f394e2699ff91 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -1,6 +1,6 @@ #include "RandomNetwork.hpp" -RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize) +RandomNetworkImpl::RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState) { } @@ -9,7 +9,7 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) if (input.dim() == 1) input = input.unsqueeze(0); - return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true)); + return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true)); } std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 7994a2f200b323bd23eeab8a08859c258dbf89a1..91aedbe7687c93a49bee34f386074f79f7026562 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -9,6 +9,17 @@ class Trainer { private : + struct Examples + { + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; + + int currentExampleIndex{0}; + int lastSavedIndex{0}; + }; + + private : + using Dataset = ConfigDataset; using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>; @@ -26,7 +37,7 @@ class Trainer void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); - void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir); +void saveExamples(std::string state, Examples & examples, std::filesystem::path dir); public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 95e98eba3e8162e01a71d8a6a5373a0d96293d30..f76f220acb4e7962ecf67e31feed8358a7c981a4 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -33,30 +33,28 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } -void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir) +void Trainer::saveExamples(std::string state, Examples & examples, std::filesystem::path dir) { - auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); - auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1); + auto tensorToSave = torch::cat({torch::stack(examples.contexts), torch::stack(examples.classes)}, 1); + auto filename = fmt::format("{}_{}-{}.tensor", state, examples.lastSavedIndex, examples.currentExampleIndex-1); torch::save(tensorToSave, dir/filename); - lastSavedIndex = currentExampleIndex; - contexts.clear(); - classes.clear(); + examples.lastSavedIndex = examples.currentExampleIndex; + examples.contexts.clear(); + examples.classes.clear(); } void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { torch::AutoGradMode useGrad(false); - int maxNbExamplesPerFile = 250000; - int currentExampleIndex = 0; - int lastSavedIndex = 0; - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; + int maxNbExamplesPerFile = 50000; + std::map<std::string, Examples> examplesPerState; std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); + machine.getClassifier()->setState(machine.getStrategy().getInitialState()); machine.getStrategy().reset(); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); @@ -75,6 +73,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p if (entry.is_regular_file()) std::filesystem::remove(entry.path()); + int totalNbExamples = 0; + while (true) { if (debug) @@ -85,6 +85,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::vector<std::vector<long>> context; + auto & contexts = examplesPerState[config.getState()].contexts; + auto & classes = examplesPerState[config.getState()].classes; + auto & currentExampleIndex = examplesPerState[config.getState()].currentExampleIndex; + auto & lastSavedIndex = examplesPerState[config.getState()].lastSavedIndex; + try { context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); @@ -133,10 +138,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p gold[0] = goldIndex; currentExampleIndex += context.size(); + totalNbExamples += context.size(); classes.insert(classes.end(), context.size(), gold); if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile) - saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); + saveExamples(config.getState(), examplesPerState[config.getState()], dir); transition->apply(config); config.addToHistory(transition->getName()); @@ -148,14 +154,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p break; config.setState(movement.first); + machine.getClassifier()->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); } - if (!contexts.empty()) - saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); + for (auto & it : examplesPerState) + if (!it.second.contexts.empty()) + saveExamples(it.first, it.second, dir); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) @@ -164,7 +172,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p machine.saveDicts(); - fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); + fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples)); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples) @@ -188,8 +196,11 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) machine.getClassifier()->getOptimizer().zero_grad(); - auto data = batch.first; - auto labels = batch.second; + auto data = std::get<0>(batch); + auto labels = std::get<1>(batch); + auto state = std::get<2>(batch); + + machine.getClassifier()->setState(state); auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1)