Skip to content
Snippets Groups Projects
Commit 0daae795 authored by Franck Dary's avatar Franck Dary
Browse files

do not close dict during example extraction

parent ecf7290b
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
{
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
machine.setDictsState(Dict::State::Closed);
machine.getStrategy().reset();
config.addPredicted(machine.getPredicted());
......
......@@ -47,8 +47,10 @@ class ReadingMachine
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode);
void setDictsState(Dict::State state);
void saveBest() const;
void saveLast() const;
void saveDicts() const;
};
#endif
......@@ -134,7 +134,7 @@ Classifier * ReadingMachine::getClassifier()
return classifier.get();
}
void ReadingMachine::save(const std::string & modelNameTemplate) const
void ReadingMachine::saveDicts() const
{
for (auto & it : dicts)
{
......@@ -147,6 +147,11 @@ void ReadingMachine::save(const std::string & modelNameTemplate) const
std::fclose(file);
}
}
void ReadingMachine::save(const std::string & modelNameTemplate) const
{
saveDicts();
auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
torch::save(classifier->getNN(), pathToClassifier);
......@@ -175,8 +180,12 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
void ReadingMachine::trainMode(bool isTrainMode)
{
classifier->getNN()->train(isTrainMode);
}
void ReadingMachine::setDictsState(Dict::State state)
{
for (auto & it : dicts)
it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
it.second.setState(state);
}
std::map<std::string, Dict> & ReadingMachine::getDicts()
......
......@@ -124,7 +124,6 @@ int MacaonTrain::main()
fillDicts(machine, goldConfig);
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
......
......@@ -44,6 +44,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
{
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
machine.setDictsState(Dict::State::Open);
int maxNbExamplesPerFile = 250000;
int currentExampleIndex = 0;
......@@ -163,6 +164,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
machine.saveDicts();
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
}
......@@ -176,6 +179,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
torch::AutoGradMode useGrad(train);
machine.trainMode(train);
machine.setDictsState(Dict::State::Closed);
auto lossFct = torch::nn::CrossEntropyLoss();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment