From 07592b27a358f2aff9c9952e5a7294adffc6e62e Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 7 Jun 2020 17:41:35 +0200 Subject: [PATCH] Added functions prefix/suffix and allowed for multiple functions to be applied --- common/include/utf8string.hpp | 1 + torch_modules/include/ContextModule.hpp | 2 +- torch_modules/include/Submodule.hpp | 2 +- torch_modules/src/ContextModule.cpp | 14 ++---- torch_modules/src/FocusedColumnModule.cpp | 8 +--- torch_modules/src/Submodule.cpp | 58 ++++++++++++++++++++--- trainer/src/MacaonTrain.cpp | 20 ++++++++ 7 files changed, 80 insertions(+), 25 deletions(-) diff --git a/common/include/utf8string.hpp b/common/include/utf8string.hpp index 21937d3..1b56bba 100644 --- a/common/include/utf8string.hpp +++ b/common/include/utf8string.hpp @@ -25,6 +25,7 @@ class utf8char : public std::array<char, 4> class utf8string : public std::vector<utf8char> { + using std::vector<utf8char>::vector; public : utf8string & operator=(const std::string & other); diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 2276aaa..ed0ce57 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -15,7 +15,7 @@ class ContextModuleImpl : public Submodule torch::nn::Embedding wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; - std::map<std::size_t, std::function<std::string(const std::string &)>> functions; + std::vector<std::function<std::string(const std::string &)>> functions; std::vector<int> bufferContext; std::vector<int> stackContext; int inSize; diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 6e8f689..71b1007 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -22,7 +22,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; - const std::function<std::string(const std::string &)> & getFunction(const std::string functionName); + std::function<std::string(const std::string &)> getFunction(const std::string functionNames); }; #endif diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 244944d..d1a31c9 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -18,12 +18,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin columns.clear(); for (auto & funcCol : funcColumns) { - auto splited = util::split(funcCol, ':'); - if (splited.size() > 2) - util::myThrow(fmt::format("invalid function:column '{}' of size {}", funcCol, splited.size())); - if (splited.size() == 2) - functions[columns.size()] = getFunction(splited[0]); - columns.emplace_back(splited.back()); + functions.emplace_back() = getFunction(funcCol); + columns.emplace_back(util::split(funcCol, ':').back()); } auto subModuleType = sm.str(4); @@ -87,11 +83,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c } else { - int dictIndex; - if (functions.count(colIndex)) - dictIndex = dict.getIndexOrInsert(functions.at(colIndex)(config.getAsFeature(col, index))); - else - dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); + int dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index))); for (auto & contextElement : context) contextElement.push_back(dictIndex); diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 5eb0db2..29aef9e 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -8,12 +8,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st { try { - auto funcCol = util::split(sm.str(1), ':'); - if (funcCol.size() > 2) - util::myThrow(fmt::format("invalid function:column '{}' of size {}", sm.str(1), funcCol.size())); - if (funcCol.size() == 2) - func = getFunction(funcCol[0]); - column = funcCol.back(); + func = getFunction(sm.str(1)); + column = util::split(sm.str(1), ':').back(); maxNbElements = std::stoi(sm.str(2)); for (auto & index : util::split(sm.str(3), ' ')) diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index a87f446..2be33f7 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -63,17 +63,63 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s getDict().setState(originalState); } -const std::function<std::string(const std::string &)> & Submodule::getFunction(const std::string functionName) +std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) { + static auto prefix = [](const std::string & s, int length) + { + util::utf8string utf8s = util::splitAsUtf8(s); + util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length)); + return fmt::format("{}", prefix); + }; + + static auto suffix = [](const std::string & s, int length) + { + util::utf8string utf8s = util::splitAsUtf8(s); + util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end()); + return fmt::format("{}", suffix); + }; + static std::map<std::string, std::function<std::string(const std::string &)>> functions { - {"lower", [](const std::string & s) {return util::lower(s);}} + {"lower", [](const std::string & s) {return util::lower(s);}}, + {"prefix1", [](const std::string & s) {return prefix(s, 1);}}, + {"prefix2", [](const std::string & s) {return prefix(s, 2);}}, + {"prefix3", [](const std::string & s) {return prefix(s, 3);}}, + {"prefix4", [](const std::string & s) {return prefix(s, 4);}}, + {"prefix5", [](const std::string & s) {return prefix(s, 5);}}, + {"prefix6", [](const std::string & s) {return prefix(s, 6);}}, + {"prefix7", [](const std::string & s) {return prefix(s, 7);}}, + {"suffix1", [](const std::string & s) {return suffix(s, 1);}}, + {"suffix2", [](const std::string & s) {return suffix(s, 2);}}, + {"suffix3", [](const std::string & s) {return suffix(s, 3);}}, + {"suffix4", [](const std::string & s) {return suffix(s, 4);}}, + {"suffix5", [](const std::string & s) {return suffix(s, 5);}}, + {"suffix6", [](const std::string & s) {return suffix(s, 6);}}, + {"suffix7", [](const std::string & s) {return suffix(s, 7);}}, }; - auto it = functions.find(util::lower(functionName)); - if (it == functions.end()) - util::myThrow(fmt::format("unknown function name '{}'", functionName)); + auto splited = util::split(functionNames, ':'); + if (splited.size() == 1) + return [](const std::string & s){return s;}; + + std::vector<std::function<std::string(const std::string &)>> sequence; - return it->second; + for (unsigned int i = 0; i < splited.size()-1; i++) + { + auto & functionName = splited[i]; + auto it = functions.find(util::lower(functionName)); + if (it == functions.end()) + util::myThrow(fmt::format("unknown function name '{}'", functionName)); + + sequence.emplace_back(it->second); + } + + return [sequence](const std::string & s) + { + auto result = s; + for (auto & f : sequence) + result = f(result); + return result; + }; } diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index ea0f6f5..8a60dee 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -90,6 +90,26 @@ Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s) return ts; } +template +< + typename Optimizer = torch::optim::Adam, + typename OptimizerOptions = torch::optim::AdamOptions +> +inline auto decay(Optimizer &optimizer, double rate) -> void +{ + for (auto &group : optimizer.param_groups()) + { + for (auto ¶m : group.params()) + { + if (!param.grad().defined()) + continue; + + auto &options = static_cast<OptimizerOptions &>(group.options()); + options.lr(options.lr() * (1.0 - rate)); + } + } +} + int MacaonTrain::main() { auto od = getOptionsDescription(); -- GitLab