Skip to content
Snippets Groups Projects
ModularNetwork.cpp 4.14 KiB
#include "ModularNetwork.hpp"

ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions)
{
  setName(name);
  std::string anyBlanks = "(?:(?:\\s|\\t)*)";
  auto splitLine = [anyBlanks](std::string line)
  {
    std::pair<std::string,std::string> result;
    util::doIfNameMatch(std::regex(fmt::format("{}(\\S+){}:{}(.+)",anyBlanks,anyBlanks,anyBlanks)),line,[&result](auto sm)
        {
          result.first = sm.str(1);
          result.second = sm.str(2);
        });
    return result;
  };

  int currentInputSize = 0;
  int currentOutputSize = 0;
  std::string mlpDef;
  for (auto & line : definitions)
  {
    auto splited = splitLine(line);
    std::string name = fmt::format("{}_{}", modules.size(), splited.first);
    std::string nameH = fmt::format("{}_{}", getName(), name);
    if (splited.first == "Context")
      modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
    else if (splited.first == "StateName")
      modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
    else if (splited.first == "Focused")
      modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second)));
    else if (splited.first == "RawInput")
      modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
    else if (splited.first == "SplitTrans")
      modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
    else if (splited.first == "DepthLayerTree")
      modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
    else if (splited.first == "MLP")
    {
      mlpDef = splited.second;
      continue;
    }
    else if (splited.first == "InputDropout")
    {
      inputDropout = register_module("inputDropout", torch::nn::Dropout(std::stof(splited.second)));
      continue;
    }
    else
      util::myThrow(fmt::format("unknown module '{}' for line '{}'", splited.first, line));

    modules.back()->setFirstInputIndex(currentInputSize);
    currentInputSize += modules.back()->getInputSize();
    currentOutputSize += modules.back()->getOutputSize();
  }

  if (mlpDef.empty())
    util::myThrow("no MLP definition found");
  if (inputDropout.is_empty())
    util::myThrow("no InputDropout definition found");

  mlp = register_module("mlp", MLP(currentOutputSize, mlpDef));

  for (auto & it : nbOutputsPerState)
    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}

torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);

  std::vector<torch::Tensor> outputs;

  for (auto & mod : modules)
    outputs.emplace_back(mod->forward(input));

  auto totalInput = inputDropout(torch::cat(outputs, 1));

  return outputLayersPerState.at(getState())(mlp(totalInput));
}

std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
{
  std::vector<std::vector<long>> context(1);
  for (auto & mod : modules)
    mod->addToContext(context, config);
  return context;
}

void ModularNetworkImpl::registerEmbeddings()
{
  for (auto & mod : modules)
    mod->registerEmbeddings();
}

void ModularNetworkImpl::saveDicts(std::filesystem::path path)
{
  for (auto & mod : modules)
    mod->saveDict(path);
}

void ModularNetworkImpl::loadDicts(std::filesystem::path path)
{
  for (auto & mod : modules)
    mod->loadDict(path);
}

void ModularNetworkImpl::setDictsState(Dict::State state)
{
  for (auto & mod : modules)
    mod->getDict().setState(state);
}

void ModularNetworkImpl::setCountOcc(bool countOcc)
{
  for (auto & mod : modules)
    mod->getDict().countOcc(countOcc);
}

void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
{
  std::size_t minNbElems = 1000;

  for (auto & mod : modules)
  {
    auto & dict = mod->getDict();
    std::size_t originalSize = dict.size();
    while (100.0*(originalSize-dict.size())/originalSize < rarityThreshold and dict.size() > minNbElems)
      dict.removeRareElements();
  }
}