From 07592b27a358f2aff9c9952e5a7294adffc6e62e Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 7 Jun 2020 17:41:35 +0200
Subject: [PATCH] Added functions prefix/suffix and allowed for multiple
 functions to be applied

---
 common/include/utf8string.hpp             |  1 +
 torch_modules/include/ContextModule.hpp   |  2 +-
 torch_modules/include/Submodule.hpp       |  2 +-
 torch_modules/src/ContextModule.cpp       | 14 ++----
 torch_modules/src/FocusedColumnModule.cpp |  8 +---
 torch_modules/src/Submodule.cpp           | 58 ++++++++++++++++++++---
 trainer/src/MacaonTrain.cpp               | 20 ++++++++
 7 files changed, 80 insertions(+), 25 deletions(-)

diff --git a/common/include/utf8string.hpp b/common/include/utf8string.hpp
index 21937d3..1b56bba 100644
--- a/common/include/utf8string.hpp
+++ b/common/include/utf8string.hpp
@@ -25,6 +25,7 @@ class utf8char : public std::array<char, 4>
 
 class utf8string : public std::vector<utf8char>
 {
+  using std::vector<utf8char>::vector;
   public :
 
   utf8string & operator=(const std::string & other);
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index 2276aaa..ed0ce57 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -15,7 +15,7 @@ class ContextModuleImpl : public Submodule
   torch::nn::Embedding wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<std::string> columns;
-  std::map<std::size_t, std::function<std::string(const std::string &)>> functions;
+  std::vector<std::function<std::string(const std::string &)>> functions;
   std::vector<int> bufferContext;
   std::vector<int> stackContext;
   int inSize;
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 6e8f689..71b1007 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -22,7 +22,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
   virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
-  const std::function<std::string(const std::string &)> & getFunction(const std::string functionName);
+  std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
 };
 
 #endif
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 244944d..d1a31c9 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -18,12 +18,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
             columns.clear();
             for (auto & funcCol : funcColumns)
             {
-              auto splited = util::split(funcCol, ':');
-              if (splited.size() > 2)
-                util::myThrow(fmt::format("invalid function:column '{}' of size {}", funcCol, splited.size()));
-              if (splited.size() == 2)
-                functions[columns.size()] = getFunction(splited[0]);
-              columns.emplace_back(splited.back());
+              functions.emplace_back() = getFunction(funcCol);
+              columns.emplace_back(util::split(funcCol, ':').back());
             }
 
             auto subModuleType = sm.str(4);
@@ -87,11 +83,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
       }
       else
       {
-        int dictIndex;
-        if (functions.count(colIndex))
-          dictIndex = dict.getIndexOrInsert(functions.at(colIndex)(config.getAsFeature(col, index)));
-        else
-          dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
+        int dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
 
         for (auto & contextElement : context)
           contextElement.push_back(dictIndex);
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 5eb0db2..29aef9e 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -8,12 +8,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
         {
           try
           {
-            auto funcCol = util::split(sm.str(1), ':');
-            if (funcCol.size() > 2)
-              util::myThrow(fmt::format("invalid function:column '{}' of size {}", sm.str(1), funcCol.size()));
-            if (funcCol.size() == 2)
-              func = getFunction(funcCol[0]);
-            column = funcCol.back();
+            func = getFunction(sm.str(1));
+            column = util::split(sm.str(1), ':').back();
             maxNbElements = std::stoi(sm.str(2));
 
             for (auto & index : util::split(sm.str(3), ' '))
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index a87f446..2be33f7 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -63,17 +63,63 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
   getDict().setState(originalState);
 }
 
-const std::function<std::string(const std::string &)> & Submodule::getFunction(const std::string functionName)
+std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
 {
+  static auto prefix = [](const std::string & s, int length)
+  {
+    util::utf8string utf8s = util::splitAsUtf8(s);
+    util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length));
+    return fmt::format("{}", prefix);
+  };
+
+  static auto suffix = [](const std::string & s, int length)
+  {
+    util::utf8string utf8s = util::splitAsUtf8(s);
+    util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end());
+    return fmt::format("{}", suffix);
+  };
+
   static std::map<std::string, std::function<std::string(const std::string &)>> functions
   {
-    {"lower", [](const std::string & s) {return util::lower(s);}}
+    {"lower", [](const std::string & s) {return util::lower(s);}},
+    {"prefix1", [](const std::string & s) {return prefix(s, 1);}},
+    {"prefix2", [](const std::string & s) {return prefix(s, 2);}},
+    {"prefix3", [](const std::string & s) {return prefix(s, 3);}},
+    {"prefix4", [](const std::string & s) {return prefix(s, 4);}},
+    {"prefix5", [](const std::string & s) {return prefix(s, 5);}},
+    {"prefix6", [](const std::string & s) {return prefix(s, 6);}},
+    {"prefix7", [](const std::string & s) {return prefix(s, 7);}},
+    {"suffix1", [](const std::string & s) {return suffix(s, 1);}},
+    {"suffix2", [](const std::string & s) {return suffix(s, 2);}},
+    {"suffix3", [](const std::string & s) {return suffix(s, 3);}},
+    {"suffix4", [](const std::string & s) {return suffix(s, 4);}},
+    {"suffix5", [](const std::string & s) {return suffix(s, 5);}},
+    {"suffix6", [](const std::string & s) {return suffix(s, 6);}},
+    {"suffix7", [](const std::string & s) {return suffix(s, 7);}},
   };
 
-  auto it = functions.find(util::lower(functionName));
-  if (it == functions.end())
-    util::myThrow(fmt::format("unknown function name '{}'", functionName));
+  auto splited = util::split(functionNames, ':');
+  if (splited.size() == 1)
+    return [](const std::string & s){return s;};
+
+  std::vector<std::function<std::string(const std::string &)>> sequence;
 
-  return it->second;
+  for (unsigned int i = 0; i < splited.size()-1; i++)
+  {
+    auto & functionName = splited[i];
+    auto it = functions.find(util::lower(functionName));
+    if (it == functions.end())
+      util::myThrow(fmt::format("unknown function name '{}'", functionName));
+
+    sequence.emplace_back(it->second);
+  }
+
+  return [sequence](const std::string & s)
+  {
+    auto result = s; 
+    for (auto & f : sequence)
+      result = f(result);
+    return result;
+  };
 }
 
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index ea0f6f5..8a60dee 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -90,6 +90,26 @@ Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
   return ts;
 }
 
+template 
+<
+  typename Optimizer = torch::optim::Adam,
+  typename OptimizerOptions = torch::optim::AdamOptions
+>
+inline auto decay(Optimizer &optimizer, double rate) -> void
+{
+  for (auto &group : optimizer.param_groups())
+  {
+    for (auto &param : group.params())
+    {
+      if (!param.grad().defined())
+        continue;
+
+      auto &options = static_cast<OptimizerOptions &>(group.options());
+      options.lr(options.lr() * (1.0 - rate));
+    }
+  }
+}
+
 int MacaonTrain::main()
 {
   auto od = getOptionsDescription();
-- 
GitLab