Submodule.cpp 4.52 KB
Newer Older
Franck Dary's avatar
Franck Dary committed
1
#include "Submodule.hpp"
2
#include "WordEmbeddings.hpp"
Franck Dary's avatar
Franck Dary committed
3
4
5
6
7
8

void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{
  this->firstInputIndex = firstInputIndex;
}

9
void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
10
{
11
  if (path.empty())
12
    return;
13
14
  if (!is_training())
    return;
15

16
17
18
  if (!std::filesystem::exists(path))
    util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  torch::NoGradGuard no_grad;

  auto originalState = getDict().getState();
  getDict().setState(Dict::State::Closed);

  std::FILE * file = std::fopen(path.c_str(), "r");
  char buffer[100000];

  bool firstLine = true;
  std::size_t embeddingsSize = embeddings->parameters()[0].size(-1);

  try
  {
    while (!std::feof(file))
    {
      if (buffer != std::fgets(buffer, 100000, file))
        break;

      if (firstLine)
      {
        firstLine = false;
        continue;
      }

      auto splited = util::split(util::strip(buffer), ' ');

      if (splited.size() < 2)
        util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));

48
49
      std::string word;

50
      if (splited[0] == "<unk>")
51
52
53
        word = Dict::unknownValueStr;
      else
        word = splited[0];
54

55
56
57
58
      auto toInsert = util::splitAsUtf8(word);
      toInsert.replace("◌", " ");
      word = fmt::format("{}", toInsert);

59
      auto dictIndex = getDict().getIndexOrInsert(word, prefix);
60
61
62
63
64
65
66
67
68
69
70
71
72
73

      if (embeddingsSize != splited.size()-1)
        util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));

      for (unsigned int i = 1; i < splited.size(); i++)
        embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
    }
  } catch (std::exception & e)
  {
    util::myThrow(fmt::format("caught '{}' for SubModule '{}'", e.what(), getName()));
  }

  std::fclose(file);

74
75
76
  if (firstLine)
    util::myThrow(fmt::format("file '{}' is empty", path.string()));

77
  getDict().setState(originalState);
78
  embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
79
80
}

81
std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
82
{
83
84
  static auto prefix = [](const std::string & s, int length)
  {
85
86
87
    if (s.size() == 0)
      return s;

88
89
90
91
92
93
94
    util::utf8string utf8s = util::splitAsUtf8(s);
    util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length));
    return fmt::format("{}", prefix);
  };

  static auto suffix = [](const std::string & s, int length)
  {
95
96
97
    if (s.size() == 0)
      return s;

98
99
100
101
102
    util::utf8string utf8s = util::splitAsUtf8(s);
    util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end());
    return fmt::format("{}", suffix);
  };

103
104
  static std::map<std::string, std::function<std::string(const std::string &)>> functions
  {
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    {"lower", [](const std::string & s) {return util::lower(s);}},
    {"prefix1", [](const std::string & s) {return prefix(s, 1);}},
    {"prefix2", [](const std::string & s) {return prefix(s, 2);}},
    {"prefix3", [](const std::string & s) {return prefix(s, 3);}},
    {"prefix4", [](const std::string & s) {return prefix(s, 4);}},
    {"prefix5", [](const std::string & s) {return prefix(s, 5);}},
    {"prefix6", [](const std::string & s) {return prefix(s, 6);}},
    {"prefix7", [](const std::string & s) {return prefix(s, 7);}},
    {"suffix1", [](const std::string & s) {return suffix(s, 1);}},
    {"suffix2", [](const std::string & s) {return suffix(s, 2);}},
    {"suffix3", [](const std::string & s) {return suffix(s, 3);}},
    {"suffix4", [](const std::string & s) {return suffix(s, 4);}},
    {"suffix5", [](const std::string & s) {return suffix(s, 5);}},
    {"suffix6", [](const std::string & s) {return suffix(s, 6);}},
    {"suffix7", [](const std::string & s) {return suffix(s, 7);}},
120
121
  };

122
123
124
125
126
  auto splited = util::split(functionNames, ':');
  if (splited.size() == 1)
    return [](const std::string & s){return s;};

  std::vector<std::function<std::string(const std::string &)>> sequence;
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  for (unsigned int i = 0; i < splited.size()-1; i++)
  {
    auto & functionName = splited[i];
    auto it = functions.find(util::lower(functionName));
    if (it == functions.end())
      util::myThrow(fmt::format("unknown function name '{}'", functionName));

    sequence.emplace_back(it->second);
  }

  return [sequence](const std::string & s)
  {
    auto result = s; 
    for (auto & f : sequence)
      result = f(result);
    return result;
  };
145
146
}