Skip to content
Snippets Groups Projects
Select Git revision
  • 8ec72e5c6b4891a4680fdbd266624e0a177408b2
  • main default protected
2 results

functions.py

Blame
  • Submodule.cpp 5.40 KiB
    #include "Submodule.hpp"
    #include "WordEmbeddings.hpp"
    
    bool Submodule::reloadPretrained = false;
    
    void Submodule::setReloadPretrained(bool value)
    {
      reloadPretrained = value;
    }
    
    void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
    {
      this->firstInputIndex = firstInputIndex;
    }
    
    void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
    {
      if (path.empty())
        return;
      if (!is_training() and !reloadPretrained)
        return;
    
      if (!std::filesystem::exists(path))
        util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
    
      std::vector<std::vector<float>> toAdd;
    
      torch::NoGradGuard no_grad;
    
      auto originalState = getDict().getState();
      getDict().setState(Dict::State::Open);
    
      std::FILE * file = std::fopen(path.c_str(), "r");
      char buffer[100000];
    
      bool firstLine = true;
      std::size_t embeddingsSize = embeddings->parameters()[0].size(-1);
    
      try
      {
        while (!std::feof(file))
        {
          if (buffer != std::fgets(buffer, 100000, file))
            break;
    
          if (firstLine)
          {
            firstLine = false;
            continue;
          }
    
          auto splited = util::split(util::strip(buffer), ' ');
    
          if (splited.size() < 2)
            util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
    
          std::string word;
    
          if (splited[0] == "<unk>")
            word = Dict::unknownValueStr;
          else
            word = splited[0];
    
          auto toInsert = util::splitAsUtf8(word);
          toInsert.replace("◌", " ");
          word = fmt::format("{}", toInsert);
    
          auto dictIndex = getDict().getIndexOrInsert(word, prefix);
    
          if (embeddingsSize != splited.size()-1)