Skip to content
Snippets Groups Projects
ModularNetwork.cpp 5.18 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "ModularNetwork.hpp"
    
    
    ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path)
    
      setName(name);
    
      std::string anyBlanks = "(?:(?:\\s|\\t)*)";
      auto splitLine = [anyBlanks](std::string line)
      {
        std::pair<std::string,std::string> result;
        util::doIfNameMatch(std::regex(fmt::format("{}(\\S+){}:{}(.+)",anyBlanks,anyBlanks,anyBlanks)),line,[&result](auto sm)
            {
              result.first = sm.str(1);
              result.second = sm.str(2);
            });
        return result;
      };
    
    
    Franck Dary's avatar
    Franck Dary committed
      std::size_t maxNbOutputs = 0;
      for (auto & it : nbOutputsPerState)
        maxNbOutputs = std::max<std::size_t>(it.second, maxNbOutputs);
    
    
      int currentInputSize = 0;
      int currentOutputSize = 0;
      std::string mlpDef;
      for (auto & line : definitions)
      {
        auto splited = splitLine(line);
        std::string name = fmt::format("{}_{}", modules.size(), splited.first);
    
        std::string nameH = fmt::format("{}_{}", getName(), name);
    
        if (splited.first == "Context")
    
          modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "StateName")
    
          modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "History")
          modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "NumericColumn")
          modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "UppercaseRate")
          modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "Focused")
    
          modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "RawInput")
    
          modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "SplitTrans")
    
          modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "AppliableTrans")
          modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "Distance")
          modules.emplace_back(register_module(name, DistanceModule(nameH, splited.second)));
    
    Franck Dary's avatar
    Franck Dary committed
        else if (splited.first == "DepthLayerTree")
    
          modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
    
        else if (splited.first == "MLP")
        {
          mlpDef = splited.second;
          continue;
        }
        else if (splited.first == "InputDropout")
        {
          inputDropout = register_module("inputDropout", torch::nn::Dropout(std::stof(splited.second)));
          continue;
        }
        else
          util::myThrow(fmt::format("unknown module '{}' for line '{}'", splited.first, line));
    
        modules.back()->setFirstInputIndex(currentInputSize);
        currentInputSize += modules.back()->getInputSize();
        currentOutputSize += modules.back()->getOutputSize();
      }
    
      if (mlpDef.empty())
        util::myThrow("no MLP definition found");
      if (inputDropout.is_empty())
        util::myThrow("no InputDropout definition found");
    
      mlp = register_module("mlp", MLP(currentOutputSize, mlpDef));
    
      for (auto & it : nbOutputsPerState)
        outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
    }
    
    torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
    {
      if (input.dim() == 1)
        input = input.unsqueeze(0);
    
      std::vector<torch::Tensor> outputs;
    
      for (auto & mod : modules)
        outputs.emplace_back(mod->forward(input));
    
      auto totalInput = inputDropout(torch::cat(outputs, 1));
    
      return outputLayersPerState.at(getState())(mlp(totalInput));
    }
    
    
    std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
    
    {
      std::vector<std::vector<long>> context(1);
      for (auto & mod : modules)
    
        mod->addToContext(context, config);
    
    void ModularNetworkImpl::registerEmbeddings()
    
    }
    
    void ModularNetworkImpl::saveDicts(std::filesystem::path path)
    {
      for (auto & mod : modules)
        mod->saveDict(path);
    }
    
    void ModularNetworkImpl::loadDicts(std::filesystem::path path)
    {
      for (auto & mod : modules)
        mod->loadDict(path);
    }
    
    void ModularNetworkImpl::setDictsState(Dict::State state)
    {
      for (auto & mod : modules)
    
      {
        if (!mod->dictIsPretrained())
          mod->getDict().setState(state);
      }
    
    }
    
    void ModularNetworkImpl::setCountOcc(bool countOcc)
    {
      for (auto & mod : modules)
        mod->getDict().countOcc(countOcc);
    }
    
    void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
    {
      std::size_t minNbElems = 1000;
    
      for (auto & mod : modules)
      {
        auto & dict = mod->getDict();
        std::size_t originalSize = dict.size();
        while (100.0*(originalSize-dict.size())/originalSize < rarityThreshold and dict.size() > minNbElems)
          dict.removeRareElements();
      }
    
    Franck Dary's avatar
    Franck Dary committed
    void ModularNetworkImpl::setState(const std::string & state)
    {
      NeuralNetworkImpl::setState(state);
      for (auto & mod : modules)
        mod->setState(state);
    }