Skip to content
Snippets Groups Projects
Commit 219be1d7 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed unknownValueThreshold usage

parent 2261c98b
No related branches found
No related tags found
No related merge requests found
...@@ -24,10 +24,11 @@ std::size_t SplitTransLSTMImpl::getInputSize() ...@@ -24,10 +24,11 @@ std::size_t SplitTransLSTMImpl::getInputSize()
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{ {
auto & splitTransitions = config.getAppliableSplitTransitions(); auto & splitTransitions = config.getAppliableSplitTransitions();
for (auto & contextElement : context)
for (int i = 0; i < maxNbTrans; i++) for (int i = 0; i < maxNbTrans; i++)
if (i < (int)splitTransitions.size()) if (i < (int)splitTransitions.size())
context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
else else
context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
} }
...@@ -22,7 +22,6 @@ class Trainer ...@@ -22,7 +22,6 @@ class Trainer
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0}; std::size_t epochNumber{0};
int batchSize; int batchSize;
int nbExamples{0};
private : private :
......
...@@ -9,11 +9,10 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem ...@@ -9,11 +9,10 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
{ {
SubConfig config(goldConfig, goldConfig.getNbLines()); SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(true);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval); extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir)); trainDataset.reset(new Dataset(dir));
nbExamples = trainDataset->size().value();
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
if (optimizer.get() == nullptr) if (optimizer.get() == nullptr)
...@@ -24,6 +23,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys ...@@ -24,6 +23,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
{ {
SubConfig config(goldConfig, goldConfig.getNbLines()); SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval); extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
devDataset.reset(new Dataset(dir)); devDataset.reset(new Dataset(dir));
...@@ -43,7 +43,6 @@ void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<to ...@@ -43,7 +43,6 @@ void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<to
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.trainMode(false);
machine.setDictsState(Dict::State::Open); machine.setDictsState(Dict::State::Open);
int maxNbExamplesPerFile = 250000; int maxNbExamplesPerFile = 250000;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment