Skip to content
Snippets Groups Projects
Select Git revision
  • 10483f58f7d28e8bcb17752c4ed33e6a01254246
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

Fusion.py

Blame
  • LossFunction.cpp 1.93 KiB
    #include "LossFunction.hpp"
    #include "util.hpp"
    
    void LossFunction::init(std::string name)
    {
      this->name = name;
    
      if (util::lower(name) == "crossentropy")
        fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean));
      else if (util::lower(name) == "bce")
        fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
      else if (util::lower(name) == "mse")
        fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));
      else if (util::lower(name) == "hinge")
        fct = CustomHingeLoss();
      else
        util::myThrow(fmt::format("unknown loss function name '{}' available losses are 'crossentropy, bce, mse, hinge'", name));
    }
    
    torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold)
    {
      try
      {
        auto index = fct.index();
    
        if (index == 0)
          return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0)));
        if (index == 1)
          return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
        if (index == 2)
          return std::get<2>(fct)(prediction, gold);
        if (index == 3)
          return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
      } catch (std::exception & e)
      {
        util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what()));
      }
    
      util::myThrow("loss is not defined");
      return torch::Tensor();
    }
    
    torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<long> & goldIndexes) const
    {
      auto index = fct.index();
    
      if (index == 0 or index == 2)
      {
        auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
        gold[0] = goldIndexes.at(0);
        return gold;
      }
      if (index == 1 or index == 3)
      {
        auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
        for (auto goldIndex : goldIndexes)
          gold[goldIndex] = 1;
        return gold;
      }
    
      util::myThrow("loss is not defined");
      return torch::Tensor();
    }