MacaonTrain.cpp 10.6 KB
Newer Older
Franck Dary's avatar
Franck Dary committed
1
#include "MacaonTrain.hpp"
Franck Dary's avatar
Franck Dary committed
2
3
#include <filesystem>
#include "util.hpp"
4
#include "NeuralNetwork.hpp"
5
#include "WordEmbeddings.hpp"
Franck Dary's avatar
Franck Dary committed
6
7
8

namespace po = boost::program_options;

Franck Dary's avatar
Franck Dary committed
9
po::options_description MacaonTrain::getOptionsDescription()
Franck Dary's avatar
Franck Dary committed
10
11
12
13
14
15
16
17
18
19
20
21
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the machine file to train")
    ("trainTSV", po::value<std::string>()->required(),
      "TSV file of the training corpus, in CONLLU format");

  po::options_description opt("Optional");
  opt.add_options()
Franck Dary's avatar
Franck Dary committed
22
    ("debug,d", "Print debuging infos on stderr")
Franck Dary's avatar
Franck Dary committed
23
    ("silent", "Don't print speed and progress")
24
    ("devScore", "Compute score on dev instead of loss (slower)")
25
26
    ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
      "Comma separated column names that describes the input/output format")
Franck Dary's avatar
Franck Dary committed
27
28
29
30
31
32
33
34
    ("trainTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the training corpus")
    ("devTSV", po::value<std::string>()->default_value(""),
      "TSV file of the development corpus, in CONLLU format")
    ("devTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the development corpus")
    ("nbEpochs,n", po::value<int>()->default_value(5),
      "Number of training epochs")
35
36
    ("batchSize", po::value<int>()->default_value(64),
      "Number of examples per batch")
37
38
    ("explorationThreshold", po::value<float>()->default_value(0.1),
      "Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
39
40
    ("machine", po::value<std::string>()->default_value(""),
      "Reading machine file content")
41
    ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
Franck Dary's avatar
Franck Dary committed
42
      "Description of what should happen during training")
Franck Dary's avatar
Franck Dary committed
43
    ("loss", po::value<std::string>()->default_value("CrossEntropy"),
44
      "Loss function to use during training : CrossEntropy | bce | mse | hinge")
Franck Dary's avatar
Franck Dary committed
45
46
    ("seed", po::value<int>()->default_value(100),
      "Number of examples per batch")
47
48
49
    ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
    ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
      "Max norm for the embeddings")
Franck Dary's avatar
Franck Dary committed
50
51
52
53
54
55
56
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

Franck Dary's avatar
Franck Dary committed
57
po::variables_map MacaonTrain::checkOptions(po::options_description & od)
Franck Dary's avatar
Franck Dary committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  return vm;
}

Franck Dary's avatar
Franck Dary committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
{
  Trainer::TrainStrategy ts;

  try
  {
    auto splited = util::split(s, ':');
    for (auto & ss : splited)
    {
      auto elements = util::split(ss, ',');

      int epoch = std::stoi(elements[0]);

      for (unsigned int i = 1; i < elements.size(); i++)
        ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
    }
  } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}

  return ts;
}

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
template 
<
  typename Optimizer = torch::optim::Adam,
  typename OptimizerOptions = torch::optim::AdamOptions
>
inline auto decay(Optimizer &optimizer, double rate) -> void
{
  for (auto &group : optimizer.param_groups())
  {
    for (auto &param : group.params())
    {
      if (!param.grad().defined())
        continue;

      auto &options = static_cast<OptimizerOptions &>(group.options());
      options.lr(options.lr() * (1.0 - rate));
    }
  }
}

Franck Dary's avatar
Franck Dary committed
119
int MacaonTrain::main()
Franck Dary's avatar
Franck Dary committed
120
121
{
  auto od = getOptionsDescription();
Franck Dary's avatar
Franck Dary committed
122
  auto variables = checkOptions(od);
Franck Dary's avatar
Franck Dary committed
123
124
125

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / "machine.rm";
126
  auto mcd = variables["mcd"].as<std::string>();
Franck Dary's avatar
Franck Dary committed
127
128
129
130
131
  auto trainTsvFile = variables["trainTSV"].as<std::string>();
  auto trainRawFile = variables["trainTXT"].as<std::string>();
  auto devTsvFile = variables["devTSV"].as<std::string>();
  auto devRawFile = variables["devTXT"].as<std::string>();
  auto nbEpoch = variables["nbEpochs"].as<int>();
132
  auto batchSize = variables["batchSize"].as<int>();
Franck Dary's avatar
Franck Dary committed
133
  bool debug = variables.count("debug") == 0 ? false : true;
Franck Dary's avatar
Franck Dary committed
134
  bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
135
  bool computeDevScore = variables.count("devScore") == 0 ? false : true;
136
  auto machineContent = variables["machine"].as<std::string>();
Franck Dary's avatar
Franck Dary committed
137
  auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
Franck Dary's avatar
Franck Dary committed
138
  auto lossFunction = variables["loss"].as<std::string>();
139
  auto explorationThreshold = variables["explorationThreshold"].as<float>();
Franck Dary's avatar
Franck Dary committed
140
  auto seed = variables["seed"].as<int>();
141
142
  WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
  WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
Franck Dary's avatar
Franck Dary committed
143
144
145

  std::srand(seed);
  torch::manual_seed(seed);
Franck Dary's avatar
Franck Dary committed
146
147

  auto trainStrategy = parseTrainStrategy(trainStrategyStr);
148

Franck Dary's avatar
Franck Dary committed
149
150
  torch::globalContext().setBenchmarkCuDNN(true);

151
152
153
154
155
156
157
158
  if (!machineContent.empty())
  {
    std::FILE * file = fopen(machinePath.c_str(), "w");
    if (!file)
      util::error(fmt::format("can't open file '{}'\n", machinePath.c_str()));
    fmt::print(file, "{}", machineContent);
    std::fclose(file);
  }
Franck Dary's avatar
Franck Dary committed
159

Franck Dary's avatar
Franck Dary committed
160
  fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str());
161

Franck Dary's avatar
Franck Dary committed
162
163
164
  try
  {

Franck Dary's avatar
Franck Dary committed
165
  ReadingMachine machine(machinePath.string(), true);
Franck Dary's avatar
Franck Dary committed
166

167
168
  BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile);
  BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
Franck Dary's avatar
Franck Dary committed
169

Franck Dary's avatar
Franck Dary committed
170
  Trainer trainer(machine, batchSize, lossFunction);
Franck Dary's avatar
Franck Dary committed
171
172
  Decoder decoder(machine);

173
  float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
Franck Dary's avatar
Franck Dary committed
174

175
176
177
178
179
180
181
182
183
184
185
186
  auto trainInfos = machinePath.parent_path() / "train.info";

  int currentEpoch = 0;

  if (std::filesystem::exists(trainInfos))
  {
    std::FILE * f = std::fopen(trainInfos.c_str(), "r");
    char buffer[1024];
    while (!std::feof(f))
    {
      if (buffer != std::fgets(buffer, 1024, f))
        break;
Franck Dary's avatar
Franck Dary committed
187
      bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
188
      float devScoreMean = std::stof(util::split(buffer, '\t').back());
Franck Dary's avatar
Franck Dary committed
189
      if (saved)
190
191
192
193
194
195
196
        bestDevScore = devScoreMean;
      currentEpoch++;
    }
    std::fclose(f);
  }

  for (; currentEpoch < nbEpoch; currentEpoch++)
Franck Dary's avatar
Franck Dary committed
197
  {
Franck Dary's avatar
Franck Dary committed
198
199
200
201
202
203
204
205
206
207
208
209
210
    bool saved = false;

    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
    {
      for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
        if (entry.is_regular_file())
          std::filesystem::remove(entry.path());

      if (!computeDevScore)
        for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
          if (entry.is_regular_file())
            std::filesystem::remove(entry.path());
    }
211
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
Franck Dary's avatar
Franck Dary committed
212
    {
213
      machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
214
      trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
Franck Dary's avatar
Franck Dary committed
215
      if (!computeDevScore)
216
217
      {
        machine.setDictsState(Dict::State::Closed);
218
        trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
219
      }
Franck Dary's avatar
Franck Dary committed
220
221
222
223
224
    }
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
    {
      if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
      {
Franck Dary's avatar
Franck Dary committed
225
        machine.resetClassifiers();
226
        machine.trainMode(currentEpoch == 0);
Franck Dary's avatar
Franck Dary committed
227
        fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
Franck Dary's avatar
Franck Dary committed
228
229
      }

Franck Dary's avatar
Franck Dary committed
230
      machine.resetOptimizers();
Franck Dary's avatar
Franck Dary committed
231
232
233
234
235
236
237
    }
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
    {
      saved = true;
    }

    trainer.makeDataLoader(modelPath/"examples/train");
238
    if (!computeDevScore)
Franck Dary's avatar
Franck Dary committed
239
      trainer.makeDevDataLoader(modelPath/"examples/dev");
240

Franck Dary's avatar
Franck Dary committed
241
    float loss = trainer.epoch(printAdvancement);
Franck Dary's avatar
Franck Dary committed
242
243
    if (debug)
      fmt::print(stderr, "Decoding dev :\n");
244
245
246
    std::vector<std::pair<float,std::string>> devScores;
    if (computeDevScore)
    {
247
      BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
248
      decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
249
250
251
252
253
254
255
256
257
      decoder.evaluate(devConfig, modelPath, devTsvFile);
      devScores = decoder.getF1Scores(machine.getPredicted());
    }
    else
    {
      float devLoss = trainer.evalOnDev(printAdvancement);
      devScores.emplace_back(std::make_pair(devLoss, "Loss"));
    }

258
259
260
261
    std::string devScoresStr = "";
    float devScoreMean = 0;
    for (auto & score : devScores)
    {
262
263
264
      if (computeDevScore)
        devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
      else
Franck Dary's avatar
Franck Dary committed
265
        devScoresStr += fmt::format("{}({:6.4f}{}),", score.second, 100.0*score.first, computeDevScore ? "%" : "");
266
      devScoreMean += score.first;
267
268
269
270
    }
    if (!devScoresStr.empty())
      devScoresStr.pop_back();
    devScoreMean /= devScores.size();
271

Franck Dary's avatar
Franck Dary committed
272
273
274
275
    if (computeDevScore)
      saved = saved or devScoreMean >= bestDevScore;
    else
      saved = saved or devScoreMean <= bestDevScore;
276

Franck Dary's avatar
Franck Dary committed
277
278
    if (saved)
    {
279
      bestDevScore = devScoreMean;
280
      machine.saveBest();
Franck Dary's avatar
Franck Dary committed
281
    }
Franck Dary's avatar
Franck Dary committed
282

283
    machine.saveLast();
Franck Dary's avatar
Franck Dary committed
284

Franck Dary's avatar
Franck Dary committed
285
    if (printAdvancement)
Franck Dary's avatar
Franck Dary committed
286
      fmt::print(stderr, "\r{:80}\r", "");
287
    std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : "");
288
289
290
291
    fmt::print(stderr, "{}\n", iterStr);
    std::FILE * f = std::fopen(trainInfos.c_str(), "a");
    fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
    std::fclose(f);
Franck Dary's avatar
Franck Dary committed
292
293
  }

Franck Dary's avatar
Franck Dary committed
294
295
296
  }
  catch(std::exception & e) {util::error(e);}

Franck Dary's avatar
Franck Dary committed
297
298
299
  return 0;
}

Franck Dary's avatar
Franck Dary committed
300
301
302
303
MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv)
{
}