diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index f34f966ef669e94f8d1fb32ab6d73b5e54346d51..c583491f50cb80588cccac6a088e975a503c005f 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -11,9 +11,11 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder private : static torch::Device device; + std::map<std::string, torch::Tensor> lossParameters; public : + torch::Tensor getLossParameter(std::string state); virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0; virtual torch::Tensor extractContext(Config & config) = 0; virtual void registerEmbeddings(bool loadPretrained) = 0; diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 84c0e13998ab30a9524f28e4808492d8954a5b1e..a3c072390995ed828b3a7e0c06199769e97ec412 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -81,7 +81,10 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st mlp = register_module("mlp", MLP(currentOutputSize, mlpDef)); for (auto & it : nbOutputsPerState) + { outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); + getLossParameter(it.first); + } } torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state) diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index fe3727dc18827225b17bbf6db9133b7dd67ce54e..55a00d5538e35b1e1ee95010b98e2a8985208c3a 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,6 +2,14 @@ torch::Device NeuralNetworkImpl::device(getPreferredDevice()); +torch::Tensor NeuralNetworkImpl::getLossParameter(std::string state) +{ + if (lossParameters.count(state) == 0) + lossParameters[state] = register_parameter(fmt::format("lossParam_{}", state), torch::ones(1, torch::TensorOptions().device(NeuralNetworkImpl::getDevice()).requires_grad(true))); + + return lossParameters[state]; +} + float NeuralNetworkImpl::entropy(torch::Tensor probabilities) { if (probabilities.dim() != 1) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index da9fb730ca33d601758439495a8e1bc233f15109..8051cca2a01b2c877e4be944e1988e5f3ac573eb 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -247,7 +247,9 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance labels /= util::float2longScale; } - auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels); + auto lossParameter = machine.getClassifier(state)->getNN()->getLossParameter(state); + + auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels)*(1.0/torch::exp(lossParameter)) + lossParameter; float lossAsFloat = 0.0; try {