Newer
Older
Franck Dary
committed
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
Franck Dary
committed
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
Franck Dary
committed
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir));
Franck Dary
committed
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary
committed
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
Franck Dary
committed
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
devDataset.reset(new Dataset(dir));
Franck Dary
committed
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary
committed
void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir)
Franck Dary
committed
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
torch::AutoGradMode useGrad(false);
machine.setDictsState(Dict::State::Open);
Franck Dary
committed
int maxNbExamplesPerFile = 250000;
int currentExampleIndex = 0;
int lastSavedIndex = 0;
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
Franck Dary
committed
machine.getStrategy().reset();
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
mustExtract = false;
if (!mustExtract)
return;
bool dynamicOracle = epoch != 0;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
std::vector<std::vector<long>> context;
try
{
context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
for (auto & element : context)
Franck Dary
committed
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Franck Dary
committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
Transition * transition = nullptr;
if (dynamicOracle and config.getState() != "tokenizer")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
chosenTransition = i;
bestScore = score;
}
}
transition = machine.getTransitionSet().getTransition(chosenTransition);
}
else
{
transition = machine.getTransitionSet().getBestAppliableTransition(config);
}
if (!transition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
Franck Dary
committed
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
currentExampleIndex += context.size();
classes.insert(classes.end(), context.size(), gold);
Franck Dary
committed
if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
Franck Dary
committed
if (!contexts.empty())
saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
Franck Dary
committed
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
torch::AutoGradMode useGrad(train);
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *loader)
Franck Dary
committed
machine.getClassifier()->getOptimizer().zero_grad();
Franck Dary
committed
auto data = batch.first;
auto labels = batch.second;
auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
Franck Dary
committed
auto loss = lossFct(prediction, labels);
try
{
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
if (train)
{
loss.backward();
Franck Dary
committed
machine.getClassifier()->getOptimizer().step();
totalNbExamplesProcessed += torch::numel(labels);
nbExamplesProcessed += torch::numel(labels);
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
auto speed = (int)(nbExamplesProcessed/seconds);
auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());