diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index ccf3162211c248eff38b8f01525584d38c886e20..6ae47d56af47beb0e27e91350d2e37056377b61e 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -6,6 +6,7 @@ #include "TransitionSet.hpp" #include "ReadingMachine.hpp" #include "TestNetwork.hpp" +#include "ConfigDataset.hpp" int main(int argc, char * argv[]) { @@ -27,7 +28,15 @@ int main(int argc, char * argv[]) SubConfig config(goldConfig); config.setState(machine.getStrategy().getInitialState()); - TestNetwork nn; + + + TestNetwork nn(machine.getTransitionSet().size()); + torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); + optimizer.zero_grad(); + + std::vector<torch::Tensor> predictionsBatch; + std::vector<torch::Tensor> referencesBatch; + std::vector<SubConfig> configs; while (true) { @@ -36,9 +45,22 @@ int main(int argc, char * argv[]) util::myThrow("No transition appliable !"); //here train - auto testo = nn(config); + int goldIndex = 3; + auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong); + gold[goldIndex] = 1; +// referencesBatch.emplace_back(gold); +// predictionsBatch.emplace_back(nn(config)); + +// auto loss = torch::nll_loss(prediction, gold); +// loss.backward(); +// optimizer.step(); + configs.emplace_back(config); + + if (config.getWordIndex()%1 == 0) + fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines()); -// std::cout << testo << std::endl; +// if (config.getWordIndex() >= 500) +// exit(1); transition->apply(config); config.addToHistory(transition->getName()); diff --git a/reading_machine/include/FeatureFunction.hpp b/reading_machine/include/FeatureFunction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed860b0492853f82835fe7782f2d5457bfaaa8b3 --- /dev/null +++ b/reading_machine/include/FeatureFunction.hpp @@ -0,0 +1,28 @@ +#ifndef FEATUREFUNCTION__H +#define FEATUREFUNCTION__H + +#include <map> +#include <string> +#include "Config.hpp" + +class FeatureFunction +{ + using Representation = std::vector<std::size_t>; + using Feature = std::function<Config::String(const Config &)>; + + private : + + std::map<std::string, Feature> features; + std::map<Config::String, std::size_t> dictionary; + + private : + + const Feature & getOrCreateFeature(const std::string & name); + + public : + + FeatureFunction(const std::vector<std::string_view> & lines); + Representation getRepresentation(const Config & config) const; +}; + +#endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index ab5d79483be3f0c7b919d17de8e59b307af7ba27..ede9f7a5212bbff43575a1f0b0a9f9164ec23088 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -4,6 +4,7 @@ #include <memory> #include "Classifier.hpp" #include "Strategy.hpp" +#include "FeatureFunction.hpp" class ReadingMachine { @@ -12,6 +13,7 @@ class ReadingMachine std::string name; std::unique_ptr<Classifier> classifier; std::unique_ptr<Strategy> strategy; + std::unique_ptr<FeatureFunction> featureFunction; public : diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 188ce70e1e4d77c3e8dfccec0f09de83d916f5eb..898ce837988ae39bac99e557599044225dd954d1 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -18,6 +18,7 @@ class TransitionSet TransitionSet(const std::string & filename); std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); Transition * getBestAppliableTransition(const Config & c); + std::size_t size() const; }; #endif diff --git a/reading_machine/src/FeatureFunction.cpp b/reading_machine/src/FeatureFunction.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c516220376939faf3f23e5b28f94d8aaada02861 --- /dev/null +++ b/reading_machine/src/FeatureFunction.cpp @@ -0,0 +1,45 @@ +#include "FeatureFunction.hpp" + +FeatureFunction::FeatureFunction(const std::vector<std::string_view> & lines) +{ + if (!util::doIfNameMatch(std::regex("Features :(.*)"), lines[0], [](auto){})) + util::myThrow(fmt::format("Wrong line '{}', expected 'Features :'", lines[0])); + + for (unsigned int i = 1; i < lines.size(); i++) + { + if (util::doIfNameMatch(std::regex("(?: |\\t)*buffer from ((?:-|\\+|)\\d+) to ((?:-|\\+|)\\d+)"), lines[i], [this](auto &sm) + { + getOrCreateFeature(fmt::format("b.")); + })) + continue; + + util::myThrow(fmt::format("Unknown feature directive '{}'", lines[i])); + } + + for (auto & it : features) + fmt::print("{}\n", it.first); +} + +FeatureFunction::Representation FeatureFunction::getRepresentation(const Config & config) const +{ + Representation representation; + + return representation; +} + +const FeatureFunction::Feature & FeatureFunction::getOrCreateFeature(const std::string & name) +{ + auto found = features.find(name); + + if (found != features.end()) + return found->second; + + if (util::doIfNameMatch(std::regex(""), name, [this,name](auto){features[name] = Feature();})) + return features[name]; + + + util::myThrow(fmt::format("Unknown feature '{}'", name)); + + return found->second; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index f2d49bef54f288c70deca580d981a06b256ba556..ec04962fcdd7cfb7f799d802397e510665241bb0 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -7,17 +7,22 @@ ReadingMachine::ReadingMachine(const std::string & filename) char buffer[1024]; std::string fileContent; + std::vector<std::string> lines; while (!std::feof(file)) { if (buffer != std::fgets(buffer, 1024, file)) break; + // If line is blank or commented (# or //), ignore it + if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){})) + continue; - fileContent += buffer; + if (buffer[std::strlen(buffer)-1] == '\n') + buffer[std::strlen(buffer)-1] = '\0'; + + lines.emplace_back(buffer); } std::fclose(file); - auto lines = util::split(fileContent, '\n'); - try { unsigned int curLine = 0; @@ -28,7 +33,15 @@ ReadingMachine::ReadingMachine(const std::string & filename) if (!classifier.get()) util::myThrow("No Classifier specified"); - std::vector<std::string_view> restOfFile(lines.begin()+curLine-1, lines.end()); + --curLine; + //std::vector<std::string_view> restOfFile; + //while (curLine < lines.size() and !util::doIfNameMatch(std::regex("Strategy(.*)"),lines[curLine], [](auto){})) + // restOfFile.emplace_back(lines[curLine++]); + + //featureFunction.reset(new FeatureFunction(restOfFile)); + + auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end()); + strategy.reset(new Strategy(restOfFile)); } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));} diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index a1a1f8904b1874230bde7c257f6f6ea36b4b2732..55b67a0425ea6245063356d8b2a9295b115a7c61 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -67,3 +67,8 @@ Transition * TransitionSet::getBestAppliableTransition(const Config & c) return result; } +std::size_t TransitionSet::size() const +{ + return transitions.size(); +} + diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b68f5fdbcfc70dbf16630e8d6bdbfccc79682a88 --- /dev/null +++ b/torch_modules/include/ConfigDataset.hpp @@ -0,0 +1,21 @@ +#ifndef CONFIGDATASET__H +#define CONFIGDATASET__H + +#include <torch/torch.h> +#include "Config.hpp" + +class ConfigDataset : public torch::data::Dataset<ConfigDataset> +{ + private : + + std::vector<Config> const & configs; + std::vector<std::size_t> const & classes; + + public : + + ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes); + torch::optional<size_t> size() const override; + torch::data::Example<> get(size_t index) override; +}; + +#endif diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp index f286a402a2b4f11e4253a10d35e6b4a50ded00c7..eceb9c9b8911364ad01f4df16314c80f0c7af550 100644 --- a/torch_modules/include/TestNetwork.hpp +++ b/torch_modules/include/TestNetwork.hpp @@ -4,16 +4,17 @@ #include <torch/torch.h> #include "Config.hpp" -class TestNetworkImpl : torch::nn::Module +class TestNetworkImpl : public torch::nn::Module { private : std::map<Config::String, std::size_t> dict; torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Linear linear{nullptr}; public : - TestNetworkImpl(); + TestNetworkImpl(int nbOutputs); torch::Tensor forward(const Config & config); std::size_t getOrAddDictValue(Config::String s); }; diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28cdefc4fd6e2435b2bf9502ee6f3aaba34eb858 --- /dev/null +++ b/torch_modules/src/ConfigDataset.cpp @@ -0,0 +1,16 @@ +#include "ConfigDataset.hpp" + +ConfigDataset::ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes) : configs(configs), classes(classes) +{ +} + +torch::optional<size_t> ConfigDataset::size() const +{ + +} + +torch::data::Example<> ConfigDataset::get(size_t index) +{ + +} + diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp index fc3d0e619d61b01b199b896bdea446826207ee90..7ac71241fb9691b77829d179a092fd2232862478 100644 --- a/torch_modules/src/TestNetwork.cpp +++ b/torch_modules/src/TestNetwork.cpp @@ -1,12 +1,13 @@ #include "TestNetwork.hpp" -TestNetworkImpl::TestNetworkImpl() +TestNetworkImpl::TestNetworkImpl(int nbOutputs) { getOrAddDictValue(Config::String("_null_")); getOrAddDictValue(Config::String("_unknown_")); getOrAddDictValue(Config::String("_S_")); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100)); + linear = register_module("linear", torch::nn::Linear(100, nbOutputs)); } torch::Tensor TestNetworkImpl::forward(const Config & config) @@ -15,11 +16,12 @@ torch::Tensor TestNetworkImpl::forward(const Config & config) // torch::Tensor tens = torch::from_blob(test.data(), {1,2}); // return wordEmbeddings(tens); constexpr int windowSize = 5; - int startIndex = config.getWordIndex(); - while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize) + int wordIndex = config.getWordIndex(); + int startIndex = wordIndex; + while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize) startIndex--; - int endIndex = config.getWordIndex(); - while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize) + int endIndex = wordIndex; + while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize) endIndex++; std::vector<std::size_t> words; @@ -32,9 +34,11 @@ torch::Tensor TestNetworkImpl::forward(const Config & config) } if (words.empty()) - util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex)); + util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex)); - return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong)); + auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong)); + + return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0); } std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)