Skip to content
Snippets Groups Projects
Commit a4844d76 authored by Franck Dary's avatar Franck Dary
Browse files

Changed the loss to CrossEntropy, to avoid using log (caused exploding gradient)

parent 096b59d9
No related branches found
No related tags found
No related merge requests found
...@@ -35,8 +35,6 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) ...@@ -35,8 +35,6 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
// reshaped dim = {batch, sequence of embeddings} // reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
auto res = torch::softmax(linear2(torch::relu(linear1(reshaped))), reshaped.dim() == 2 ? 1 : 0); return linear2(torch::relu(linear1(reshaped)));
return res;
} }
...@@ -45,7 +45,7 @@ torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) ...@@ -45,7 +45,7 @@ torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
if (reshaped.dim() == 3) if (reshaped.dim() == 3)
reshaped = wordsAsEmb.permute({1,0,2}); reshaped = wordsAsEmb.permute({1,0,2});
auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0); auto res = linear(reshaped[focusedIndex]);
return res; return res;
} }
......
...@@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5))); denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5))); sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
} }
float Trainer::epoch(bool printAdvancement) float Trainer::epoch(bool printAdvancement)
...@@ -70,6 +70,8 @@ float Trainer::epoch(bool printAdvancement) ...@@ -70,6 +70,8 @@ float Trainer::epoch(bool printAdvancement)
int nbExamplesUntilPrint = printInterval; int nbExamplesUntilPrint = printInterval;
int currentBatchNumber = 0; int currentBatchNumber = 0;
auto lossFct = torch::nn::CrossEntropyLoss();
for (auto & batch : *dataLoader) for (auto & batch : *dataLoader)
{ {
denseOptimizer->zero_grad(); denseOptimizer->zero_grad();
...@@ -80,7 +82,7 @@ float Trainer::epoch(bool printAdvancement) ...@@ -80,7 +82,7 @@ float Trainer::epoch(bool printAdvancement)
auto prediction = machine.getClassifier()->getNN()(data); auto prediction = machine.getClassifier()->getNN()(data);
auto loss = torch::nll_loss(torch::log(prediction), labels); auto loss = lossFct(prediction, labels);
try try
{ {
totalLoss += loss.item<float>(); totalLoss += loss.item<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment