diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f600e062af093e9964b4f26033181c76f1cb1ee..17dc5f054a48304e9daea570b48b252b64c66f27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,8 @@ cmake_minimum_required(VERSION 2.8.7) project(test_torch) +add_compile_definitions(BOOST_DISABLE_THREADS) + find_package(Torch REQUIRED) find_package(fmt REQUIRED) diff --git a/common/include/util.hpp b/common/include/util.hpp index 0c6ae63a4e4e9eab4972d9737a94e5ac8fedad50..5f2d5205e7313228e62c5c475cfbc4fe8bd7a4a5 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -39,35 +39,6 @@ void myThrow(std::string_view message, const std::experimental::source_location std::string int2HumanStr(int number); -template<typename T> -std::size_t memorySize(const T & val) -{ - myThrow("Type not yet supported"); - return sizeof val; -} - -inline std::size_t memorySize(int val) -{ - return sizeof val; -} - -template<typename T> -std::size_t memorySize(const std::basic_string<T> & val) -{ - return sizeof val + val.capacity() * sizeof (T); -} - -template<typename T> -std::size_t memorySize(const std::vector<T> & vec) -{ - std::size_t result = sizeof vec + sizeof (T) * (vec.capacity()-vec.size()); - - for (auto & elem : vec) - result += memorySize(elem); - - return result; -} - }; template <> diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 4863c31cc37a47857f4f0d921226dfe71793af3e..0aa2137a7d3901e7e082ba692843de3ae7308022 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -10,8 +10,9 @@ int main(int argc, char * argv[]) Config config(argv[3], argv[1], argv[2]); - config.printSize(stderr); + config.print(stdout); + fmt::print(stderr, "ok\n"); std::scanf("%*c"); return 0; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index cc9425ed88153e576a975c0684c5f85f829f42fe..7185c8e49437fa6df35498dbfdd4209add579fa5 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -18,14 +18,10 @@ #include <vector> #include <unordered_map> #include "util.hpp" +#include <boost/flyweight.hpp> class Config; -namespace util -{ -std::size_t memorySize(const Config & c); -}; - class Config { public : @@ -34,6 +30,8 @@ class Config static constexpr const char * EOSSymbol1 = "1"; static constexpr const char * EOSSymbol0 = "0"; + static constexpr int nbHypothesesMax = 1; + private : std::vector<std::string> colIndex2Name; @@ -42,9 +40,8 @@ class Config std::string rawInput; util::utf8string rawInputUtf8; - using ReferenceAndHypotheses = std::vector<std::string>; - using Line = std::vector<ReferenceAndHypotheses>; - std::vector<Line> lines; + int nbColumns; + std::vector<boost::flyweight<std::string>> lines; private : @@ -52,13 +49,16 @@ class Config void readRawInput(std::string_view rawFilename); void readTSVInput(std::string_view tsvFilename); - friend std::size_t util::memorySize(const Config &); - public : Config(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename); - void print(FILE * dest) const; - void printSize(FILE * dest); + void print(FILE * dest); + void addLine(); + boost::flyweight<std::string> & get(const std::string & colName, int lineIndex, int hypothesisIndex); + boost::flyweight<std::string> & get(int colIndex, int lineIndex, int hypothesisIndex); + boost::flyweight<std::string> & getLastNotEmpty(const std::string & colName, int lineIndex); + boost::flyweight<std::string> & getLastNotEmpty(int colIndex, int lineIndex); + std::size_t getNbLines() const; }; #endif diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index be26c0bbb9f4e7e173fc74aebd820cfea7465e72..0f8fcb8b2a09cc49a0ccc83cf5d62f1e0632fd0a 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -32,6 +32,8 @@ void Config::readMCD(std::string_view mcdFilename) util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, EOSColName)); colIndex2Name.emplace_back(EOSColName); colName2Index.emplace(EOSColName, colIndex2Name.size()-1); + + nbColumns = colIndex2Name.size(); } void Config::readRawInput(std::string_view rawFilename) @@ -73,7 +75,7 @@ void Config::readTSVInput(std::string_view tsvFilename) if (!inputHasBeenRead) continue; - lines.back()[colName2Index[EOSColName]][0] = EOSSymbol1; + get(EOSColName, getNbLines()-1, 0) = EOSSymbol1; continue; } @@ -91,41 +93,17 @@ void Config::readTSVInput(std::string_view tsvFilename) if ((int)splited.size() != usualNbCol) util::myThrow(fmt::format("in file {} line {} is invalid, it shoud have {} columns", tsvFilename, line, usualNbCol)); - lines.emplace_back(); - for (unsigned int i = 0; i < colIndex2Name.size(); i++) - { - lines.back().emplace_back(); - lines.back().back().emplace_back(""); - } - - lines.back()[colName2Index[EOSColName]][0] = EOSSymbol0; + addLine(); + get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; for (unsigned int i = 0; i < splited.size(); i++) if (i < colIndex2Name.size()) - lines.back()[i][0] = splited[i]; + get(i, getNbLines()-1, 0) = std::string(splited[i]); } std::fclose(file); } -void Config::printSize(FILE * dest) -{ - int rawInputSize = util::memorySize(rawInput); - int rawInputUtf8Size = util::memorySize(rawInputUtf8); - int linesSize = util::memorySize(lines); - - int totalSize = rawInputSize + rawInputUtf8Size + linesSize; - - std::string unit = "Mo"; - int unitPower = 6; - float unitMultiplier = std::stof(fmt::format("0.{:0^{}}1","",unitPower-1)); - - fmt::print(dest, "{:<15} : {:<{}.2f} {}\n", "rawInput", unitMultiplier*rawInputSize, 2+11-unitPower, unit); - fmt::print(dest, "{:<15} : {:<{}.2f} {}\n", "rawInputUtf8", unitMultiplier*rawInputUtf8Size, 2+11-unitPower, unit); - fmt::print(dest, "{:<15} : {:<{}.2f} {}\n", "lines", unitMultiplier*linesSize, 2+11-unitPower, unit); - fmt::print(dest, "{:<15} : {:<{}.2f} {}\n", "Total", unitMultiplier*totalSize, 2+11-unitPower, unit); -} - Config::Config(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) { if (tsvFilename.empty() and rawFilename.empty()) @@ -142,19 +120,49 @@ Config::Config(std::string_view mcdFilename, std::string_view tsvFilename, std:: readTSVInput(tsvFilename); } -void Config::print(FILE * dest) const +void Config::print(FILE * dest) { - for (auto & line : lines) + for (unsigned int line = 0; line < getNbLines(); line++) { - for (unsigned int i = 0; i < line.size()-1; i++) - fmt::print(dest, "{}{}", line[i].back(), i < line.size()-2 ? "\t" : "\n"); - if (line[colName2Index.at(EOSColName)].back() == EOSSymbol1) + for (int i = 0; i < nbColumns-1; i++) + fmt::print(dest, "{}{}", getLastNotEmpty(i, line).get(), i < nbColumns-2 ? "\t" : "\n"); + if (getLastNotEmpty(EOSColName, line) == EOSSymbol1) fmt::print(dest, "\n"); } } -std::size_t util::memorySize(const Config & c) +void Config::addLine() +{ + lines.resize(lines.size() + nbColumns*(nbHypothesesMax+1)); +} + +boost::flyweight<std::string> & Config::get(const std::string & colName, int lineIndex, int hypothesisIndex) +{ + return get(colName2Index[colName], lineIndex, hypothesisIndex); +} + +boost::flyweight<std::string> & Config::get(int colIndex, int lineIndex, int hypothesisIndex) +{ + return lines[lineIndex * nbColumns * (nbHypothesesMax+1) + colIndex * (nbHypothesesMax+1) + hypothesisIndex]; +} + +boost::flyweight<std::string> & Config::getLastNotEmpty(int colIndex, int lineIndex) +{ + int baseIndex = lineIndex * nbColumns * (nbHypothesesMax+1) + colIndex * (nbHypothesesMax+1); + for (int i = nbHypothesesMax; i > 0; --i) + if (!lines[baseIndex+i].get().empty()) + return lines[baseIndex+i]; + + return lines[baseIndex]; +} + +boost::flyweight<std::string> & Config::getLastNotEmpty(const std::string & colName, int lineIndex) +{ + return getLastNotEmpty(colName2Index[colName], lineIndex); +} + +std::size_t Config::getNbLines() const { - return sizeof c + memorySize(c.rawInput) + memorySize(c.rawInputUtf8) + memorySize(c.lines) + memorySize(c.colIndex2Name) + memorySize(c.colName2Index); + return lines.size() / (nbColumns * (nbHypothesesMax+1)); }