diff --git a/common/include/utf8string.hpp b/common/include/utf8string.hpp index 21937d305fad51143b4d0aac07b9ffc902a572a7..1b56bbaee6e739d3c2cb661cda9f07388dd05850 100644 --- a/common/include/utf8string.hpp +++ b/common/include/utf8string.hpp @@ -25,6 +25,7 @@ class utf8char : public std::array<char, 4> class utf8string : public std::vector<utf8char> { + using std::vector<utf8char>::vector; public : utf8string & operator=(const std::string & other); diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 2276aaaf5c287d11f67377323a4c40851c6f3618..ed0ce5757a1f77b0df95b682be8a6276c4f9c3a1 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -15,7 +15,7 @@ class ContextModuleImpl : public Submodule torch::nn::Embedding wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; - std::map<std::size_t, std::function<std::string(const std::string &)>> functions; + std::vector<std::function<std::string(const std::string &)>> functions; std::vector<int> bufferContext; std::vector<int> stackContext; int inSize; diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 6e8f6898c637574e16da10351870e41723c32c75..71b100781fe538ca4d5e6f010684c5d4aa956699 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -22,7 +22,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; - const std::function<std::string(const std::string &)> & getFunction(const std::string functionName); + std::function<std::string(const std::string &)> getFunction(const std::string functionNames); }; #endif diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 244944d7ce315fbf38f263c71790bcb9c6cdd9dd..d1a31c9f03596e4bd38eabb826713d54797f6283 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -18,12 +18,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin columns.clear(); for (auto & funcCol : funcColumns) { - auto splited = util::split(funcCol, ':'); - if (splited.size() > 2) - util::myThrow(fmt::format("invalid function:column '{}' of size {}", funcCol, splited.size())); - if (splited.size() == 2) - functions[columns.size()] = getFunction(splited[0]); - columns.emplace_back(splited.back()); + functions.emplace_back() = getFunction(funcCol); + columns.emplace_back(util::split(funcCol, ':').back()); } auto subModuleType = sm.str(4); @@ -87,11 +83,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c } else { - int dictIndex; - if (functions.count(colIndex)) - dictIndex = dict.getIndexOrInsert(functions.at(colIndex)(config.getAsFeature(col, index))); - else - dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); + int dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index))); for (auto & contextElement : context) contextElement.push_back(dictIndex); diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 5eb0db28499009e74136f10d12a1fd4b43c8cd7b..29aef9e53e765f05ccd9f1d9b35bacc2817d01c6 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -8,12 +8,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st { try { - auto funcCol = util::split(sm.str(1), ':'); - if (funcCol.size() > 2) - util::myThrow(fmt::format("invalid function:column '{}' of size {}", sm.str(1), funcCol.size())); - if (funcCol.size() == 2) - func = getFunction(funcCol[0]); - column = funcCol.back(); + func = getFunction(sm.str(1)); + column = util::split(sm.str(1), ':').back(); maxNbElements = std::stoi(sm.str(2)); for (auto & index : util::split(sm.str(3), ' ')) diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index a87f44669b3a6c58befe4271f4f83e7ae09af5e4..2be33f7bd165e8fb32a12283284320db01eda065 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -63,17 +63,63 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s getDict().setState(originalState); } -const std::function<std::string(const std::string &)> & Submodule::getFunction(const std::string functionName) +std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) { + static auto prefix = [](const std::string & s, int length) + { + util::utf8string utf8s = util::splitAsUtf8(s); + util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length)); + return fmt::format("{}", prefix); + }; + + static auto suffix = [](const std::string & s, int length) + { + util::utf8string utf8s = util::splitAsUtf8(s); + util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end()); + return fmt::format("{}", suffix); + }; + static std::map<std::string, std::function<std::string(const std::string &)>> functions { - {"lower", [](const std::string & s) {return util::lower(s);}} + {"lower", [](const std::string & s) {return util::lower(s);}}, + {"prefix1", [](const std::string & s) {return prefix(s, 1);}}, + {"prefix2", [](const std::string & s) {return prefix(s, 2);}}, + {"prefix3", [](const std::string & s) {return prefix(s, 3);}}, + {"prefix4", [](const std::string & s) {return prefix(s, 4);}}, + {"prefix5", [](const std::string & s) {return prefix(s, 5);}}, + {"prefix6", [](const std::string & s) {return prefix(s, 6);}}, + {"prefix7", [](const std::string & s) {return prefix(s, 7);}}, + {"suffix1", [](const std::string & s) {return suffix(s, 1);}}, + {"suffix2", [](const std::string & s) {return suffix(s, 2);}}, + {"suffix3", [](const std::string & s) {return suffix(s, 3);}}, + {"suffix4", [](const std::string & s) {return suffix(s, 4);}}, + {"suffix5", [](const std::string & s) {return suffix(s, 5);}}, + {"suffix6", [](const std::string & s) {return suffix(s, 6);}}, + {"suffix7", [](const std::string & s) {return suffix(s, 7);}}, }; - auto it = functions.find(util::lower(functionName)); - if (it == functions.end()) - util::myThrow(fmt::format("unknown function name '{}'", functionName)); + auto splited = util::split(functionNames, ':'); + if (splited.size() == 1) + return [](const std::string & s){return s;}; + + std::vector<std::function<std::string(const std::string &)>> sequence; - return it->second; + for (unsigned int i = 0; i < splited.size()-1; i++) + { + auto & functionName = splited[i]; + auto it = functions.find(util::lower(functionName)); + if (it == functions.end()) + util::myThrow(fmt::format("unknown function name '{}'", functionName)); + + sequence.emplace_back(it->second); + } + + return [sequence](const std::string & s) + { + auto result = s; + for (auto & f : sequence) + result = f(result); + return result; + }; } diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index ea0f6f52eced990cc0e05f893c44302e20486bbb..8a60deeea35d5488f7cc704ee7f81069576a5ada 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -90,6 +90,26 @@ Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s) return ts; } +template +< + typename Optimizer = torch::optim::Adam, + typename OptimizerOptions = torch::optim::AdamOptions +> +inline auto decay(Optimizer &optimizer, double rate) -> void +{ + for (auto &group : optimizer.param_groups()) + { + for (auto ¶m : group.params()) + { + if (!param.grad().defined()) + continue; + + auto &options = static_cast<OptimizerOptions &>(group.options()); + options.lr(options.lr() * (1.0 - rate)); + } + } +} + int MacaonTrain::main() { auto od = getOptionsDescription();