#include "Submodule.hpp" void Submodule::setFirstInputIndex(std::size_t firstInputIndex) { this->firstInputIndex = firstInputIndex; } void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path) { if (path.empty()) return; if (!is_training()) return; if (!std::filesystem::exists(path)) util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); torch::NoGradGuard no_grad; auto originalState = getDict().getState(); getDict().setState(Dict::State::Closed); std::FILE * file = std::fopen(path.c_str(), "r"); char buffer[100000]; bool firstLine = true; std::size_t embeddingsSize = embeddings->parameters()[0].size(-1); try { while (!std::feof(file)) { if (buffer != std::fgets(buffer, 100000, file)) break; if (firstLine) { firstLine = false; continue; } auto splited = util::split(util::strip(buffer), ' '); if (splited.size() < 2) util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); auto dictIndex = getDict().getIndexOrInsert(splited[0]); if (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)) continue; if (embeddingsSize != splited.size()-1) util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); for (unsigned int i = 1; i < splited.size(); i++) embeddings->weight[dictIndex][i-1] = std::stof(splited[i]); } } catch (std::exception & e) { util::myThrow(fmt::format("caught '{}' for SubModule '{}'", e.what(), getName())); } std::fclose(file); if (firstLine) util::myThrow(fmt::format("file '{}' is empty", path.string())); getDict().setState(originalState); } std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) { static auto prefix = [](const std::string & s, int length) { if (s.size() == 0) return s; util::utf8string utf8s = util::splitAsUtf8(s); util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length)); return fmt::format("{}", prefix); }; static auto suffix = [](const std::string & s, int length) { if (s.size() == 0) return s; util::utf8string utf8s = util::splitAsUtf8(s); util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end()); return fmt::format("{}", suffix); }; static std::map<std::string, std::function<std::string(const std::string &)>> functions { {"lower", [](const std::string & s) {return util::lower(s);}}, {"prefix1", [](const std::string & s) {return prefix(s, 1);}}, {"prefix2", [](const std::string & s) {return prefix(s, 2);}}, {"prefix3", [](const std::string & s) {return prefix(s, 3);}}, {"prefix4", [](const std::string & s) {return prefix(s, 4);}}, {"prefix5", [](const std::string & s) {return prefix(s, 5);}}, {"prefix6", [](const std::string & s) {return prefix(s, 6);}}, {"prefix7", [](const std::string & s) {return prefix(s, 7);}}, {"suffix1", [](const std::string & s) {return suffix(s, 1);}}, {"suffix2", [](const std::string & s) {return suffix(s, 2);}}, {"suffix3", [](const std::string & s) {return suffix(s, 3);}}, {"suffix4", [](const std::string & s) {return suffix(s, 4);}}, {"suffix5", [](const std::string & s) {return suffix(s, 5);}}, {"suffix6", [](const std::string & s) {return suffix(s, 6);}}, {"suffix7", [](const std::string & s) {return suffix(s, 7);}}, }; auto splited = util::split(functionNames, ':'); if (splited.size() == 1) return [](const std::string & s){return s;}; std::vector<std::function<std::string(const std::string &)>> sequence; for (unsigned int i = 0; i < splited.size()-1; i++) { auto & functionName = splited[i]; auto it = functions.find(util::lower(functionName)); if (it == functions.end()) util::myThrow(fmt::format("unknown function name '{}'", functionName)); sequence.emplace_back(it->second); } return [sequence](const std::string & s) { auto result = s; for (auto & f : sequence) result = f(result); return result; }; }