diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index fb4b04bae5a812d2710def5e83f3418a7c57e60b..d976573f3c3c383acfb1e769a8ec84277ed43861 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -11,7 +11,7 @@ class Dict enum State {Open, Closed}; enum Encoding {Binary, Ascii}; - private : + public : static constexpr char const * unknownValueStr = "__unknownValue__"; static constexpr char const * nullValueStr = "__nullValue__"; diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 8d83e28f843445eafa04cb7ae068b23f4e91e547..b5bd81f47d2f1b37427c699e3eb1d3eb48b15c27 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -29,17 +29,10 @@ int main(int argc, char * argv[]) config.setState(machine.getStrategy().getInitialState()); + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; - TestNetwork nn(machine.getTransitionSet().size()); - torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); - optimizer.zero_grad(); - - std::vector<torch::Tensor> predictionsBatch; - std::vector<torch::Tensor> referencesBatch; - std::vector<std::unique_ptr<Config>> configs; - std::vector<std::size_t> classes; - - fmt::print("Generating dataset..."); + fmt::print("Generating dataset...\n"); Dict dict(Dict::State::Open); @@ -49,21 +42,13 @@ int main(int argc, char * argv[]) if (!transition) util::myThrow("No transition appliable !"); - //here train - int goldIndex = 3; - auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong); - gold[goldIndex] = 1; -// referencesBatch.emplace_back(gold); -// predictionsBatch.emplace_back(nn(config)); + auto context = config.extractContext(5,5,dict); + contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); -// auto loss = torch::nll_loss(prediction, gold); -// loss.backward(); -// optimizer.step(); - configs.emplace_back(std::unique_ptr<Config>(new SubConfig(config))); - classes.emplace_back(goldIndex); + int goldIndex = 3; + auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone(); -// if (config.getWordIndex() >= 500) -// exit(1); + classes.emplace_back(gold); transition->apply(config); config.addToHistory(transition->getName()); @@ -80,16 +65,39 @@ int main(int argc, char * argv[]) config.update(); } - auto dataset = ConfigDataset(configs, classes, machine.getTransitionSet().size(), dict).map(torch::data::transforms::Stack<>()); + auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); + + fmt::print("Done! size={}\n", *dataset.size()); - fmt::print("Done!\n"); + int batchSize = 100; + auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize)); - auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), 50); + TestNetwork nn(machine.getTransitionSet().size(), 5); + torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); - for (auto & batch : *dataLoader) + for (int epoch = 1; epoch <= 5; ++epoch) { - auto data = batch.data; - auto labels = batch.target.squeeze(); + float totalLoss = 0.0; + torch::Tensor example; + + for (auto & batch : *dataLoader) + { + optimizer.zero_grad(); + + auto data = batch.data; + auto labels = batch.target.squeeze(); + + auto prediction = nn(data); + example = prediction[0]; + + auto loss = torch::nll_loss(torch::log(prediction), labels); + totalLoss += loss.item<float>(); + loss.backward(); + optimizer.step(); + } + + fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); + std::cout << example << std::endl; } return 0; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 21edddd928d7a4fd7820ef5aea8778f9142f51ac..5f34b241d7e494a48b345b2d0713647f3e6000e3 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -99,7 +99,7 @@ class Config String getState() const; void setState(const std::string state); bool stateIsDone() const; - std::vector<int> extractContext(int leftBorder, int rightBorder, Dict & dict) const; + std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const; }; diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 0075bcb458f864cd05b884a4ba747e039a4d8be9..2f6e7cf138875e225877990894964700e78fbc35 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -366,30 +366,29 @@ bool Config::stateIsDone() const return !has(0, wordIndex+1, 0); } -std::vector<int> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const +std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const { - std::vector<int> context; + std::stack<int> leftContext; + for (int index = wordIndex-1; has(0,index,0) && (int)leftContext.size() < leftBorder; --index) + if (isToken(index)) + leftContext.push(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index))); - int startIndex = wordIndex; + std::vector<long> context; - for (int i = 0; i < leftBorder and has(0,startIndex-1,0); i++) - do - --startIndex; - while (!isToken(startIndex) and has(0,startIndex-1,0)); - - int endIndex = wordIndex; - - for (int i = 0; i < rightBorder and has(0,endIndex+1,0); i++) - do - ++endIndex; - while (!isToken(endIndex) and has(0,endIndex+1,0)); + while ((int)context.size() < leftBorder-(int)leftContext.size()) + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + while (!leftContext.empty()) + { + context.emplace_back(leftContext.top()); + leftContext.pop(); + } - for (int i = startIndex; i <= endIndex; ++i) - if (isToken(i)) - context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", i))); + for (int index = wordIndex; has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index) + if (isToken(index)) + context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index))); - //TODO gérer les cas où la taille est differente... - return {0}; + while ((int)context.size() < leftBorder+rightBorder+1) + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); return context; } diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index ee4430ce3171cda9f9e8f49575d09a5f821121c8..7aa878e62ce6d8beeabca4cd763353bb9023ff22 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -8,14 +8,12 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset> { private : - std::vector<std::unique_ptr<Config>> const & configs; - std::vector<std::size_t> const & classes; - std::size_t nbClasses; - Dict & dict; + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; public : - explicit ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict); + explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes); torch::optional<size_t> size() const override; torch::data::Example<> get(size_t index) override; }; diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp index eceb9c9b8911364ad01f4df16314c80f0c7af550..b1bb4e11200197f029556f7f3d93338ca0ce072e 100644 --- a/torch_modules/include/TestNetwork.hpp +++ b/torch_modules/include/TestNetwork.hpp @@ -8,15 +8,14 @@ class TestNetworkImpl : public torch::nn::Module { private : - std::map<Config::String, std::size_t> dict; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear{nullptr}; + int focusedIndex; public : - TestNetworkImpl(int nbOutputs); - torch::Tensor forward(const Config & config); - std::size_t getOrAddDictValue(Config::String s); + TestNetworkImpl(int nbOutputs, int focusedIndex); + torch::Tensor forward(torch::Tensor input); }; TORCH_MODULE(TestNetwork); diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index f9b5b57c976c30bcbdedb6abe2bd4685af13a95d..e2d3853312fc657c4dfd31f08197238adbab7e47 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -1,20 +1,16 @@ #include "ConfigDataset.hpp" -ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict) +ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes) { } torch::optional<size_t> ConfigDataset::size() const { - return configs.size(); + return contexts.size(); } torch::data::Example<> ConfigDataset::get(size_t index) { - auto context = configs[index]->extractContext(5,5,dict); - auto tensorClass = torch::zeros(nbClasses); - tensorClass[classes[index]] = 1; - - return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass}; + return {contexts[index], classes[index]}; } diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp index 7ac71241fb9691b77829d179a092fd2232862478..3e3c010a7634824181e205b69161afbaefd96585 100644 --- a/torch_modules/src/TestNetwork.cpp +++ b/torch_modules/src/TestNetwork.cpp @@ -1,56 +1,22 @@ #include "TestNetwork.hpp" -TestNetworkImpl::TestNetworkImpl(int nbOutputs) +TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex) { - getOrAddDictValue(Config::String("_null_")); - getOrAddDictValue(Config::String("_unknown_")); - getOrAddDictValue(Config::String("_S_")); - - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100)); - linear = register_module("linear", torch::nn::Linear(100, nbOutputs)); -} - -torch::Tensor TestNetworkImpl::forward(const Config & config) -{ -// std::vector<std::size_t> test{0,1}; -// torch::Tensor tens = torch::from_blob(test.data(), {1,2}); -// return wordEmbeddings(tens); - constexpr int windowSize = 5; - int wordIndex = config.getWordIndex(); - int startIndex = wordIndex; - while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize) - startIndex--; - int endIndex = wordIndex; - while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize) - endIndex++; - - std::vector<std::size_t> words; - for (int i = startIndex; i <= endIndex; ++i) - { - if (!config.has(0, i, 0)) - util::myThrow(fmt::format("Config do not have line %d", i)); - - words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i))); - } - - if (words.empty()) - util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex)); - - auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong)); - - return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0); + constexpr int embeddingsSize = 100; + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize)); + linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); + this->focusedIndex = focusedIndex; } -std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s) +torch::Tensor TestNetworkImpl::forward(torch::Tensor input) { - if (s.get().empty()) - return dict[Config::String("_null_")]; - - const auto & found = dict.find(s); + // input dim = {batch, sequence, embeddings} + auto wordsAsEmb = wordEmbeddings(input); + // reshaped dim = {sequence, batch, embeddings} + auto reshaped = wordsAsEmb.permute({1,0,2}); - if (found == dict.end()) - return dict[s] = dict.size(); + auto res = torch::softmax(linear(reshaped[focusedIndex]), 1); - return found->second; + return res; }