Skip to content
Snippets Groups Projects
DepthLayerTreeEmbeddingModule.cpp 5.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "DepthLayerTreeEmbeddingModule.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      setName(name);
    
    Franck Dary's avatar
    Franck Dary committed
      std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
      if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
            {
              try
              {
                columns = util::split(sm.str(1), ' ');
    
                for (auto & index : util::split(sm.str(2), ' '))
                  focusedBuffer.emplace_back(std::stoi(index));
    
                for (auto & index : util::split(sm.str(3), ' '))
                  focusedStack.emplace_back(std::stoi(index));
    
                for (auto & elem : util::split(sm.str(4), ' '))
                  maxElemPerDepth.emplace_back(std::stoi(elem));
    
                auto subModuleType = sm.str(5);
                auto subModuleArguments = util::split(sm.str(6), ' ');
    
                auto options = MyModule::ModuleOptions(true)
                  .bidirectional(std::stoi(subModuleArguments[0]))
                  .num_layers(std::stoi(subModuleArguments[1]))
                  .dropout(std::stof(subModuleArguments[2]))
                  .complete(std::stoi(subModuleArguments[3]));
    
    
    Franck Dary's avatar
    Franck Dary committed
                int outSize = std::stoi(sm.str(8));
    
                for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
                {
                  std::string name = fmt::format("{}_{}", i, subModuleType);
                  if (subModuleType == "LSTM")
                    depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
                  else if (subModuleType == "GRU")
                    depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
    
    Franck Dary's avatar
    Franck Dary committed
                  else if (subModuleType == "Concat")
                    depthModules.emplace_back(register_module(name, Concat(inSize)));
    
    Franck Dary's avatar
    Franck Dary committed
                  else
                    util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
                }
    
              } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
            }))
        util::myThrow(fmt::format("invalid definition '{}'", definition));
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    torch::Tensor DepthLayerTreeEmbeddingModuleImpl::forward(torch::Tensor input)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
      auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()));
    
    
      std::vector<torch::Tensor> outputs;
    
    
      int offset = 0;
      for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
        for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
        {
    
    Franck Dary's avatar
    Franck Dary committed
          outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({context.size(0), maxElemPerDepth[depth], (long)columns.size()*context.size(2)})));
    
          offset += maxElemPerDepth[depth]*columns.size();
        }
    
    
      return torch::cat(outputs, 1);
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    std::size_t DepthLayerTreeEmbeddingModuleImpl::getOutputSize()
    
    {
      std::size_t outputSize = 0;
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
    
        outputSize += depthModules[depth]->getOutputSize(maxElemPerDepth[depth]);
    
      return outputSize*(focusedBuffer.size()+focusedStack.size());
    
    Franck Dary's avatar
    Franck Dary committed
    std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
    
      int inputSize = 0;
      for (int maxElem : maxElemPerDepth)
        inputSize += (focusedBuffer.size()+focusedStack.size())*maxElem*columns.size();
      return inputSize;
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      auto & dict = getDict();
    
      std::vector<long> focusedIndexes;
    
      for (int index : focusedBuffer)
        focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
    
      for (int index : focusedStack)
        if (config.hasStack(index))
          focusedIndexes.emplace_back(config.getStack(index));
        else
          focusedIndexes.emplace_back(-1);
    
      for (auto & contextElement : context)
        for (auto index : focusedIndexes)
        {
    
          std::vector<std::string> childs{std::to_string(index)};
    
    Franck Dary's avatar
    Franck Dary committed
    
    
          for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
          {
            std::vector<std::string> newChilds;
            for (auto & child : childs)
              if (config.has(Config::childsColName, std::stoi(child), 0))
              {
                auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
                newChilds.insert(newChilds.end(), val.begin(), val.end());
              }
            childs = newChilds;
    
            for (int i = 0; i < maxElemPerDepth[depth]; i++)
              for (auto & col : columns)
    
                if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
                  contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i]))));
    
                else
                  contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
          }
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
    
      wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));