Skip to content
Snippets Groups Projects
Commit ecf7290b authored by Franck Dary's avatar Franck Dary
Browse files

Added dynamical oracle and extracted examples are savec to the disk not to use too much memory

parent 8023999c
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
{
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
machine.getStrategy().reset();
config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50;
......@@ -27,9 +28,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto dictState = machine.getDict(config.getState()).getState();
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
machine.getDict(config.getState()).setState(dictState);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
......
......@@ -4,19 +4,24 @@
#include <torch/torch.h>
#include "Config.hpp"
class ConfigDataset : public torch::data::Dataset<ConfigDataset>
class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::pair<torch::Tensor,torch::Tensor>>
{
private :
torch::Tensor data;
std::size_t size_{0};
std::size_t contextSize{0};
std::vector<std::tuple<int,int,std::filesystem::path>> exampleLocations;
torch::Tensor loadedTensor;
std::optional<std::size_t> loadedTensorIndex;
std::size_t nextIndexToGive{0};
public :
explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes);
torch::optional<size_t> size() const override;
torch::data::Example<> get(size_t index) override;
explicit ConfigDataset(std::filesystem::path dir);
c10::optional<std::size_t> size() const override;
c10::optional<std::pair<torch::Tensor,torch::Tensor>> get_batch(std::size_t batchSize) override;
void reset() override;
void load(torch::serialize::InputArchive &) override;
void save(torch::serialize::OutputArchive &) const override;
};
#endif
#include "ConfigDataset.hpp"
#include "NeuralNetwork.hpp"
ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
ConfigDataset::ConfigDataset(std::filesystem::path dir)
{
if (contexts.size() != classes.size())
util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size()));
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
{
auto splited = util::split(entry.path().stem().string(), '-');
if (splited.size() != 2)
continue;
exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path()));
size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
}
}
c10::optional<std::size_t> ConfigDataset::size() const
{
return size_;
}
size_ = contexts.size();
contextSize = contexts.back().size(0);
std::vector<torch::Tensor> total;
for (unsigned int i = 0; i < contexts.size(); i++)
c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(std::size_t batchSize)
{
if (!loadedTensorIndex.has_value())
{
loadedTensorIndex = 0;
nextIndexToGive = 0;
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
}
if ((int)nextIndexToGive >= loadedTensor.size(0))
{
total.emplace_back(contexts[i]);
total.emplace_back(classes[i]);
nextIndexToGive = 0;
loadedTensorIndex = loadedTensorIndex.value() + 1;
if (loadedTensorIndex >= exampleLocations.size())
return c10::optional<std::pair<torch::Tensor,torch::Tensor>>();
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
}
data = torch::cat(total);
std::pair<torch::Tensor, torch::Tensor> batch;
if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
{
batch.first = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
batch.second = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive += batchSize;
}
else
{
batch.first = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
batch.second = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive = loadedTensor.size(0);
}
return batch;
}
torch::optional<size_t> ConfigDataset::size() const
void ConfigDataset::reset()
{
std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
loadedTensorIndex = std::optional<std::size_t>();
nextIndexToGive = 0;
}
void ConfigDataset::load(torch::serialize::InputArchive &)
{
return size_;
}
torch::data::Example<> ConfigDataset::get(size_t index)
void ConfigDataset::save(torch::serialize::OutputArchive &) const
{
return {data.narrow(0, index*(contextSize+1), contextSize).to(NeuralNetworkImpl::device), data.narrow(0, index*(contextSize+1)+contextSize, 1).to(NeuralNetworkImpl::device)};
}
......@@ -10,28 +10,31 @@ class Trainer
private :
using Dataset = ConfigDataset;
using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;
using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>;
private :
ReadingMachine & machine;
std::unique_ptr<Dataset> trainDataset{nullptr};
std::unique_ptr<Dataset> devDataset{nullptr};
DataLoader dataLoader{nullptr};
DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize{64};
int batchSize;
int nbExamples{0};
private :
void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes);
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float processDataset(DataLoader & loader, bool train, bool printAdvancement);
void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir);
public :
Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig, bool debug);
void createDevDataset(SubConfig & goldConfig, bool debug);
Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
void loadOptimizer(std::filesystem::path path);
......
......@@ -31,6 +31,10 @@ po::options_description MacaonTrain::getOptionsDescription()
"Raw text file of the development corpus")
("nbEpochs,n", po::value<int>()->default_value(5),
"Number of training epochs")
("batchSize", po::value<int>()->default_value(64),
"Number of examples per batch")
("dynamicOracleInterval", po::value<int>()->default_value(-1),
"Number of examples per batch")
("machine", po::value<std::string>()->default_value(""),
"Reading machine file content")
("help,h", "Produce this help message");
......@@ -90,6 +94,8 @@ int MacaonTrain::main()
auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
auto batchSize = variables["batchSize"].as<int>();
auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
......@@ -115,19 +121,11 @@ int MacaonTrain::main()
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
SubConfig config(goldConfig, goldConfig.getNbLines());
fillDicts(machine, goldConfig);
Trainer trainer(machine);
trainer.createDataset(config, debug);
if (!computeDevScore)
{
machine.getStrategy().reset();
SubConfig devConfig(devGoldConfig, devGoldConfig.getNbLines());
trainer.createDevDataset(devConfig, debug);
}
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
......@@ -154,14 +152,21 @@ int MacaonTrain::main()
std::fclose(f);
}
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
if (std::filesystem::exists(trainInfos))
trainer.loadOptimizer(optimizerCheckpoint);
for (; currentEpoch < nbEpoch; currentEpoch++)
{
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval);
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
float loss = trainer.epoch(printAdvancement);
machine.getStrategy().reset();
if (debug)
fmt::print(stderr, "Decoding dev :\n");
std::vector<std::pair<float,std::string>> devScores;
......@@ -169,7 +174,6 @@ int MacaonTrain::main()
{
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile);
devScores = decoder.getF1Scores(machine.getPredicted());
}
......@@ -192,9 +196,9 @@ int MacaonTrain::main()
if (!devScoresStr.empty())
devScoresStr.pop_back();
devScoreMean /= devScores.size();
bool saved = devScoreMean > bestDevScore;
bool saved = devScoreMean >= bestDevScore;
if (!computeDevScore)
saved = devScoreMean < bestDevScore;
saved = devScoreMean <= bestDevScore;
if (saved)
{
bestDevScore = devScoreMean;
......
#include "Trainer.hpp"
#include "SubConfig.hpp"
Trainer::Trainer(ReadingMachine & machine) : machine(machine)
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{
}
void Trainer::createDataset(SubConfig & config, bool debug)
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
machine.trainMode(true);
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
SubConfig config(goldConfig, goldConfig.getNbLines());
extractExamples(config, debug, contexts, classes);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir));
nbExamples = classes.size();
nbExamples = trainDataset->size().value();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
if (optimizer.get() == nullptr)
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
}
void Trainer::createDevDataset(SubConfig & config, bool debug)
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
machine.trainMode(false);
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
SubConfig config(goldConfig, goldConfig.getNbLines());
extractExamples(config, debug, contexts, classes);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes)
void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir)
{
fmt::print(stderr, "[{}] Starting to extract examples\n", util::getTime());
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
int maxNbExamplesPerFile = 250000;
int currentExampleIndex = 0;
int lastSavedIndex = 0;
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
machine.getStrategy().reset();
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
mustExtract = false;
if (!mustExtract)
return;
bool dynamicOracle = epoch != 0;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
while (true)
{
......@@ -46,31 +81,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
if (config.isMultiword(config.getWordIndex()))
if (transition->getName() == "ADDCHARTOWORD")
{
config.printForDebug(stderr);
auto & splitTrans = config.getAppliableSplitTransitions();
fmt::print(stderr, "splitTrans.size() = {}\n", splitTrans.size());
for (auto & trans : splitTrans)
fmt::print(stderr, "cost {} : '{}'\n", trans->getCost(config), trans->getName());
util::myThrow(fmt::format("Transition should have been a split"));
}
if (transition->getName() == "ENDWORD")
if (config.getAsFeature("FORM",config.getWordIndex()) != config.getConst("FORM",config.getWordIndex(),0))
{
config.printForDebug(stderr);
util::myThrow(fmt::format("Words don't match"));
}
std::vector<std::vector<long>> context;
try
......@@ -83,12 +93,51 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Transition * transition = nullptr;
if (dynamicOracle and config.getState() != "tokenizer")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
chosenTransition = i;
bestScore = score;
}
}
transition = machine.getTransitionSet().getTransition(chosenTransition);
}
else
{
transition = machine.getTransitionSet().getBestAppliableTransition(config);
}
if (!transition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
for (auto & element : context)
{
currentExampleIndex++;
classes.emplace_back(gold);
}
if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
transition->apply(config);
config.addToHistory(transition->getName());
......@@ -106,7 +155,15 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
config.update();
}
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(classes.size()));
if (!contexts.empty())
saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
}
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
......@@ -129,8 +186,8 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
if (train)
optimizer->zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto data = batch.first;
auto labels = batch.second;
auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment