Skip to content
Snippets Groups Projects
Commit a4844d76 authored by Franck Dary's avatar Franck Dary
Browse files

Changed the loss to CrossEntropy, to avoid using log (caused exploding gradient)

parent 096b59d9
No related branches found
No related tags found
No related merge requests found
......@@ -35,8 +35,6 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
// reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
auto res = torch::softmax(linear2(torch::relu(linear1(reshaped))), reshaped.dim() == 2 ? 1 : 0);
return res;
return linear2(torch::relu(linear1(reshaped)));
}
......@@ -45,7 +45,7 @@ torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
if (reshaped.dim() == 3)
reshaped = wordsAsEmb.permute({1,0,2});
auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0);
auto res = linear(reshaped[focusedIndex]);
return res;
}
......
......@@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
}
float Trainer::epoch(bool printAdvancement)
......@@ -70,6 +70,8 @@ float Trainer::epoch(bool printAdvancement)
int nbExamplesUntilPrint = printInterval;
int currentBatchNumber = 0;
auto lossFct = torch::nn::CrossEntropyLoss();
for (auto & batch : *dataLoader)
{
denseOptimizer->zero_grad();
......@@ -80,7 +82,7 @@ float Trainer::epoch(bool printAdvancement)
auto prediction = machine.getClassifier()->getNN()(data);
auto loss = torch::nll_loss(torch::log(prediction), labels);
auto loss = lossFct(prediction, labels);
try
{
totalLoss += loss.item<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment