diff --git a/dev/CMakeLists.txt b/dev/CMakeLists.txt index e7c5cc7e85eaad4a8229d5739c727a816f2c2777..a4738067775772d41f8b3e80f5c6f7f3ea3a6679 100644 --- a/dev/CMakeLists.txt +++ b/dev/CMakeLists.txt @@ -3,3 +3,4 @@ FILE(GLOB SOURCES src/*.cpp) add_executable(dev src/dev.cpp) target_link_libraries(dev common) target_link_libraries(dev reading_machine) +target_link_libraries(dev torch_modules) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 889b4734328e0a7e447e11e774881c2fadc5a4fc..ccf3162211c248eff38b8f01525584d38c886e20 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -5,6 +5,7 @@ #include "SubConfig.hpp" #include "TransitionSet.hpp" #include "ReadingMachine.hpp" +#include "TestNetwork.hpp" int main(int argc, char * argv[]) { @@ -26,8 +27,7 @@ int main(int argc, char * argv[]) SubConfig config(goldConfig); config.setState(machine.getStrategy().getInitialState()); - - std::vector<std::pair<SubConfig, Transition*>> trainingExamples; + TestNetwork nn; while (true) { @@ -35,7 +35,10 @@ int main(int argc, char * argv[]) if (!transition) util::myThrow("No transition appliable !"); - trainingExamples.emplace_back(config, transition); + //here train + auto testo = nn(config); + +// std::cout << testo << std::endl; transition->apply(config); config.addToHistory(transition->getName()); @@ -52,8 +55,6 @@ int main(int argc, char * argv[]) config.update(); } - trainingExamples[10000].first.printForDebug(stderr); - return 0; } diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f286a402a2b4f11e4253a10d35e6b4a50ded00c7 --- /dev/null +++ b/torch_modules/include/TestNetwork.hpp @@ -0,0 +1,22 @@ +#ifndef TESTNETWORK__H +#define TESTNETWORK__H + +#include <torch/torch.h> +#include "Config.hpp" + +class TestNetworkImpl : torch::nn::Module +{ + private : + + std::map<Config::String, std::size_t> dict; + torch::nn::Embedding wordEmbeddings{nullptr}; + + public : + + TestNetworkImpl(); + torch::Tensor forward(const Config & config); + std::size_t getOrAddDictValue(Config::String s); +}; +TORCH_MODULE(TestNetwork); + +#endif diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc3d0e619d61b01b199b896bdea446826207ee90 --- /dev/null +++ b/torch_modules/src/TestNetwork.cpp @@ -0,0 +1,52 @@ +#include "TestNetwork.hpp" + +TestNetworkImpl::TestNetworkImpl() +{ + getOrAddDictValue(Config::String("_null_")); + getOrAddDictValue(Config::String("_unknown_")); + getOrAddDictValue(Config::String("_S_")); + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100)); +} + +torch::Tensor TestNetworkImpl::forward(const Config & config) +{ +// std::vector<std::size_t> test{0,1}; +// torch::Tensor tens = torch::from_blob(test.data(), {1,2}); +// return wordEmbeddings(tens); + constexpr int windowSize = 5; + int startIndex = config.getWordIndex(); + while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize) + startIndex--; + int endIndex = config.getWordIndex(); + while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize) + endIndex++; + + std::vector<std::size_t> words; + for (int i = startIndex; i <= endIndex; ++i) + { + if (!config.has(0, i, 0)) + util::myThrow(fmt::format("Config do not have line %d", i)); + + words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i))); + } + + if (words.empty()) + util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex)); + + return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong)); +} + +std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s) +{ + if (s.get().empty()) + return dict[Config::String("_null_")]; + + const auto & found = dict.find(s); + + if (found == dict.end()) + return dict[s] = dict.size(); + + return found->second; +} +