Skip to content
Snippets Groups Projects
Commit ecf7290b authored by Franck Dary's avatar Franck Dary
Browse files

Added dynamical oracle and extracted examples are savec to the disk not to use too much memory

parent 8023999c
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.trainMode(false); machine.trainMode(false);
machine.getStrategy().reset();
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50; constexpr int printInterval = 50;
...@@ -27,9 +28,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -27,9 +28,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (machine.hasSplitWordTransitionSet()) if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto dictState = machine.getDict(config.getState()).getState();
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
machine.getDict(config.getState()).setState(dictState);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
......
...@@ -4,19 +4,24 @@ ...@@ -4,19 +4,24 @@
#include <torch/torch.h> #include <torch/torch.h>
#include "Config.hpp" #include "Config.hpp"
class ConfigDataset : public torch::data::Dataset<ConfigDataset> class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::pair<torch::Tensor,torch::Tensor>>
{ {
private : private :
torch::Tensor data;
std::size_t size_{0}; std::size_t size_{0};
std::size_t contextSize{0}; std::vector<std::tuple<int,int,std::filesystem::path>> exampleLocations;
torch::Tensor loadedTensor;
std::optional<std::size_t> loadedTensorIndex;
std::size_t nextIndexToGive{0};
public : public :
explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes); explicit ConfigDataset(std::filesystem::path dir);
torch::optional<size_t> size() const override; c10::optional<std::size_t> size() const override;
torch::data::Example<> get(size_t index) override; c10::optional<std::pair<torch::Tensor,torch::Tensor>> get_batch(std::size_t batchSize) override;
void reset() override;
void load(torch::serialize::InputArchive &) override;
void save(torch::serialize::OutputArchive &) const override;
}; };
#endif #endif
#include "ConfigDataset.hpp" #include "ConfigDataset.hpp"
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes) ConfigDataset::ConfigDataset(std::filesystem::path dir)
{ {
if (contexts.size() != classes.size()) for (auto & entry : std::filesystem::directory_iterator(dir))
util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size())); if (entry.is_regular_file())
{
auto splited = util::split(entry.path().stem().string(), '-');
if (splited.size() != 2)
continue;
exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path()));
size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
}
}
size_ = contexts.size(); c10::optional<std::size_t> ConfigDataset::size() const
contextSize = contexts.back().size(0);
std::vector<torch::Tensor> total;
for (unsigned int i = 0; i < contexts.size(); i++)
{ {
total.emplace_back(contexts[i]); return size_;
total.emplace_back(classes[i]);
} }
data = torch::cat(total); c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(std::size_t batchSize)
{
if (!loadedTensorIndex.has_value())
{
loadedTensorIndex = 0;
nextIndexToGive = 0;
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
} }
if ((int)nextIndexToGive >= loadedTensor.size(0))
{
nextIndexToGive = 0;
loadedTensorIndex = loadedTensorIndex.value() + 1;
if (loadedTensorIndex >= exampleLocations.size())
return c10::optional<std::pair<torch::Tensor,torch::Tensor>>();
torch::optional<size_t> ConfigDataset::size() const torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
}
std::pair<torch::Tensor, torch::Tensor> batch;
if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
{
batch.first = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
batch.second = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive += batchSize;
}
else
{
batch.first = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
batch.second = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive = loadedTensor.size(0);
}
return batch;
}
void ConfigDataset::reset()
{
std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
loadedTensorIndex = std::optional<std::size_t>();
nextIndexToGive = 0;
}
void ConfigDataset::load(torch::serialize::InputArchive &)
{ {
return size_;
} }
torch::data::Example<> ConfigDataset::get(size_t index) void ConfigDataset::save(torch::serialize::OutputArchive &) const
{ {
return {data.narrow(0, index*(contextSize+1), contextSize).to(NeuralNetworkImpl::device), data.narrow(0, index*(contextSize+1)+contextSize, 1).to(NeuralNetworkImpl::device)};
} }
...@@ -10,28 +10,31 @@ class Trainer ...@@ -10,28 +10,31 @@ class Trainer
private : private :
using Dataset = ConfigDataset; using Dataset = ConfigDataset;
using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >; using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>;
private : private :
ReadingMachine & machine; ReadingMachine & machine;
std::unique_ptr<Dataset> trainDataset{nullptr};
std::unique_ptr<Dataset> devDataset{nullptr};
DataLoader dataLoader{nullptr}; DataLoader dataLoader{nullptr};
DataLoader devDataLoader{nullptr}; DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0}; std::size_t epochNumber{0};
int batchSize{64}; int batchSize;
int nbExamples{0}; int nbExamples{0};
private : private :
void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes); void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float processDataset(DataLoader & loader, bool train, bool printAdvancement); float processDataset(DataLoader & loader, bool train, bool printAdvancement);
void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir);
public : public :
Trainer(ReadingMachine & machine); Trainer(ReadingMachine & machine, int batchSize);
void createDataset(SubConfig & goldConfig, bool debug); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void createDevDataset(SubConfig & goldConfig, bool debug); void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float epoch(bool printAdvancement); float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement); float evalOnDev(bool printAdvancement);
void loadOptimizer(std::filesystem::path path); void loadOptimizer(std::filesystem::path path);
......
...@@ -31,6 +31,10 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -31,6 +31,10 @@ po::options_description MacaonTrain::getOptionsDescription()
"Raw text file of the development corpus") "Raw text file of the development corpus")
("nbEpochs,n", po::value<int>()->default_value(5), ("nbEpochs,n", po::value<int>()->default_value(5),
"Number of training epochs") "Number of training epochs")
("batchSize", po::value<int>()->default_value(64),
"Number of examples per batch")
("dynamicOracleInterval", po::value<int>()->default_value(-1),
"Number of examples per batch")
("machine", po::value<std::string>()->default_value(""), ("machine", po::value<std::string>()->default_value(""),
"Reading machine file content") "Reading machine file content")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
...@@ -90,6 +94,8 @@ int MacaonTrain::main() ...@@ -90,6 +94,8 @@ int MacaonTrain::main()
auto devTsvFile = variables["devTSV"].as<std::string>(); auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>(); auto nbEpoch = variables["nbEpochs"].as<int>();
auto batchSize = variables["batchSize"].as<int>();
auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true; bool computeDevScore = variables.count("devScore") == 0 ? false : true;
...@@ -115,19 +121,11 @@ int MacaonTrain::main() ...@@ -115,19 +121,11 @@ int MacaonTrain::main()
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
SubConfig config(goldConfig, goldConfig.getNbLines());
fillDicts(machine, goldConfig); fillDicts(machine, goldConfig);
Trainer trainer(machine);
trainer.createDataset(config, debug);
if (!computeDevScore)
{
machine.getStrategy().reset();
SubConfig devConfig(devGoldConfig, devGoldConfig.getNbLines());
trainer.createDevDataset(devConfig, debug);
}
Trainer trainer(machine, batchSize);
Decoder decoder(machine); Decoder decoder(machine);
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
...@@ -154,14 +152,21 @@ int MacaonTrain::main() ...@@ -154,14 +152,21 @@ int MacaonTrain::main()
std::fclose(f); std::fclose(f);
} }
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
if (std::filesystem::exists(trainInfos)) if (std::filesystem::exists(trainInfos))
trainer.loadOptimizer(optimizerCheckpoint); trainer.loadOptimizer(optimizerCheckpoint);
for (; currentEpoch < nbEpoch; currentEpoch++) for (; currentEpoch < nbEpoch; currentEpoch++)
{ {
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
float loss = trainer.epoch(printAdvancement); float loss = trainer.epoch(printAdvancement);
machine.getStrategy().reset();
if (debug) if (debug)
fmt::print(stderr, "Decoding dev :\n"); fmt::print(stderr, "Decoding dev :\n");
std::vector<std::pair<float,std::string>> devScores; std::vector<std::pair<float,std::string>> devScores;
...@@ -169,7 +174,6 @@ int MacaonTrain::main() ...@@ -169,7 +174,6 @@ int MacaonTrain::main()
{ {
auto devConfig = devGoldConfig; auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1, debug, printAdvancement); decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile); decoder.evaluate(devConfig, modelPath, devTsvFile);
devScores = decoder.getF1Scores(machine.getPredicted()); devScores = decoder.getF1Scores(machine.getPredicted());
} }
...@@ -192,9 +196,9 @@ int MacaonTrain::main() ...@@ -192,9 +196,9 @@ int MacaonTrain::main()
if (!devScoresStr.empty()) if (!devScoresStr.empty())
devScoresStr.pop_back(); devScoresStr.pop_back();
devScoreMean /= devScores.size(); devScoreMean /= devScores.size();
bool saved = devScoreMean > bestDevScore; bool saved = devScoreMean >= bestDevScore;
if (!computeDevScore) if (!computeDevScore)
saved = devScoreMean < bestDevScore; saved = devScoreMean <= bestDevScore;
if (saved) if (saved)
{ {
bestDevScore = devScoreMean; bestDevScore = devScoreMean;
......
#include "Trainer.hpp" #include "Trainer.hpp"
#include "SubConfig.hpp" #include "SubConfig.hpp"
Trainer::Trainer(ReadingMachine & machine) : machine(machine) Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{ {
} }
void Trainer::createDataset(SubConfig & config, bool debug) void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{ {
machine.trainMode(true); SubConfig config(goldConfig, goldConfig.getNbLines());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes); extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir));
nbExamples = classes.size(); nbExamples = trainDataset->size().value();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
if (optimizer.get() == nullptr)
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
} }
void Trainer::createDevDataset(SubConfig & config, bool debug) void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{ {
machine.trainMode(false); SubConfig config(goldConfig, goldConfig.getNbLines());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes); extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
} }
void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes) void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir)
{ {
fmt::print(stderr, "[{}] Starting to extract examples\n", util::getTime()); auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
int maxNbExamplesPerFile = 250000;
int currentExampleIndex = 0;
int lastSavedIndex = 0;
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState()); config.setState(machine.getStrategy().getInitialState());
machine.getStrategy().reset();
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
mustExtract = false;
if (!mustExtract)
return;
bool dynamicOracle = epoch != 0;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
while (true) while (true)
{ {
...@@ -46,31 +81,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: ...@@ -46,31 +81,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
if (machine.hasSplitWordTransitionSet()) if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
if (config.isMultiword(config.getWordIndex()))
if (transition->getName() == "ADDCHARTOWORD")
{
config.printForDebug(stderr);
auto & splitTrans = config.getAppliableSplitTransitions();
fmt::print(stderr, "splitTrans.size() = {}\n", splitTrans.size());
for (auto & trans : splitTrans)
fmt::print(stderr, "cost {} : '{}'\n", trans->getCost(config), trans->getName());
util::myThrow(fmt::format("Transition should have been a split"));
}
if (transition->getName() == "ENDWORD")
if (config.getAsFeature("FORM",config.getWordIndex()) != config.getConst("FORM",config.getWordIndex(),0))
{
config.printForDebug(stderr);
util::myThrow(fmt::format("Words don't match"));
}
std::vector<std::vector<long>> context; std::vector<std::vector<long>> context;
try try
...@@ -83,12 +93,51 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: ...@@ -83,12 +93,51 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
util::myThrow(fmt::format("Failed to extract context : {}", e.what())); util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
} }
Transition * transition = nullptr;
if (dynamicOracle and config.getState() != "tokenizer")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
chosenTransition = i;
bestScore = score;
}
}
transition = machine.getTransitionSet().getTransition(chosenTransition);
}
else
{
transition = machine.getTransitionSet().getBestAppliableTransition(config);
}
if (!transition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex; gold[0] = goldIndex;
for (auto & element : context) for (auto & element : context)
{
currentExampleIndex++;
classes.emplace_back(gold); classes.emplace_back(gold);
}
if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
transition->apply(config); transition->apply(config);
config.addToHistory(transition->getName()); config.addToHistory(transition->getName());
...@@ -106,7 +155,15 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: ...@@ -106,7 +155,15 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
config.update(); config.update();
} }
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(classes.size())); if (!contexts.empty())
saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
} }
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement) float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
...@@ -129,8 +186,8 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -129,8 +186,8 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
if (train) if (train)
optimizer->zero_grad(); optimizer->zero_grad();
auto data = batch.data; auto data = batch.first;
auto labels = batch.target.squeeze(); auto labels = batch.second;
auto prediction = machine.getClassifier()->getNN()(data); auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1) if (prediction.dim() == 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment