MacaonTrain.cpp 10.8 KB
Newer Older
Franck Dary's avatar
Franck Dary committed
1
#include "MacaonTrain.hpp"
Franck Dary's avatar
Franck Dary committed
2
3
#include <filesystem>
#include "util.hpp"
4
#include "NeuralNetwork.hpp"
5
#include "WordEmbeddings.hpp"
Franck Dary's avatar
Franck Dary committed
6
7
8

namespace po = boost::program_options;

Franck Dary's avatar
Franck Dary committed
9
po::options_description MacaonTrain::getOptionsDescription()
Franck Dary's avatar
Franck Dary committed
10
11
12
13
14
15
16
17
18
19
20
21
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the machine file to train")
    ("trainTSV", po::value<std::string>()->required(),
      "TSV file of the training corpus, in CONLLU format");

  po::options_description opt("Optional");
  opt.add_options()
Franck Dary's avatar
Franck Dary committed
22
    ("debug,d", "Print debuging infos on stderr")
Franck Dary's avatar
Franck Dary committed
23
    ("silent", "Don't print speed and progress")
24
    ("devScore", "Compute score on dev instead of loss (slower)")
25
26
    ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
      "Comma separated column names that describes the input/output format")
Franck Dary's avatar
Franck Dary committed
27
28
29
30
31
32
33
34
    ("trainTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the training corpus")
    ("devTSV", po::value<std::string>()->default_value(""),
      "TSV file of the development corpus, in CONLLU format")
    ("devTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the development corpus")
    ("nbEpochs,n", po::value<int>()->default_value(5),
      "Number of training epochs")
35
36
    ("batchSize", po::value<int>()->default_value(64),
      "Number of examples per batch")
37
38
    ("explorationThreshold", po::value<float>()->default_value(0.1),
      "Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
39
40
    ("machine", po::value<std::string>()->default_value(""),
      "Reading machine file content")
41
    ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
Franck Dary's avatar
Franck Dary committed
42
      "Description of what should happen during training")
Franck Dary's avatar
Franck Dary committed
43
44
    ("seed", po::value<int>()->default_value(100),
      "Number of examples per batch")
45
46
47
    ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
    ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
      "Max norm for the embeddings")
48
    ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
Franck Dary's avatar
Franck Dary committed
49
50
51
52
53
54
55
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

Franck Dary's avatar
Franck Dary committed
56
po::variables_map MacaonTrain::checkOptions(po::options_description & od)
Franck Dary's avatar
Franck Dary committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  return vm;
}

Franck Dary's avatar
Franck Dary committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
{
  Trainer::TrainStrategy ts;

  try
  {
    auto splited = util::split(s, ':');
    for (auto & ss : splited)
    {
      auto elements = util::split(ss, ',');

      int epoch = std::stoi(elements[0]);

      for (unsigned int i = 1; i < elements.size(); i++)
        ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
    }
  } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}

  return ts;
}

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
template 
<
  typename Optimizer = torch::optim::Adam,
  typename OptimizerOptions = torch::optim::AdamOptions
>
inline auto decay(Optimizer &optimizer, double rate) -> void
{
  for (auto &group : optimizer.param_groups())
  {
    for (auto &param : group.params())
    {
      if (!param.grad().defined())
        continue;

      auto &options = static_cast<OptimizerOptions &>(group.options());
      options.lr(options.lr() * (1.0 - rate));
    }
  }
}

Franck Dary's avatar
Franck Dary committed
118
int MacaonTrain::main()
Franck Dary's avatar
Franck Dary committed
119
120
{
  auto od = getOptionsDescription();
Franck Dary's avatar
Franck Dary committed
121
  auto variables = checkOptions(od);
Franck Dary's avatar
Franck Dary committed
122
123
124

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / "machine.rm";
125
  auto mcd = variables["mcd"].as<std::string>();
Franck Dary's avatar
Franck Dary committed
126
127
128
129
130
  auto trainTsvFile = variables["trainTSV"].as<std::string>();
  auto trainRawFile = variables["trainTXT"].as<std::string>();
  auto devTsvFile = variables["devTSV"].as<std::string>();
  auto devRawFile = variables["devTXT"].as<std::string>();
  auto nbEpoch = variables["nbEpochs"].as<int>();
131
  auto batchSize = variables["batchSize"].as<int>();
Franck Dary's avatar
Franck Dary committed
132
  bool debug = variables.count("debug") == 0 ? false : true;
Franck Dary's avatar
Franck Dary committed
133
  bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
134
  bool computeDevScore = variables.count("devScore") == 0 ? false : true;
135
  auto machineContent = variables["machine"].as<std::string>();
Franck Dary's avatar
Franck Dary committed
136
  auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
137
  auto explorationThreshold = variables["explorationThreshold"].as<float>();
Franck Dary's avatar
Franck Dary committed
138
  auto seed = variables["seed"].as<int>();
139
140
  WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
  WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
141
  WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
Franck Dary's avatar
Franck Dary committed
142
143
144

  std::srand(seed);
  torch::manual_seed(seed);
Franck Dary's avatar
Franck Dary committed
145
146

  auto trainStrategy = parseTrainStrategy(trainStrategyStr);
147

Franck Dary's avatar
Franck Dary committed
148
149
  torch::globalContext().setBenchmarkCuDNN(true);

150
151
152
153
154
155
156
157
  if (!machineContent.empty())
  {
    std::FILE * file = fopen(machinePath.c_str(), "w");
    if (!file)
      util::error(fmt::format("can't open file '{}'\n", machinePath.c_str()));
    fmt::print(file, "{}", machineContent);
    std::fclose(file);
  }
Franck Dary's avatar
Franck Dary committed
158

Franck Dary's avatar
Franck Dary committed
159
  fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str());
160

Franck Dary's avatar
Franck Dary committed
161
162
163
  try
  {

Franck Dary's avatar
Franck Dary committed
164
  ReadingMachine machine(machinePath.string(), true);
Franck Dary's avatar
Franck Dary committed
165

166
167
  BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile);
  BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
Franck Dary's avatar
Franck Dary committed
168

Franck Dary's avatar
Franck Dary committed
169
  Trainer trainer(machine, batchSize);
Franck Dary's avatar
Franck Dary committed
170
171
  Decoder decoder(machine);

Franck Dary's avatar
Franck Dary committed
172
  float bestDevScore = computeDevScore ? -std::numeric_limits<float>::max() : std::numeric_limits<float>::max();
Franck Dary's avatar
Franck Dary committed
173

174
175
176
177
178
179
180
181
182
183
184
185
  auto trainInfos = machinePath.parent_path() / "train.info";

  int currentEpoch = 0;

  if (std::filesystem::exists(trainInfos))
  {
    std::FILE * f = std::fopen(trainInfos.c_str(), "r");
    char buffer[1024];
    while (!std::feof(f))
    {
      if (buffer != std::fgets(buffer, 1024, f))
        break;
Franck Dary's avatar
Franck Dary committed
186
      bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
187
      float devScoreMean = std::stof(util::split(buffer, '\t').back());
Franck Dary's avatar
Franck Dary committed
188
      if (saved)
189
190
191
192
193
194
195
        bestDevScore = devScoreMean;
      currentEpoch++;
    }
    std::fclose(f);
  }

  for (; currentEpoch < nbEpoch; currentEpoch++)
Franck Dary's avatar
Franck Dary committed
196
  {
Franck Dary's avatar
Franck Dary committed
197
198
199
200
201
202
203
204
205
206
207
208
209
    bool saved = false;

    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
    {
      for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
        if (entry.is_regular_file())
          std::filesystem::remove(entry.path());

      if (!computeDevScore)
        for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
          if (entry.is_regular_file())
            std::filesystem::remove(entry.path());
    }
210
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
Franck Dary's avatar
Franck Dary committed
211
    {
212
      machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
213
      trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
Franck Dary's avatar
Franck Dary committed
214
      if (!computeDevScore)
215
216
      {
        machine.setDictsState(Dict::State::Closed);
217
        trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
218
      }
Franck Dary's avatar
Franck Dary committed
219
220
221
222
223
    }
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
    {
      if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
      {
Franck Dary's avatar
Franck Dary committed
224
        machine.resetClassifiers();
225
        machine.trainMode(currentEpoch == 0);
Franck Dary's avatar
Franck Dary committed
226
        fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
Franck Dary's avatar
Franck Dary committed
227
228
      }

Franck Dary's avatar
Franck Dary committed
229
      machine.resetOptimizers();
Franck Dary's avatar
Franck Dary committed
230
231
232
233
234
235
236
    }
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
    {
      saved = true;
    }

    trainer.makeDataLoader(modelPath/"examples/train");
237
    if (!computeDevScore)
Franck Dary's avatar
Franck Dary committed
238
      trainer.makeDevDataLoader(modelPath/"examples/dev");
239

Franck Dary's avatar
Franck Dary committed
240
    float loss = trainer.epoch(printAdvancement);
Franck Dary's avatar
Franck Dary committed
241
242
    if (debug)
      fmt::print(stderr, "Decoding dev :\n");
243
244
245
    std::vector<std::pair<float,std::string>> devScores;
    if (computeDevScore)
    {
246
      BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
247
      decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
248
249
250
251
252
253
254
255
256
      decoder.evaluate(devConfig, modelPath, devTsvFile);
      devScores = decoder.getF1Scores(machine.getPredicted());
    }
    else
    {
      float devLoss = trainer.evalOnDev(printAdvancement);
      devScores.emplace_back(std::make_pair(devLoss, "Loss"));
    }

257
258
    std::string devScoresStr = "";
    float devScoreMean = 0;
Franck Dary's avatar
Franck Dary committed
259
260
    int totalLen = 0;
    std::string toAdd;
261
262
    for (auto & score : devScores)
    {
263
      if (computeDevScore)
Franck Dary's avatar
Franck Dary committed
264
        toAdd = fmt::format("{}({}{}),", score.second, util::shrink(fmt::format("{:.2f}", std::abs(score.first)),7), score.first >= 0 ? "%" : "");
265
      else
Franck Dary's avatar
Franck Dary committed
266
        toAdd = fmt::format("{}({}),", score.second, util::shrink(fmt::format("{}", 100.0*score.first),7));
267
      devScoreMean += score.first;
Franck Dary's avatar
Franck Dary committed
268
269
270

      devScoresStr += toAdd;
      totalLen += util::printedLength(score.second) + 3;
271
272
273
    }
    if (!devScoresStr.empty())
      devScoresStr.pop_back();
Franck Dary's avatar
Franck Dary committed
274
    devScoresStr = fmt::format("{:{}}", devScoresStr, totalLen+7*devScores.size());
275
    devScoreMean /= devScores.size();
276

Franck Dary's avatar
Franck Dary committed
277
278
279
280
    if (computeDevScore)
      saved = saved or devScoreMean >= bestDevScore;
    else
      saved = saved or devScoreMean <= bestDevScore;
281

Franck Dary's avatar
Franck Dary committed
282
283
    if (saved)
    {
284
      bestDevScore = devScoreMean;
285
      machine.saveBest();
Franck Dary's avatar
Franck Dary committed
286
    }
Franck Dary's avatar
Franck Dary committed
287

288
    machine.saveLast();
Franck Dary's avatar
Franck Dary committed
289

Franck Dary's avatar
Franck Dary committed
290
    if (printAdvancement)
Franck Dary's avatar
Franck Dary committed
291
      fmt::print(stderr, "\r{:80}\r", "");
Franck Dary's avatar
Franck Dary committed
292
    std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:7} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), util::shrink(fmt::format("{}",100.0*loss), 7), devScoresStr, saved ? "SAVED" : "");
293
294
295
296
    fmt::print(stderr, "{}\n", iterStr);
    std::FILE * f = std::fopen(trainInfos.c_str(), "a");
    fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
    std::fclose(f);
Franck Dary's avatar
Franck Dary committed
297
298
  }

Franck Dary's avatar
Franck Dary committed
299
300
301
  }
  catch(std::exception & e) {util::error(e);}

Franck Dary's avatar
Franck Dary committed
302
303
304
  return 0;
}

Franck Dary's avatar
Franck Dary committed
305
306
307
308
MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv)
{
}