diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fb4b04bae5a812d2710def5e83f3418a7c57e60b --- /dev/null +++ b/common/include/Dict.hpp @@ -0,0 +1,44 @@ +#ifndef DICT__H +#define DICT__H + +#include <string> +#include <unordered_map> + +class Dict +{ + public : + + enum State {Open, Closed}; + enum Encoding {Binary, Ascii}; + + private : + + static constexpr char const * unknownValueStr = "__unknownValue__"; + static constexpr char const * nullValueStr = "__nullValue__"; + static constexpr std::size_t maxEntrySize = 5000; + + private : + + std::unordered_map<std::string, int> elementsToIndexes; + State state; + + public : + + Dict(State state); + Dict(const char * filename, State state); + + private : + + void readFromFile(const char * filename); + + public : + + void insert(const std::string & element); + int getIndexOrInsert(const std::string & element); + void setState(State state); + void save(std::FILE * destination, Encoding encoding); + bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding); + void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding); +}; + +#endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fcfc942dd73e38140bcc18ae341bea9071a8c7c8 --- /dev/null +++ b/common/src/Dict.cpp @@ -0,0 +1,125 @@ +#include "Dict.hpp" +#include "util.hpp" + +Dict::Dict(State state) +{ + setState(state); + insert(unknownValueStr); + insert(nullValueStr); +} + +Dict::Dict(const char * filename, State state) +{ + readFromFile(filename); + setState(state); +} + +void Dict::readFromFile(const char * filename) +{ + std::FILE * file = std::fopen(filename, "r"); + + if (!file) + util::myThrow(fmt::format("could not open file \'%s\'", filename)); + + char buffer[1048]; + if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1) + util::myThrow(fmt::format("file \'%s\' bad format", filename)); + + Encoding encoding{Encoding::Ascii}; + if (std::string(buffer) == "Ascii") + encoding = Encoding::Ascii; + else if (std::string(buffer) == "Binary") + encoding = Encoding::Binary; + else + util::myThrow(fmt::format("file \'%s\' bad format", filename)); + + int nbEntries; + + if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1) + util::myThrow(fmt::format("file \'%s\' bad format", filename)); + + elementsToIndexes.reserve(nbEntries); + + int entryIndex; + char entryString[maxEntrySize+1]; + for (int i = 0; i < nbEntries; i++) + { + if (!readEntry(file, &entryIndex, entryString, encoding)) + util::myThrow(fmt::format("file \'%s\' line {} bad format", filename, i)); + + elementsToIndexes[entryString] = entryIndex; + } + + std::fclose(file); +} + +void Dict::insert(const std::string & element) +{ + if (element.size() > maxEntrySize) + util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); + + elementsToIndexes.emplace(element, elementsToIndexes.size()); +} + +int Dict::getIndexOrInsert(const std::string & element) +{ + if (state == State::Open) + insert(element); + + const auto & found = elementsToIndexes.find(element); + + if (found == elementsToIndexes.end()) + return elementsToIndexes[unknownValueStr]; + + return found->second; +} + +void Dict::setState(State state) +{ + this->state = state; +} + +void Dict::save(std::FILE * destination, Encoding encoding) +{ + fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary"); + fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size()); + for (auto & it : elementsToIndexes) + printEntry(destination, it.second, it.first, encoding); +} + +bool Dict::readEntry(std::FILE * file, int * index, char * entry, Encoding encoding) +{ + if (encoding == Encoding::Ascii) + { + static std::string readFormat = "%d\t%"+std::to_string(maxEntrySize)+"[^\n]\n"; + return fscanf(file, readFormat.c_str(), index, entry) == 2; + } + else + { + if (std::fread(index, sizeof *index, 1, file) != 1) + return false; + for (unsigned int i = 0; i < maxEntrySize; i++) + { + if (std::fread(entry+i, 1, 1, file) != 1) + return false; + if (!entry[i]) + return true; + } + return false; + } +} + +void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) +{ + if (encoding == Encoding::Ascii) + { + static std::string printFormat = "%d\t%s\n"; + fprintf(file, printFormat.c_str(), index, entry.c_str()); + } + else + { + std::fwrite(&index, sizeof index, 1, file); + std::fwrite(entry.c_str(), 1, entry.size()+1, file); + } +} + diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 6ae47d56af47beb0e27e91350d2e37056377b61e..8d83e28f843445eafa04cb7ae068b23f4e91e547 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -36,7 +36,12 @@ int main(int argc, char * argv[]) std::vector<torch::Tensor> predictionsBatch; std::vector<torch::Tensor> referencesBatch; - std::vector<SubConfig> configs; + std::vector<std::unique_ptr<Config>> configs; + std::vector<std::size_t> classes; + + fmt::print("Generating dataset..."); + + Dict dict(Dict::State::Open); while (true) { @@ -54,10 +59,8 @@ int main(int argc, char * argv[]) // auto loss = torch::nll_loss(prediction, gold); // loss.backward(); // optimizer.step(); - configs.emplace_back(config); - - if (config.getWordIndex()%1 == 0) - fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines()); + configs.emplace_back(std::unique_ptr<Config>(new SubConfig(config))); + classes.emplace_back(goldIndex); // if (config.getWordIndex() >= 500) // exit(1); @@ -77,6 +80,18 @@ int main(int argc, char * argv[]) config.update(); } + auto dataset = ConfigDataset(configs, classes, machine.getTransitionSet().size(), dict).map(torch::data::transforms::Stack<>()); + + fmt::print("Done!\n"); + + auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), 50); + + for (auto & batch : *dataLoader) + { + auto data = batch.data; + auto labels = batch.target.squeeze(); + } + return 0; } diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 6738810d24af1cc3921659c5c9e38083bfbce6dd..21edddd928d7a4fd7820ef5aea8778f9142f51ac 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -7,6 +7,7 @@ #include <boost/flyweight.hpp> #include <boost/circular_buffer.hpp> #include "util.hpp" +#include "Dict.hpp" class Config { @@ -15,6 +16,8 @@ class Config static constexpr const char * EOSColName = "EOS"; static constexpr const char * EOSSymbol1 = "1"; static constexpr const char * EOSSymbol0 = "0"; + static constexpr const char * headColName = "HEAD"; + static constexpr const char * idColName = "ID"; static constexpr int nbHypothesesMax = 1; public : @@ -96,6 +99,7 @@ class Config String getState() const; void setState(const std::string state); bool stateIsDone() const; + std::vector<int> extractContext(int leftBorder, int rightBorder, Dict & dict) const; }; diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index f4d1001d3d5cc35b49d352d93bc889a3b280a3bb..0075bcb458f864cd05b884a4ba747e039a4d8be9 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -366,3 +366,31 @@ bool Config::stateIsDone() const return !has(0, wordIndex+1, 0); } +std::vector<int> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const +{ + std::vector<int> context; + + int startIndex = wordIndex; + + for (int i = 0; i < leftBorder and has(0,startIndex-1,0); i++) + do + --startIndex; + while (!isToken(startIndex) and has(0,startIndex-1,0)); + + int endIndex = wordIndex; + + for (int i = 0; i < rightBorder and has(0,endIndex+1,0); i++) + do + ++endIndex; + while (!isToken(endIndex) and has(0,endIndex+1,0)); + + for (int i = startIndex; i <= endIndex; ++i) + if (isToken(i)) + context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", i))); + + //TODO gérer les cas où la taille est differente... + return {0}; + + return context; +} + diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index b68f5fdbcfc70dbf16630e8d6bdbfccc79682a88..ee4430ce3171cda9f9e8f49575d09a5f821121c8 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -8,12 +8,14 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset> { private : - std::vector<Config> const & configs; + std::vector<std::unique_ptr<Config>> const & configs; std::vector<std::size_t> const & classes; + std::size_t nbClasses; + Dict & dict; public : - ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes); + explicit ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict); torch::optional<size_t> size() const override; torch::data::Example<> get(size_t index) override; }; diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 28cdefc4fd6e2435b2bf9502ee6f3aaba34eb858..f9b5b57c976c30bcbdedb6abe2bd4685af13a95d 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -1,16 +1,20 @@ #include "ConfigDataset.hpp" -ConfigDataset::ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes) : configs(configs), classes(classes) +ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict) { } torch::optional<size_t> ConfigDataset::size() const { - + return configs.size(); } torch::data::Example<> ConfigDataset::get(size_t index) { + auto context = configs[index]->extractContext(5,5,dict); + auto tensorClass = torch::zeros(nbClasses); + tensorClass[classes[index]] = 1; + return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass}; }