Skip to content
Snippets Groups Projects
Commit 0daae795 authored by Franck Dary's avatar Franck Dary
Browse files

do not close dict during example extraction

parent ecf7290b
Branches
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.trainMode(false); machine.trainMode(false);
machine.setDictsState(Dict::State::Closed);
machine.getStrategy().reset(); machine.getStrategy().reset();
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
......
...@@ -47,8 +47,10 @@ class ReadingMachine ...@@ -47,8 +47,10 @@ class ReadingMachine
bool isPredicted(const std::string & columnName) const; bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const; const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode); void trainMode(bool isTrainMode);
void setDictsState(Dict::State state);
void saveBest() const; void saveBest() const;
void saveLast() const; void saveLast() const;
void saveDicts() const;
}; };
#endif #endif
...@@ -134,7 +134,7 @@ Classifier * ReadingMachine::getClassifier() ...@@ -134,7 +134,7 @@ Classifier * ReadingMachine::getClassifier()
return classifier.get(); return classifier.get();
} }
void ReadingMachine::save(const std::string & modelNameTemplate) const void ReadingMachine::saveDicts() const
{ {
for (auto & it : dicts) for (auto & it : dicts)
{ {
...@@ -147,6 +147,11 @@ void ReadingMachine::save(const std::string & modelNameTemplate) const ...@@ -147,6 +147,11 @@ void ReadingMachine::save(const std::string & modelNameTemplate) const
std::fclose(file); std::fclose(file);
} }
}
void ReadingMachine::save(const std::string & modelNameTemplate) const
{
saveDicts();
auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName()); auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
torch::save(classifier->getNN(), pathToClassifier); torch::save(classifier->getNN(), pathToClassifier);
...@@ -175,8 +180,12 @@ const std::set<std::string> & ReadingMachine::getPredicted() const ...@@ -175,8 +180,12 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
void ReadingMachine::trainMode(bool isTrainMode) void ReadingMachine::trainMode(bool isTrainMode)
{ {
classifier->getNN()->train(isTrainMode); classifier->getNN()->train(isTrainMode);
}
void ReadingMachine::setDictsState(Dict::State state)
{
for (auto & it : dicts) for (auto & it : dicts)
it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); it.second.setState(state);
} }
std::map<std::string, Dict> & ReadingMachine::getDicts() std::map<std::string, Dict> & ReadingMachine::getDicts()
......
...@@ -124,7 +124,6 @@ int MacaonTrain::main() ...@@ -124,7 +124,6 @@ int MacaonTrain::main()
fillDicts(machine, goldConfig); fillDicts(machine, goldConfig);
Trainer trainer(machine, batchSize); Trainer trainer(machine, batchSize);
Decoder decoder(machine); Decoder decoder(machine);
......
...@@ -44,6 +44,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -44,6 +44,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.trainMode(false); machine.trainMode(false);
machine.setDictsState(Dict::State::Open);
int maxNbExamplesPerFile = 250000; int maxNbExamplesPerFile = 250000;
int currentExampleIndex = 0; int currentExampleIndex = 0;
...@@ -163,6 +164,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -163,6 +164,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f); std::fclose(f);
machine.saveDicts();
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
} }
...@@ -176,6 +179,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -176,6 +179,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
torch::AutoGradMode useGrad(train); torch::AutoGradMode useGrad(train);
machine.trainMode(train); machine.trainMode(train);
machine.setDictsState(Dict::State::Closed);
auto lossFct = torch::nn::CrossEntropyLoss(); auto lossFct = torch::nn::CrossEntropyLoss();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment