Skip to content
Snippets Groups Projects
Commit f6de0f30 authored by Franck Dary's avatar Franck Dary
Browse files

Trainer now returns loss per example

parent 0daae795
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@ class Trainer
private :
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float processDataset(DataLoader & loader, bool train, bool printAdvancement);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir);
public :
......
......@@ -189,7 +189,7 @@ int MacaonTrain::main()
if (computeDevScore)
devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
else
devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
devScoresStr += fmt::format("{}({:6.4f}{}),", score.second, score.first, computeDevScore ? "%" : "");
devScoreMean += score.first;
}
if (!devScoresStr.empty())
......@@ -207,7 +207,7 @@ int MacaonTrain::main()
trainer.saveOptimizer(optimizerCheckpoint);
if (printAdvancement)
fmt::print(stderr, "\r{:80}\r", "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "{}\n", iterStr);
std::FILE * f = std::fopen(trainInfos.c_str(), "a");
fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
......
......@@ -131,11 +131,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
for (auto & element : context)
{
currentExampleIndex++;
classes.emplace_back(gold);
}
currentExampleIndex += context.size();
classes.insert(classes.end(), context.size(), gold);
if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
......@@ -169,13 +166,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
}
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
{
constexpr int printInterval = 50;
int nbExamplesProcessed = 0;
int totalNbExamplesProcessed = 0;
float totalLoss = 0.0;
float lossSoFar = 0.0;
int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train);
machine.trainMode(train);
......@@ -212,37 +209,41 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
optimizer->step();
}
totalNbExamplesProcessed += torch::numel(labels);
if (printAdvancement)
{
nbExamplesProcessed += labels.size(0);
nbExamplesProcessed += torch::numel(labels);
++currentBatchNumber;
if (nbExamplesProcessed >= printInterval)
{
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
auto speed = (int)(nbExamplesProcessed/seconds);
auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
if (train)
fmt::print(stderr, "\r{:80}\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<6}ex/s", "", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
else
fmt::print(stderr, "\r{:80}\reval on dev : loss={:<7.3f} speed={:<6}ex/s", "", lossSoFar, (int)(nbExamplesProcessed/seconds));
fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
lossSoFar = 0;
nbExamplesProcessed = 0;
}
}
}
return totalLoss;
return totalLoss / nbExamples;
}
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement);
return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement);
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
}
void Trainer::loadOptimizer(std::filesystem::path path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment