diff --git a/CMakeLists.txt b/CMakeLists.txt index ed388ba6e6af5c5f93a48ab48ae819ffbf077919..0e841bdc34df293c23ac42dcb8dbe8a4a293ea72 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ add_compile_definitions(BOOST_DISABLE_THREADS) find_package(Torch REQUIRED) find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options) +find_package(TBB REQUIRED tbb) include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) diff --git a/common/include/util.hpp b/common/include/util.hpp index bd1a48b19470bc21b24ead6d808fff07e20689ca..0a3f6fef67402d9a4bcc51006a415929b40cf871 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -9,6 +9,7 @@ #include <filesystem> #include <experimental/source_location> #include <boost/flyweight.hpp> +#include <boost/flyweight/no_tracking.hpp> #include <boost/circular_buffer.hpp> #include "fmt/core.h" #include "utf8.hpp" @@ -16,6 +17,7 @@ namespace util { +using String = boost::flyweights::flyweight<std::string,boost::flyweights::no_tracking>; constexpr float float2longScale = 10000; @@ -56,18 +58,6 @@ float long2float(long l); std::vector<std::vector<std::string>> readTSV(std::string_view tsvFilename); -template <typename T> -bool isEmpty(const std::vector<T> & s) -{ - return s.empty(); -} - -template <typename T> -bool isEmpty(const std::basic_string<T> & s) -{ - return s.empty(); -} - template <typename T> std::size_t getSize(const std::vector<T> & s) { @@ -75,17 +65,11 @@ std::size_t getSize(const std::vector<T> & s) } template <typename T> -std::size_t getSize(const boost::flyweight<T> & s) +std::size_t getSize(const boost::flyweights::flyweight<T> & s) { return getSize(s.get()); } -template <typename T> -bool isEmpty(const boost::flyweight<T> & s) -{ - return isEmpty(s.get()); -} - bool doIfNameMatch(const std::regex & reg, std::string_view name, const std::function<void(const std::smatch &)> & f); bool choiceWithProbability(float probability); @@ -156,16 +140,16 @@ struct fmt::formatter<std::experimental::source_location> } }; -template <typename T> -struct fmt::formatter<boost::flyweight<T>> -{ - constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } - - template <typename FormatContext> - auto format(const boost::flyweight<T> & s, FormatContext & ctx) - { - return format_to(ctx.out(), "{}", s.get()); - } -}; +//template <typename T> +//struct fmt::formatter<boost::flyweights::flyweight<T>> +//{ +// constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } +// +// template <typename FormatContext> +// auto format(const boost::flyweights::flyweight<T> & s, FormatContext & ctx) +// { +// return format_to(ctx.out(), "{}", s.get()); +// } +//}; #endif diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index ad2e6a66bbc9eb5fdbe874bbc1760895294f3e42..38957afa3eb2189f860d1d748e1099f7e9673e69 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -10,10 +10,6 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh { constexpr int printInterval = 50; - torch::AutoGradMode useGrad(false); - machine.trainMode(false); - machine.setDictsState(Dict::State::Closed); - int nbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); diff --git a/macaon/CMakeLists.txt b/macaon/CMakeLists.txt index 2d95abc71b9225dc91f3813e39f87cba7f52747a..5e2686a17a43724ba8fca9ed5f0e0768ab48d4f8 100644 --- a/macaon/CMakeLists.txt +++ b/macaon/CMakeLists.txt @@ -3,4 +3,5 @@ FILE(GLOB SOURCES src/*.cpp) add_executable(macaon src/macaon.cpp) target_link_libraries(macaon trainer) target_link_libraries(macaon decoder) +target_link_libraries(macaon ${TBB_IMPORTED_TARGETS}) install(TARGETS macaon DESTINATION bin) diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 801173f083cda83d6f37087a7f18ead64414670d..49eff0d1d6ac3d29088e3ae585738b85b2c2c839 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -4,7 +4,6 @@ #include <memory> #include <string> #include <vector> -#include <boost/flyweight.hpp> #include <boost/circular_buffer.hpp> #include "util.hpp" #include "Dict.hpp" @@ -39,14 +38,13 @@ class Config public : - using String = boost::flyweight<std::string>; using Utf8String = util::utf8string; - using ValueIterator = std::vector<String>::iterator; - using ConstValueIterator = std::vector<String>::const_iterator; + using ValueIterator = std::vector<util::String>::iterator; + using ConstValueIterator = std::vector<util::String>::const_iterator; private : - std::vector<String> lines; + std::vector<util::String> lines; protected : @@ -54,8 +52,8 @@ class Config std::size_t wordIndex{0}; std::size_t characterIndex{0}; std::size_t currentSentenceStartRawInput{0}; - String state{"NONE"}; - boost::circular_buffer<String> history{10}; + util::String state{"NONE"}; + boost::circular_buffer<util::String> history{10}; boost::circular_buffer<std::size_t> stack{50}; float chosenActionScore{0.0}; std::vector<std::string> extraColumns{commentsColName, rawRangeStartColName, rawRangeEndColName, isMultiColName, childsColName, sentIdColName, EOSColName}; @@ -91,13 +89,13 @@ class Config void addLines(unsigned int nbLines); void resizeLines(unsigned int nbLines); bool has(int colIndex, int lineIndex, int hypothesisIndex) const; - String & get(int colIndex, int lineIndex, int hypothesisIndex); - const String & getConst(int colIndex, int lineIndex, int hypothesisIndex) const; - String & getLastNotEmpty(int colIndex, int lineIndex); - String & getLastNotEmptyHyp(int colIndex, int lineIndex); - const String & getLastNotEmptyHypConst(int colIndex, int lineIndex) const; - const String & getLastNotEmptyConst(int colIndex, int lineIndex) const; - const String & getAsFeature(int colIndex, int lineIndex) const; + util::String & get(int colIndex, int lineIndex, int hypothesisIndex); + const util::String & getConst(int colIndex, int lineIndex, int hypothesisIndex) const; + util::String & getLastNotEmpty(int colIndex, int lineIndex); + util::String & getLastNotEmptyHyp(int colIndex, int lineIndex); + const util::String & getLastNotEmptyHypConst(int colIndex, int lineIndex) const; + const util::String & getLastNotEmptyConst(int colIndex, int lineIndex) const; + const util::String & getAsFeature(int colIndex, int lineIndex) const; ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex); ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const; std::size_t & getStackRef(int relativeIndex); @@ -110,15 +108,15 @@ class Config void print(FILE * dest, bool printHeader = true) const; void printForDebug(FILE * dest) const; bool has(const std::string & colName, int lineIndex, int hypothesisIndex) const; - String & get(const std::string & colName, int lineIndex, int hypothesisIndex); - const String & getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const; - String & getLastNotEmpty(const std::string & colName, int lineIndex); - const String & getLastNotEmptyConst(const std::string & colName, int lineIndex) const; - String & getLastNotEmptyHyp(const std::string & colName, int lineIndex); - const String & getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const; - const String & getAsFeature(const std::string & colName, int lineIndex) const; - String & getFirstEmpty(int colIndex, int lineIndex); - String & getFirstEmpty(const std::string & colName, int lineIndex); + util::String & get(const std::string & colName, int lineIndex, int hypothesisIndex); + const util::String & getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const; + util::String & getLastNotEmpty(const std::string & colName, int lineIndex); + const util::String & getLastNotEmptyConst(const std::string & colName, int lineIndex) const; + util::String & getLastNotEmptyHyp(const std::string & colName, int lineIndex); + const util::String & getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const; + const util::String & getAsFeature(const std::string & colName, int lineIndex) const; + util::String & getFirstEmpty(int colIndex, int lineIndex); + util::String & getFirstEmpty(const std::string & colName, int lineIndex); bool hasCharacter(int letterIndex) const; const util::utf8char & getLetter(int letterIndex) const; void addToHistory(const std::string & transition); @@ -143,12 +141,12 @@ class Config std::size_t getCharacterIndex() const; long getRelativeWordIndex(Object object, int relativeIndex) const; bool hasRelativeWordIndex(Object object, int relativeIndex) const; - const String & getHistory(int relativeIndex) const; + const util::String & getHistory(int relativeIndex) const; std::size_t getStack(int relativeIndex) const; std::size_t getStackSize() const; bool hasHistory(int relativeIndex) const; bool hasStack(int relativeIndex) const; - String getState() const; + util::String getState() const; void setState(const std::string state); bool stateIsDone() const; void addPredicted(const std::set<std::string> & predicted); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index a974ac47715775054b9a84216ca7a08834d284de..3135635f69c7d29956bd2b28c860af497e58281f 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -18,7 +18,7 @@ class ReadingMachine std::string name; std::filesystem::path path; std::vector<std::unique_ptr<Classifier>> classifiers; - std::map<std::string, int> state2classifier; + std::unordered_map<std::string, int> state2classifier; std::vector<std::string> strategyDefinition; std::vector<std::vector<std::string>> classifierDefinitions; std::vector<std::string> classifierNames; diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index e5287bdb4e8834f3e2fc8f30fc8df95917b7584f..2fc63ce295dae575b4cb1d0d9e881f8f1639da4e 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -153,7 +153,7 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde { auto apply = [colName, lineIndex, addition](Config & config, Action &) { - auto currentElems = util::split(config.getLastNotEmptyHypConst(colName, lineIndex).get(), '|'); + auto currentElems = util::split(std::string(config.getLastNotEmptyHypConst(colName, lineIndex)), '|'); currentElems.emplace_back(addition); std::sort(currentElems.begin(), currentElems.end()); @@ -163,7 +163,7 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde auto undo = [colName, lineIndex, addition](Config & config, Action &) { - auto curElems = util::split(config.getLastNotEmptyHypConst(colName, lineIndex).get(), '|'); + auto curElems = util::split(std::string(config.getLastNotEmptyHypConst(colName, lineIndex)), '|'); std::vector<std::string> newElems; for (auto & elem : curElems) if (elem != addition) @@ -177,7 +177,7 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde if (!config.has(colName, lineIndex, 0)) return false; auto & current = config.getLastNotEmptyHypConst(colName, lineIndex); - auto splited = util::split(current.get(), '|'); + auto splited = util::split(std::string(current), '|'); for (auto & part : splited) if (part == addition) return false; @@ -191,10 +191,10 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde { auto apply = [colName, lineIndex, addition, mean](Config & config, Action &) { - std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); + std::string totalStr = std::string(config.getLastNotEmptyHypConst(colName, lineIndex)); if (totalStr.empty() || totalStr == "_") - totalStr = fmt::format("{}={}|{}", config.getState(), 0.0, 0); + totalStr = fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0); auto byStates = util::split(totalStr, ','); int index = -1; @@ -209,7 +209,7 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde } if (index == -1) { - byStates.emplace_back(fmt::format("{}={}|{}", config.getState(), 0.0, 0)); + byStates.emplace_back(fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0)); index = byStates.size()-1; } @@ -235,17 +235,17 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde curVal += addition; } - byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); + byStates[index] = fmt::format("{}={}|{}", std::string(config.getState()), curVal, curNb); config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); }; auto undo = [colName, lineIndex, addition, mean](Config & config, Action &) { - std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); + std::string totalStr = std::string(config.getLastNotEmptyHypConst(colName, lineIndex)); if (totalStr.empty() || totalStr == "_") - totalStr = fmt::format("{}={}|{}", config.getState(), 0.0, 0); + totalStr = fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0); auto byStates = util::split(totalStr, ','); int index = -1; @@ -260,7 +260,7 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde } if (index == -1) { - byStates.emplace_back(fmt::format("{}={}|{}", config.getState(), 0.0, 0)); + byStates.emplace_back(fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0)); index = byStates.size()-1; } @@ -282,7 +282,7 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde else curVal -= addition; - byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); + byStates[index] = fmt::format("{}={}|{}", std::string(config.getState()), curVal, curNb); config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); }; @@ -456,9 +456,9 @@ Action Action::endWord() auto appliable = [](const Config & config, const Action &) { - if (util::isEmpty(config.getAsFeature("FORM", config.getWordIndex()))) + if (std::string(config.getAsFeature("FORM", config.getWordIndex())).empty()) return false; - if (!util::isEmpty(config.getAsFeature(Config::idColName, config.getWordIndex())) and config.getAsFeature(Config::isMultiColName, config.getWordIndex()) != Config::EOSSymbol1) + if (!std::string(config.getAsFeature(Config::idColName, config.getWordIndex())).empty() and config.getAsFeature(Config::isMultiColName, config.getWordIndex()) != Config::EOSSymbol1) return false; return true; @@ -532,7 +532,7 @@ Action Action::assertIsEmpty(const std::string & colName, Config::Object object, if (!config.hasRelativeWordIndex(object, relativeIndex)) return false; auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); - return util::isEmpty(config.getAsFeature(colName, lineIndex)); + return std::string(config.getAsFeature(colName, lineIndex)).empty(); } catch (std::exception & e) { util::myThrow(fmt::format("colName='{}' object='{}' relativeIndex='{}' {}", colName, object == Config::Object::Stack ? "Stack" : "Buffer", relativeIndex, e.what())); @@ -561,7 +561,7 @@ Action Action::assertIsNotEmpty(const std::string & colName, Config::Object obje if (!config.hasRelativeWordIndex(object, relativeIndex)) return false; auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); - return !util::isEmpty(config.getAsFeature(colName, lineIndex)); + return !std::string(config.getAsFeature(colName, lineIndex)).empty(); } catch (std::exception & e) { util::myThrow(fmt::format("colName='{}' object='{}' relativeIndex='{}' {}", colName, object == Config::Object::Stack ? "Stack" : "Buffer", relativeIndex, e.what())); @@ -581,21 +581,21 @@ Action Action::addCharsToCol(const std::string & col, int n, Config::Object obje auto & curWord = config.getLastNotEmptyHyp(col, index); if (col == "FORM") { - if (util::isEmpty(config.getAsFeature(Config::rawRangeStartColName, index))) + if (std::string(config.getAsFeature(Config::rawRangeStartColName, index)).empty()) config.getLastNotEmptyHyp(Config::rawRangeStartColName, index) = fmt::format("{}", config.getCharacterIndex()); config.getLastNotEmptyHyp(Config::rawRangeEndColName, index) = fmt::format("{}", config.getCharacterIndex()); int curEndValue = std::stoi(config.getAsFeature(Config::rawRangeEndColName, index)); config.getLastNotEmptyHyp(Config::rawRangeEndColName, index) = fmt::format("{}", curEndValue+n); } for (int i = 0; i < n; i++) - curWord = fmt::format("{}{}", curWord, config.getLetter(config.getCharacterIndex()+i)); + curWord = fmt::format("{}{}", std::string(curWord), config.getLetter(config.getCharacterIndex()+i)); }; auto undo = [col, n, object, relativeIndex](Config & config, Action &) { auto index = config.getRelativeWordIndex(object, relativeIndex); auto & curWord = config.getLastNotEmptyHyp(col, index); - auto newWord = util::splitAsUtf8(curWord.get()); + auto newWord = util::splitAsUtf8(std::string(curWord)); for (int i = 0; i < n; i++) newWord.pop_back(); curWord = fmt::format("{}", newWord); @@ -612,7 +612,7 @@ Action Action::addCharsToCol(const std::string & col, int n, Config::Object obje auto firstLetter = config.getLetter(config.getCharacterIndex()); - if (firstLetter == ' ' and util::isEmpty(config.getAsFeature(col, config.getRelativeWordIndex(object, relativeIndex)))) + if (firstLetter == ' ' and std::string(config.getAsFeature(col, config.getRelativeWordIndex(object, relativeIndex))).empty()) return false; for (int i = 0; i < n; i++) @@ -646,7 +646,7 @@ Action Action::setRoot(int bufferIndex) if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) break; - if (util::isEmpty(config.getAsFeature(Config::headColName, i))) + if (std::string(config.getAsFeature(Config::headColName, i)).empty()) { rootIndex = i; a.data.push_back(std::to_string(i)); @@ -667,7 +667,7 @@ Action Action::setRoot(int bufferIndex) if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) break; - if (util::isEmpty(config.getAsFeature(Config::headColName, i))) + if (std::string(config.getAsFeature(Config::headColName, i)).empty()) { if (i == rootIndex) { @@ -752,13 +752,13 @@ Action Action::updateIds(int bufferIndex) int firstIndex = 0; int lastIndex = 0; try {firstIndex = std::stoi(config.getAsFeature(Config::rawRangeStartColName, firstIndexOfSentence));} - catch (std::exception & e) {util::myThrow(fmt::format("{} : '{}'", e.what(), config.getAsFeature(Config::rawRangeStartColName, firstIndexOfSentence)));} + catch (std::exception & e) {util::myThrow(fmt::format("{} : '{}'", e.what(), std::string(config.getAsFeature(Config::rawRangeStartColName, firstIndexOfSentence))));} try {lastIndex = std::stoi(config.getAsFeature(Config::rawRangeEndColName, lineIndex));} - catch (std::exception & e) {util::myThrow(fmt::format("{} : '{}'", e.what(), config.getAsFeature(Config::rawRangeEndColName, lineIndex)));} + catch (std::exception & e) {util::myThrow(fmt::format("{} : '{}'", e.what(), std::string(config.getAsFeature(Config::rawRangeEndColName, lineIndex))));} for (auto i = firstIndex; i < lastIndex; i++) textMetadata = fmt::format("{}{}", textMetadata, config.getLetter(i)); - config.getLastNotEmptyHyp(Config::commentsColName, firstIndexOfSentence) = fmt::format("{}\n# sent_id = {}", textMetadata, config.getAsFeature(Config::sentIdColName, firstIndexOfSentence)); + config.getLastNotEmptyHyp(Config::commentsColName, firstIndexOfSentence) = fmt::format("{}\n# sent_id = {}", textMetadata, std::string(config.getAsFeature(Config::sentIdColName, firstIndexOfSentence))); } }; @@ -812,7 +812,7 @@ Action Action::attach(Config::Object governorObject, int governorIndex, Config:: return false; // Check if dep is not already attached - if (!util::isEmpty(config.getAsFeature(Config::headColName, depLineIndex))) + if (!std::string(config.getAsFeature(Config::headColName, depLineIndex)).empty()) return false; return true; @@ -877,7 +877,7 @@ Action Action::setRootUpdateIdsEmptyStackIfSentChanged() break; } - if (util::isEmpty(config.getAsFeature(Config::headColName, i))) + if (std::string(config.getAsFeature(Config::headColName, i)).empty()) rootIndex = i; firstIndexOfSentence = i; @@ -894,7 +894,7 @@ Action Action::setRootUpdateIdsEmptyStackIfSentChanged() if (!config.isTokenPredicted(i)) continue; - if (util::isEmpty(config.getAsFeature(Config::headColName, i))) + if (std::string(config.getAsFeature(Config::headColName, i)).empty()) { if (i == rootIndex) { @@ -968,11 +968,11 @@ Action Action::transformSuffix(std::string fromCol, Config::Object fromObj, int if (toRemove.empty() and toAdd.empty()) { - addHypothesis(toCol, toLineIndex, config.getAsFeature(fromCol, fromLineIndex).get()).apply(config, a); + addHypothesis(toCol, toLineIndex, std::string(config.getAsFeature(fromCol, fromLineIndex))).apply(config, a); return; } - util::utf8string res = util::splitAsUtf8(util::lower(config.getAsFeature(fromCol, fromLineIndex).get())); + util::utf8string res = util::splitAsUtf8(util::lower(std::string(config.getAsFeature(fromCol, fromLineIndex)))); for (unsigned int i = 0; i < toRemove.size(); i++) res.pop_back(); for (auto & letter : toAdd) @@ -993,7 +993,7 @@ Action Action::transformSuffix(std::string fromCol, Config::Object fromObj, int int fromLineIndex = config.getRelativeWordIndex(fromObj, fromIndex); int toLineIndex = config.getRelativeWordIndex(toObj, toIndex); - util::utf8string res = util::splitAsUtf8(util::lower(config.getAsFeature(fromCol, fromLineIndex).get())); + util::utf8string res = util::splitAsUtf8(util::lower(std::string(config.getAsFeature(fromCol, fromLineIndex)))); if (res.size() < toRemove.size()) return false; @@ -1069,7 +1069,7 @@ Action Action::uppercaseIndex(std::string col, Config::Object obj, int index, in auto apply = [col, obj, index, inIndex](Config & config, Action & a) { int lineIndex = config.getRelativeWordIndex(obj, index); - auto res = util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()); + auto res = util::splitAsUtf8(std::string(config.getAsFeature(col, lineIndex))); util::upper(res[inIndex]); addHypothesis(col, lineIndex, fmt::format("{}", res)).apply(config, a); @@ -1079,7 +1079,7 @@ Action Action::uppercaseIndex(std::string col, Config::Object obj, int index, in { int lineIndex = config.getRelativeWordIndex(obj, index); auto & value = config.getLastNotEmptyHyp(col, lineIndex); - auto res = util::splitAsUtf8(value.get()); + auto res = util::splitAsUtf8(std::string(value)); value = fmt::format("{}", res); }; @@ -1090,7 +1090,7 @@ Action Action::uppercaseIndex(std::string col, Config::Object obj, int index, in int lineIndex = config.getRelativeWordIndex(obj, index); - if ((int)util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()).size() <= inIndex) + if ((int)util::splitAsUtf8(std::string(config.getAsFeature(col, lineIndex))).size() <= inIndex) return false; return addHypothesis(col, lineIndex, "").appliable(config, a); @@ -1133,7 +1133,7 @@ Action Action::lowercaseIndex(std::string col, Config::Object obj, int index, in auto apply = [col, obj, index, inIndex](Config & config, Action & a) { int lineIndex = config.getRelativeWordIndex(obj, index); - auto res = util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()); + auto res = util::splitAsUtf8(std::string(config.getAsFeature(col, lineIndex))); util::lower(res[inIndex]); addHypothesis(col, lineIndex, fmt::format("{}", res)).apply(config, a); @@ -1143,7 +1143,7 @@ Action Action::lowercaseIndex(std::string col, Config::Object obj, int index, in { int lineIndex = config.getRelativeWordIndex(obj, index); auto & value = config.getLastNotEmptyHyp(col, lineIndex); - auto res = util::splitAsUtf8(value.get()); + auto res = util::splitAsUtf8(std::string(value)); value = fmt::format("{}", res); }; @@ -1154,7 +1154,7 @@ Action Action::lowercaseIndex(std::string col, Config::Object obj, int index, in int lineIndex = config.getRelativeWordIndex(obj, index); - if ((int)util::splitAsUtf8(config.getAsFeature(col, lineIndex).get()).size() <= inIndex) + if ((int)util::splitAsUtf8(std::string(config.getAsFeature(col, lineIndex))).size() <= inIndex) return false; return addHypothesis(col, lineIndex, "").appliable(config, a); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 8e08ddb73ab008312af039ff6a55689fa2684fd1..bed8601019a3385378a0f417e9fc979fa6908a90 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -53,22 +53,22 @@ bool Config::has(const std::string & colName, int lineIndex, int hypothesisIndex return hasColIndex(colName) && has(getColIndex(colName), lineIndex, hypothesisIndex); } -Config::String & Config::get(const std::string & colName, int lineIndex, int hypothesisIndex) +util::String & Config::get(const std::string & colName, int lineIndex, int hypothesisIndex) { return get(getColIndex(colName), lineIndex, hypothesisIndex); } -const Config::String & Config::getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const +const util::String & Config::getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const { return getConst(getColIndex(colName), lineIndex, hypothesisIndex); } -Config::String & Config::get(int colIndex, int lineIndex, int hypothesisIndex) +util::String & Config::get(int colIndex, int lineIndex, int hypothesisIndex) { return *getIterator(colIndex, lineIndex, hypothesisIndex); } -const Config::String & Config::getConst(int colIndex, int lineIndex, int hypothesisIndex) const +const util::String & Config::getConst(int colIndex, int lineIndex, int hypothesisIndex) const { return *getConstIterator(colIndex, lineIndex, hypothesisIndex); } @@ -105,8 +105,8 @@ void Config::print(FILE * dest, bool printHeader) const for (unsigned int line = 0; line < getNbLines(); line++) { - if (!util::isEmpty(getAsFeature(commentsColName, getFirstLineIndex()+line))) - currentSequenceComments.emplace_back(fmt::format("{}\n", getAsFeature(commentsColName, getFirstLineIndex()+line))); + if (!std::string(getAsFeature(commentsColName, getFirstLineIndex()+line)).empty()) + currentSequenceComments.emplace_back(fmt::format("{}\n", std::string(getAsFeature(commentsColName, getFirstLineIndex()+line)))); for (unsigned int i = 0; i < getNbColumns()-1; i++) { @@ -175,14 +175,14 @@ void Config::printForDebug(FILE * dest) const { if ((isExtraColumn(getColName(i)) and exceptions.count(getColName(i)) == 0) and getColName(i) != EOSColName) continue; - std::string colContent = has(i,line,0) ? getAsFeature(i, line).get() : "?"; + std::string colContent = has(i,line,0) ? std::string(getAsFeature(i, line)) : "?"; std::string toPrintCol = colContent; try { if (getColName(i) == headColName && toPrintCol != "_" && !toPrintCol.empty()) { if (toPrintCol != "-1" && toPrintCol != "?") - toPrintCol = has(0,std::stoi(toPrintCol),0) ? getAsFeature(idColName, std::stoi(toPrintCol)).get() : "?"; + toPrintCol = has(0,std::stoi(toPrintCol),0) ? std::string(getAsFeature(idColName, std::stoi(toPrintCol))) : "?"; else if (toPrintCol == "-1") toPrintCol = "0"; } @@ -228,9 +228,9 @@ void Config::printForDebug(FILE * dest) const fmt::print(dest, "{}\n", longLine); for (std::size_t index = characterIndex; index < util::getSize(rawInput) and index - characterIndex < lettersWindowSize; index++) fmt::print(dest, "{}", getLetter(index)); - if (!util::isEmpty(rawInput)) + if (!rawInput.empty()) fmt::print(dest, "\n{}\n", longLine); - fmt::print(dest, "State={}\nwordIndex={}/{} characterIndex={}/{}\nhistory=({})\nstack=({})\n", state, wordIndex, getNbLines(), characterIndex, rawInput.size(), historyStr, stackStr); + fmt::print(dest, "State={}\nwordIndex={}/{} characterIndex={}/{}\nhistory=({})\nstack=({})\n", std::string(state), wordIndex, getNbLines(), characterIndex, rawInput.size(), historyStr, stackStr); fmt::print(dest, "{}\n", longLine); for (unsigned int line = 0; line < toPrint.size(); line++) @@ -246,77 +246,77 @@ void Config::printForDebug(FILE * dest) const fmt::print(dest, "{}\n", longLine); } -Config::String & Config::getLastNotEmpty(int colIndex, int lineIndex) +util::String & Config::getLastNotEmpty(int colIndex, int lineIndex) { if (!has(colIndex, lineIndex, 0)) util::myThrow(fmt::format("asked for line {} but last line = {}", lineIndex, getNbLines()+getFirstLineIndex()-1)); int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); for (int i = nbHypothesesMax; i > 0; --i) - if (!util::isEmpty(lines[baseIndex+i])) + if (!std::string(lines[baseIndex+i]).empty()) return lines[baseIndex+i]; return lines[baseIndex]; } -Config::String & Config::getLastNotEmptyHyp(int colIndex, int lineIndex) +util::String & Config::getLastNotEmptyHyp(int colIndex, int lineIndex) { if (!has(colIndex, lineIndex, 0)) util::myThrow(fmt::format("asked for line {} but nbLines = {}", lineIndex, getNbLines())); int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); for (int i = nbHypothesesMax; i > 0; --i) - if (!util::isEmpty(lines[baseIndex+i])) + if (!std::string((lines[baseIndex+i])).empty()) return lines[baseIndex+i]; return lines[baseIndex+1]; } -Config::String & Config::getFirstEmpty(int colIndex, int lineIndex) +util::String & Config::getFirstEmpty(int colIndex, int lineIndex) { if (!has(colIndex, lineIndex, 0)) util::myThrow(fmt::format("asked for line {} but nbLines = {}", lineIndex, getNbLines())); int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); for (int i = 1; i < nbHypothesesMax; ++i) - if (util::isEmpty(lines[baseIndex+i])) + if (std::string(lines[baseIndex+i]).empty()) return lines[baseIndex+i]; return lines[baseIndex+nbHypothesesMax]; } -Config::String & Config::getFirstEmpty(const std::string & colName, int lineIndex) +util::String & Config::getFirstEmpty(const std::string & colName, int lineIndex) { return getFirstEmpty(getColIndex(colName), lineIndex); } -const Config::String & Config::getLastNotEmptyConst(int colIndex, int lineIndex) const +const util::String & Config::getLastNotEmptyConst(int colIndex, int lineIndex) const { if (!has(colIndex, lineIndex, 0)) util::myThrow(fmt::format("asked for line {} but nbLines = {}", lineIndex, getNbLines())); int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); for (int i = nbHypothesesMax; i > 0; --i) - if (!util::isEmpty(lines[baseIndex+i])) + if (!std::string(lines[baseIndex+i]).empty()) return lines[baseIndex+i]; return lines[baseIndex]; } -const Config::String & Config::getLastNotEmptyHypConst(int colIndex, int lineIndex) const +const util::String & Config::getLastNotEmptyHypConst(int colIndex, int lineIndex) const { if (!has(colIndex, lineIndex, 0)) util::myThrow(fmt::format("asked for line {} but nbLines = {}", lineIndex, getNbLines())); int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); for (int i = nbHypothesesMax; i > 0; --i) - if (!util::isEmpty(lines[baseIndex+i])) + if (!std::string(lines[baseIndex+i]).empty()) return lines[baseIndex+i]; return lines[baseIndex+1]; } -const Config::String & Config::getAsFeature(int colIndex, int lineIndex) const +const util::String & Config::getAsFeature(int colIndex, int lineIndex) const { if (isPredicted(getColName(colIndex))) return getLastNotEmptyHypConst(colIndex, lineIndex); @@ -324,27 +324,27 @@ const Config::String & Config::getAsFeature(int colIndex, int lineIndex) const return getLastNotEmptyConst(colIndex, lineIndex); } -const Config::String & Config::getAsFeature(const std::string & colName, int lineIndex) const +const util::String & Config::getAsFeature(const std::string & colName, int lineIndex) const { return getAsFeature(getColIndex(colName), lineIndex); } -Config::String & Config::getLastNotEmpty(const std::string & colName, int lineIndex) +util::String & Config::getLastNotEmpty(const std::string & colName, int lineIndex) { return getLastNotEmpty(getColIndex(colName), lineIndex); } -Config::String & Config::getLastNotEmptyHyp(const std::string & colName, int lineIndex) +util::String & Config::getLastNotEmptyHyp(const std::string & colName, int lineIndex) { return getLastNotEmptyHyp(getColIndex(colName), lineIndex); } -const Config::String & Config::getLastNotEmptyConst(const std::string & colName, int lineIndex) const +const util::String & Config::getLastNotEmptyConst(const std::string & colName, int lineIndex) const { return getLastNotEmptyConst(getColIndex(colName), lineIndex); } -const Config::String & Config::getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const +const util::String & Config::getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const { return getLastNotEmptyHypConst(getColIndex(colName), lineIndex); } @@ -361,7 +361,7 @@ Config::ConstValueIterator Config::getConstIterator(int colIndex, int lineIndex, void Config::addToHistory(const std::string & transition) { - history.push_back(String(transition)); + history.push_back(util::String(transition)); } void Config::addToStack(std::size_t index) @@ -394,34 +394,34 @@ const util::utf8char & Config::getLetter(int letterIndex) const bool Config::isMultiword(std::size_t lineIndex) const { - return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('-') != std::string::npos; + return hasColIndex(idColName) && std::string(getConst(idColName, lineIndex, 0)).find('-') != std::string::npos; } bool Config::isMultiwordPredicted(std::size_t lineIndex) const { - return hasColIndex(idColName) && getAsFeature(idColName, lineIndex).get().find('-') != std::string::npos; + return hasColIndex(idColName) && std::string(getAsFeature(idColName, lineIndex)).find('-') != std::string::npos; } int Config::getMultiwordSize(std::size_t lineIndex) const { - auto splited = util::split(getConst(idColName, lineIndex, 0).get(), '-'); + auto splited = util::split(std::string(getConst(idColName, lineIndex, 0)), '-'); return std::stoi(std::string(splited[1])) - std::stoi(std::string(splited[0])); } int Config::getMultiwordSizePredicted(std::size_t lineIndex) const { - auto splited = util::split(getAsFeature(idColName, lineIndex).get(), '-'); + auto splited = util::split(std::string(getAsFeature(idColName, lineIndex)), '-'); return std::stoi(std::string(splited[1])) - std::stoi(std::string(splited[0])); } bool Config::isEmptyNode(std::size_t lineIndex) const { - return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('.') != std::string::npos; + return hasColIndex(idColName) && std::string(getConst(idColName, lineIndex, 0)).find('.') != std::string::npos; } bool Config::isEmptyNodePredicted(std::size_t lineIndex) const { - return hasColIndex(idColName) && getAsFeature(idColName, lineIndex).get().find('.') != std::string::npos; + return hasColIndex(idColName) && std::string(getAsFeature(idColName, lineIndex)).find('.') != std::string::npos; } bool Config::isToken(std::size_t lineIndex) const @@ -486,7 +486,7 @@ std::size_t Config::getCharacterIndex() const return characterIndex; } -const Config::String & Config::getHistory(int relativeIndex) const +const util::String & Config::getHistory(int relativeIndex) const { return history[history.size()-1-relativeIndex]; } @@ -515,7 +515,7 @@ bool Config::hasStack(int relativeIndex) const return relativeIndex >= 0 && relativeIndex < (int)stack.size(); } -Config::String Config::getState() const +util::String Config::getState() const { return state; } @@ -527,7 +527,7 @@ void Config::setState(const std::string state) bool Config::stateIsDone() const { - if (!util::isEmpty(rawInput)) + if (!rawInput.empty()) return rawInputOnlySeparatorsLeft() and !has(0, wordIndex+1, 0) and !hasStack(0); return !has(0, wordIndex+1, 0) and !hasStack(0); @@ -579,7 +579,7 @@ void Config::addMissingColumns() if (!isTokenPredicted(index)) continue; - if (util::isEmpty(getAsFeature(idColName, index))) + if (std::string(getAsFeature(idColName, index)).empty()) { int last = 0; if (index > 0 and isTokenPredicted(index-1)) @@ -592,7 +592,7 @@ void Config::addMissingColumns() firstIndex = index; if (hasColIndex(headColName)) - if (util::isEmpty(getAsFeature(headColName, index))) + if (std::string(getAsFeature(headColName, index)).empty()) getLastNotEmptyHyp(headColName, index) = (curId == 1) ? "-1" : std::to_string(firstIndex); } } diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 582acff8d4cf6f1311897210b368ed52bb8dabf8..33e08cd945cad3e52886d958bd304e1b62b46895 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -95,7 +95,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) TransitionSet & ReadingMachine::getTransitionSet(const std::string & state) { - return classifiers[state2classifier.at(state)]->getTransitionSet(state); + return classifiers.at(state2classifier.at(state))->getTransitionSet(state); } bool ReadingMachine::hasSplitWordTransitionSet() const diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index bcdd951bafe4e98ead8f60b5275c6d129660f432..b22b9e62cc19ef3420f59e4cc05cdbd59260bb45 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -42,7 +42,7 @@ Strategy::Movement Strategy::getMovement(const Config & c, const std::string & t Strategy::Movement Strategy::Block::getMovement(const Config & c, const std::string & transition) { std::string transitionPrefix(util::split(transition, ' ')[0]); - auto currentState = c.getState(); + auto currentState = std::string(c.getState()); for (auto & movement : movements) { diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 12e4e4a8ef8665f71ce106aa93721f8ca5f57cb5..858a04457dd4f132feafc3b27b633c7fde4e82ee 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -102,13 +102,13 @@ void Transition::apply(Config & config, float entropy) if (config.hasColIndex("STACK_SIZE")) { auto & curValue = config.get("STACK_SIZE", config.getWordIndex(), 0); - if (curValue.get().empty()) + if (std::string(curValue).empty()) curValue = fmt::format("{}", config.getStackSize()); } if (config.hasColIndex("STACK_DIST")) { auto & curValue = config.get("STACK_DIST", config.getWordIndex(), 0); - if (curValue.get().empty()) + if (std::string(curValue).empty()) { if (config.hasStack(0) and config.hasStack(1)) curValue = fmt::format("{}", config.getStack(0) - config.getStack(1)); @@ -145,8 +145,12 @@ void Transition::apply(Config & config, float entropy) void Transition::apply(Config & config) { + //TODO find a way without copy (data race, having data inside actions) for (Action & action : sequence) - action.apply(config, action); + { + Action cp = action; + cp.apply(config, cp); + } } bool Transition::appliable(const Config & config) const @@ -238,7 +242,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); - auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|'); + auto gold = util::split(std::string(config.getConst(colName, lineIndex, 0)), '|'); for (auto & part : gold) if (part == value) @@ -267,8 +271,8 @@ void Transition::initIgnoreChar() costDynamic = [](const Config & config) { auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex())); - auto goldWord = util::splitAsUtf8(config.getConst("FORM", config.getWordIndex(), 0).get()); - auto curWord = util::splitAsUtf8(config.getAsFeature("FORM", config.getWordIndex()).get()); + auto goldWord = util::splitAsUtf8(std::string(config.getConst("FORM", config.getWordIndex(), 0))); + auto curWord = util::splitAsUtf8(std::string(config.getAsFeature("FORM", config.getWordIndex()))); if (curWord.size() >= goldWord.size()) return 0; @@ -857,7 +861,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s { int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue); int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue); - util::utf8string res = util::splitAsUtf8(util::lower(config.getAsFeature(fromCol, fromLineIndex).get())); + util::utf8string res = util::splitAsUtf8(util::lower(std::string(config.getAsFeature(fromCol, fromLineIndex)))); for (unsigned int i = 0; i < toRemoveUtf8.size(); i++) res.pop_back(); for (auto & letter : toAddUtf8) @@ -883,7 +887,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); - std::string currentValue = config.getAsFeature(col, lineIndex).get(); + std::string currentValue = std::string(config.getAsFeature(col, lineIndex)); if (expectedValue == currentValue) return 1; @@ -908,7 +912,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); - std::string currentValue = config.getAsFeature(col, lineIndex).get(); + std::string currentValue = std::string(config.getAsFeature(col, lineIndex)); if (expectedValue == currentValue) return 1; @@ -932,7 +936,7 @@ void Transition::initNothing(std::string col, std::string obj, std::string index { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); - std::string currentValue = config.getAsFeature(col, lineIndex).get(); + std::string currentValue = std::string(config.getAsFeature(col, lineIndex)); if (expectedValue == currentValue) return 0; @@ -953,7 +957,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); - std::string currentValue = config.getAsFeature(col, lineIndex).get(); + std::string currentValue = std::string(config.getAsFeature(col, lineIndex)); if (expectedValue == currentValue) return 1; @@ -978,7 +982,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); - std::string currentValue = config.getAsFeature(col, lineIndex).get(); + std::string currentValue = std::string(config.getAsFeature(col, lineIndex)); if (expectedValue == currentValue) return 1; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 19c7972123aab7fa47822aaf5e290cb69fbdb54a..8289c13c3a87040be35ddf3b1098a092a7e37b83 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -97,7 +97,7 @@ void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & con else { int childIndex = *std::get<2>(target); - auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|'); + auto childs = util::split(std::string(config.getAsFeature(Config::childsColName, baseIndex)), '|'); int candidate = -2; if (childIndex >= 0 and childIndex < (int)childs.size()) diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 1c569831a4b15f71a8314326e5c7eccd96c337d6..164de44db0df1b8fca52ac74e1b2ee4f015ec719 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -118,7 +118,7 @@ void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config & else { int childIndex = *std::get<2>(target); - auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|'); + auto childs = util::split(std::string(config.getAsFeature(Config::childsColName, baseIndex)), '|'); int candidate = -2; if (childIndex >= 0 and childIndex < (int)childs.size()) diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index ac906908408baf2c8375618294a9892929bc3062..6945a346864b45ec683d962946685dced05508b5 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -109,7 +109,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co for (auto & child : childs) if (config.has(Config::childsColName, std::stoi(child), 0)) { - auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|'); + auto val = util::split(std::string(config.getAsFeature(Config::childsColName, std::stoi(child))), '|'); newChilds.insert(newChilds.end(), val.begin(), val.end()); } childs = newChilds; diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 3fc25f0ad53ef521f05da2bdbc5aabf0216ef096..a7af65a275a422d867e5fedb64e47fbd19d6ab87 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -112,7 +112,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config std::vector<std::string> elements; if (column == "FORM") { - auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get())); + auto asUtf8 = util::splitAsUtf8(func(std::string(config.getAsFeature(column, index)))); //TODO don't use nullValueStr here for (int i = 0; i < maxNbElements; i++) @@ -123,7 +123,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config } else if (column == "FEATS") { - auto splited = util::split(func(config.getAsFeature(column, index).get()), '|'); + auto splited = util::split(func(std::string(config.getAsFeature(column, index))), '|'); for (int i = 0; i < maxNbElements; i++) if (i < (int)splited.size()) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 49d3016bb00452efee0ffa7fc0d45d5a80a58bae..57c33f64ab5924e89deafd0c2e3e86dc45ebe8f4 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -79,7 +79,7 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config double res = 0.0; if (index >= 0) { - auto value = config.getAsFeature(column, index).get(); + auto value = std::string(config.getAsFeature(column, index)); try {res = (value == "_" or value == "NA") ? defaultValue : std::stof(value);} catch (std::exception & e) {util::myThrow(fmt::format("{} for '{}'", e.what(), value));} diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 7c846150c1e96af8f7886520592b8d4363e9c78d..5f6de2129503c79c6b74b864060c6f73e24238b4 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -75,7 +75,7 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config double res = -1.0; if (index >= 0) { - auto word = util::splitAsUtf8(config.getAsFeature("FORM", index).get()); + auto word = util::splitAsUtf8(std::string(config.getAsFeature("FORM", index))); int nbUpper = 0; for (auto & letter : word) if (util::isUppercase(letter)) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index e0f2c1d4cad4015e567ae97a313acd81ffa9b80b..900ab29218bf839115679e60daec6a3358c736c3 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -1,5 +1,6 @@ #include "MacaonTrain.hpp" #include <filesystem> +#include <execution> #include "util.hpp" #include "NeuralNetwork.hpp" #include "WordEmbeddings.hpp" @@ -318,8 +319,15 @@ int MacaonTrain::main() devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>()); } - for (auto & devConfig : devConfigs) - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); + torch::AutoGradMode useGrad(false); + machine.trainMode(false); + machine.setDictsState(Dict::State::Closed); + + std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(), + [&decoder, &debug, &printAdvancement](BaseConfig & devConfig) + { + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); + }); std::vector<const Config *> devConfigsPtrs; for (auto & devConfig : devConfigs) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 4143b2cb5a961f9db485dbd9891a2d35e569fc0c..298e2a9bdfbba6b2626f117fefd35c84640d8fe2 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -381,7 +381,7 @@ void Trainer::extractActionSequence(BaseConfig & config) newSent = true; std::string curSeq = ""; for (int i = curSeqStartIndex; i <= curInputIndex; i++) - curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i)); + curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i))); fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes)); curOutputSeqSize = 0; curInputSeqSize = 0; @@ -403,7 +403,7 @@ void Trainer::extractActionSequence(BaseConfig & config) { std::string curSeq = ""; for (int i = curSeqStartIndex; i <= curInputIndex; i++) - curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", config.getAsFeature("FORM", i)); + curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i))); fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes)); curOutputSeqSize = 0; curInputSeqSize = 0;