Skip to content
Snippets Groups Projects
Select Git revision
  • 355429b0b67f501e22d33d12915850b2ed03f13d
  • main default protected
  • audio_stream-dev
  • DEPRECATED_Paul_Best_Code
  • HighBlueRec
5 results

filewriter.cpp

Blame
  • DepthLayerTreeEmbeddingModule.cpp 5.20 KiB
    #include "DepthLayerTreeEmbeddingModule.hpp"
    
    DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition)
    {
      setName(name);
      std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
      if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
            {
              try
              {
                columns = util::split(sm.str(1), ' ');
    
                for (auto & index : util::split(sm.str(2), ' '))
                  focusedBuffer.emplace_back(std::stoi(index));
    
                for (auto & index : util::split(sm.str(3), ' '))
                  focusedStack.emplace_back(std::stoi(index));
    
                for (auto & elem : util::split(sm.str(4), ' '))
                  maxElemPerDepth.emplace_back(std::stoi(elem));
    
                auto subModuleType = sm.str(5);
                auto subModuleArguments = util::split(sm.str(6), ' ');
    
                auto options = MyModule::ModuleOptions(true)
                  .bidirectional(std::stoi(subModuleArguments[0]))
                  .num_layers(std::stoi(subModuleArguments[1]))
                  .dropout(std::stof(subModuleArguments[2]))
                  .complete(std::stoi(subModuleArguments[3]));
    
                inSize = std::stoi(sm.str(7));
                int outSize = std::stoi(sm.str(8));
    
                for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
                {
                  std::string name = fmt::format("{}_{}", i, subModuleType);
                  if (subModuleType == "LSTM")
                    depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
                  else if (subModuleType == "GRU")
                    depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
                  else if (subModuleType == "Concat")
                    depthModules.emplace_back(register_module(name, Concat(inSize)));
                  else
                    util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
                }
    
              } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
            }))
        util::myThrow(fmt::format("invalid definition '{}'", definition));
    }
    
    torch::Tensor DepthLayerTreeEmbeddingModuleImpl::forward(torch::Tensor input)
    {
      auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()));
    
      std::vector<torch::Tensor> outputs;
    
      int offset = 0;
      for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
        for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
        {
          outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({context.size(0), maxElemPerDepth[depth], (long)columns.size()*context.size(2)})));
          offset += maxElemPerDepth[depth]*columns.size();
        }
    
      return torch::cat(outputs, 1);
    }
    
    std::size_t DepthLayerTreeEmbeddingModuleImpl::getOutputSize()
    {
      std::size_t outputSize = 0;
    
      for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
        outputSize += depthModules[depth]->getOutputSize(maxElemPerDepth[depth]);
    
      return outputSize*(focusedBuffer.size()+focusedStack.size());
    }
    
    std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
    {
      int inputSize = 0;
      for (int maxElem : maxElemPerDepth)
        inputSize += (focusedBuffer.size()+focusedStack.size())*maxElem*columns.size();
      return inputSize;
    }
    
    void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
    {
      auto & dict = getDict();
      std::vector<long> focusedIndexes;
    
      for (int index : focusedBuffer)
        focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
    
      for (int index : focusedStack)
        if (config.hasStack(index))
          focusedIndexes.emplace_back(config.getStack(index));
        else
          focusedIndexes.emplace_back(-1);
    
      for (auto & contextElement : context)
        for (auto index : focusedIndexes)
        {
          std::vector<std::string> childs{std::to_string(index)};
    
          for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
          {
            std::vector<std::string> newChilds;
            for (auto & child : childs)
              if (config.has(Config::childsColName, std::stoi(child), 0))
              {
                auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
                newChilds.insert(newChilds.end(), val.begin(), val.end());
              }
            childs = newChilds;
    
            for (int i = 0; i < maxElemPerDepth[depth]; i++)
              for (auto & col : columns)
                if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
                  contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col));
                else
                  contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
          }
        }
    }
    
    void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
    {
      if (!wordEmbeddings)
        wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
    }