Skip to content
Snippets Groups Projects
Commit 07592b27 authored by Franck Dary's avatar Franck Dary
Browse files

Added functions prefix/suffix and allowed for multiple functions to be applied

parent 559cbc8a
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ class utf8char : public std::array<char, 4>
class utf8string : public std::vector<utf8char>
{
using std::vector<utf8char>::vector;
public :
utf8string & operator=(const std::string & other);
......
......@@ -15,7 +15,7 @@ class ContextModuleImpl : public Submodule
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns;
std::map<std::size_t, std::function<std::string(const std::string &)>> functions;
std::vector<std::function<std::string(const std::string &)>> functions;
std::vector<int> bufferContext;
std::vector<int> stackContext;
int inSize;
......
......@@ -22,7 +22,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
const std::function<std::string(const std::string &)> & getFunction(const std::string functionName);
std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
};
#endif
......
......@@ -18,12 +18,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
columns.clear();
for (auto & funcCol : funcColumns)
{
auto splited = util::split(funcCol, ':');
if (splited.size() > 2)
util::myThrow(fmt::format("invalid function:column '{}' of size {}", funcCol, splited.size()));
if (splited.size() == 2)
functions[columns.size()] = getFunction(splited[0]);
columns.emplace_back(splited.back());
functions.emplace_back() = getFunction(funcCol);
columns.emplace_back(util::split(funcCol, ':').back());
}
auto subModuleType = sm.str(4);
......@@ -87,11 +83,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
}
else
{
int dictIndex;
if (functions.count(colIndex))
dictIndex = dict.getIndexOrInsert(functions.at(colIndex)(config.getAsFeature(col, index)));
else
dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
int dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
......
......@@ -8,12 +8,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
{
try
{
auto funcCol = util::split(sm.str(1), ':');
if (funcCol.size() > 2)
util::myThrow(fmt::format("invalid function:column '{}' of size {}", sm.str(1), funcCol.size()));
if (funcCol.size() == 2)
func = getFunction(funcCol[0]);
column = funcCol.back();
func = getFunction(sm.str(1));
column = util::split(sm.str(1), ':').back();
maxNbElements = std::stoi(sm.str(2));
for (auto & index : util::split(sm.str(3), ' '))
......
......@@ -63,17 +63,63 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
getDict().setState(originalState);
}
const std::function<std::string(const std::string &)> & Submodule::getFunction(const std::string functionName)
std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
{
static auto prefix = [](const std::string & s, int length)
{
util::utf8string utf8s = util::splitAsUtf8(s);
util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length));
return fmt::format("{}", prefix);
};
static auto suffix = [](const std::string & s, int length)
{
util::utf8string utf8s = util::splitAsUtf8(s);
util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end());
return fmt::format("{}", suffix);
};
static std::map<std::string, std::function<std::string(const std::string &)>> functions
{
{"lower", [](const std::string & s) {return util::lower(s);}}
{"lower", [](const std::string & s) {return util::lower(s);}},
{"prefix1", [](const std::string & s) {return prefix(s, 1);}},
{"prefix2", [](const std::string & s) {return prefix(s, 2);}},
{"prefix3", [](const std::string & s) {return prefix(s, 3);}},
{"prefix4", [](const std::string & s) {return prefix(s, 4);}},
{"prefix5", [](const std::string & s) {return prefix(s, 5);}},
{"prefix6", [](const std::string & s) {return prefix(s, 6);}},
{"prefix7", [](const std::string & s) {return prefix(s, 7);}},
{"suffix1", [](const std::string & s) {return suffix(s, 1);}},
{"suffix2", [](const std::string & s) {return suffix(s, 2);}},
{"suffix3", [](const std::string & s) {return suffix(s, 3);}},
{"suffix4", [](const std::string & s) {return suffix(s, 4);}},
{"suffix5", [](const std::string & s) {return suffix(s, 5);}},
{"suffix6", [](const std::string & s) {return suffix(s, 6);}},
{"suffix7", [](const std::string & s) {return suffix(s, 7);}},
};
auto it = functions.find(util::lower(functionName));
if (it == functions.end())
util::myThrow(fmt::format("unknown function name '{}'", functionName));
auto splited = util::split(functionNames, ':');
if (splited.size() == 1)
return [](const std::string & s){return s;};
std::vector<std::function<std::string(const std::string &)>> sequence;
return it->second;
for (unsigned int i = 0; i < splited.size()-1; i++)
{
auto & functionName = splited[i];
auto it = functions.find(util::lower(functionName));
if (it == functions.end())
util::myThrow(fmt::format("unknown function name '{}'", functionName));
sequence.emplace_back(it->second);
}
return [sequence](const std::string & s)
{
auto result = s;
for (auto & f : sequence)
result = f(result);
return result;
};
}
......@@ -90,6 +90,26 @@ Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
return ts;
}
template
<
typename Optimizer = torch::optim::Adam,
typename OptimizerOptions = torch::optim::AdamOptions
>
inline auto decay(Optimizer &optimizer, double rate) -> void
{
for (auto &group : optimizer.param_groups())
{
for (auto &param : group.params())
{
if (!param.grad().defined())
continue;
auto &options = static_cast<OptimizerOptions &>(group.options());
options.lr(options.lr() * (1.0 - rate));
}
}
}
int MacaonTrain::main()
{
auto od = getOptionsDescription();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment