Skip to content
Snippets Groups Projects
Select Git revision
  • f5c33df2f38aacab7c1bc833eb6a038d60406cb1
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

EarlyFusion.py

Blame
  • Submodule.cpp 5.40 KiB
    #include "Submodule.hpp"
    #include "WordEmbeddings.hpp"
    
    bool Submodule::reloadPretrained = false;
    
    void Submodule::setReloadPretrained(bool value)
    {
      reloadPretrained = value;
    }
    
    void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
    {
      this->firstInputIndex = firstInputIndex;
    }
    
    void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix)
    {
      if (path.empty())
        return;
      if (!is_training() and !reloadPretrained)
        return;
    
      if (!std::filesystem::exists(path))
        util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
    
      std::vector<std::vector<float>> toAdd;
    
      torch::NoGradGuard no_grad;
    
      auto originalState = getDict().getState();
      getDict().setState(Dict::State::Open);
    
      std::FILE * file = std::fopen(path.c_str(), "r");
      char buffer[100000];
    
      bool firstLine = true;
      std::size_t embeddingsSize = embeddings->parameters()[0].size(-1);
    
      try
      {
        while (!std::feof(file))
        {
          if (buffer != std::fgets(buffer, 100000, file))
            break;
    
          if (firstLine)
          {
            firstLine = false;
            continue;
          }
    
          auto splited = util::split(util::strip(buffer), ' ');
    
          if (splited.size() < 2)
            util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
    
          std::string word;
    
          if (splited[0] == "<unk>")
            word = Dict::unknownValueStr;
          else
            word = splited[0];
    
          auto toInsert = util::splitAsUtf8(word);
          toInsert.replace("◌", " ");
          word = fmt::format("{}", toInsert);
    
          auto dictIndex = getDict().getIndexOrInsert(word, prefix);
    
          if (embeddingsSize != splited.size()-1)
            util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
    
          if (dictIndex >= embeddings->weight.size(0))
          {
            toAdd.emplace_back();
            for (unsigned int i = 1; i < splited.size(); i++)
              toAdd.back().emplace_back(std::stof(splited[i]));
          }
          else
          {
            for (unsigned int i = 1; i < splited.size(); i++)
              embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
          }
        }
      } catch (std::exception & e)
      {
        util::myThrow(fmt::format("caught '{}' for SubModule '{}'", e.what(), getName()));
      }
    
      std::fclose(file);
    
      if (firstLine)
        util::myThrow(fmt::format("file '{}' is empty", path.string()));
    
      if (!toAdd.empty())
      {
        auto newEmb = torch::nn::Embedding(embeddings->weight.size(0)+toAdd.size(), embeddingsSize);
        for (unsigned int i = 0; i < embeddings->weight.size(0); i++)
          newEmb->weight[i] = embeddings->weight[i];
        for (unsigned int i = 0; i < toAdd.size(); i++)
          for (unsigned int j = 0; j < embeddingsSize; j++)
            newEmb->weight[embeddings->weight.size(0)+i][j] = toAdd[i][j];
        embeddings->weight = newEmb->weight;
      }
    
      getDict().setState(originalState);
      embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
    }
    
    std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames)
    {
      static auto prefix = [](const std::string & s, int length)
      {
        if (s.size() == 0)
          return s;
    
        util::utf8string utf8s = util::splitAsUtf8(s);
        util::utf8string prefix(utf8s.begin(), std::min(utf8s.end(),utf8s.begin()+length));
        return fmt::format("{}", prefix);
      };
    
      static auto suffix = [](const std::string & s, int length)
      {
        if (s.size() == 0)
          return s;
    
        util::utf8string utf8s = util::splitAsUtf8(s);
        util::utf8string suffix(std::max(utf8s.begin(), utf8s.end()-length), utf8s.end());
        return fmt::format("{}", suffix);
      };
    
      static std::map<std::string, std::function<std::string(const std::string &)>> functions
      {
        {"lower", [](const std::string & s) {return util::lower(s);}},
        {"prefix1", [](const std::string & s) {return prefix(s, 1);}},
        {"prefix2", [](const std::string & s) {return prefix(s, 2);}},
        {"prefix3", [](const std::string & s) {return prefix(s, 3);}},
        {"prefix4", [](const std::string & s) {return prefix(s, 4);}},
        {"prefix5", [](const std::string & s) {return prefix(s, 5);}},
        {"prefix6", [](const std::string & s) {return prefix(s, 6);}},
        {"prefix7", [](const std::string & s) {return prefix(s, 7);}},
        {"suffix1", [](const std::string & s) {return suffix(s, 1);}},
        {"suffix2", [](const std::string & s) {return suffix(s, 2);}},
        {"suffix3", [](const std::string & s) {return suffix(s, 3);}},
        {"suffix4", [](const std::string & s) {return suffix(s, 4);}},
        {"suffix5", [](const std::string & s) {return suffix(s, 5);}},
        {"suffix6", [](const std::string & s) {return suffix(s, 6);}},
        {"suffix7", [](const std::string & s) {return suffix(s, 7);}},
      };
    
      auto splited = util::split(functionNames, ':');
      if (splited.size() == 1)
        return [](const std::string & s){return s;};
    
      std::vector<std::function<std::string(const std::string &)>> sequence;
    
      for (unsigned int i = 0; i < splited.size()-1; i++)
      {
        auto & functionName = splited[i];
        auto it = functions.find(util::lower(functionName));
        if (it == functions.end())
          util::myThrow(fmt::format("unknown function name '{}'", functionName));
    
        sequence.emplace_back(it->second);
      }
    
      return [sequence](const std::string & s)
      {
        auto result = s; 
        for (auto & f : sequence)
          result = f(result);
        return result;
      };
    }