Skip to content
Snippets Groups Projects
Select Git revision
  • 4967475f60cdb80e5c23a7c2797802aaccaa60b9
  • master default protected
  • fullUD
  • movementInAction
4 results

Oracle.cpp

Blame
  • Oracle.cpp 20.18 KiB
    #include "Oracle.hpp"
    #include "util.hpp"
    #include "File.hpp"
    #include "Action.hpp"
    #include "ProgramParameters.hpp"
    
    std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle;
    
    Oracle::Oracle(std::function<void(Oracle *)> initialize,
                   std::function<std::string(Config &, Oracle *)> findAction,
                   std::function<int(Config &, Oracle *, const std::string &)> getCost)
    {
      this->getCost = getCost;
      this->findAction = findAction;
      this->initialize = initialize;
      this->isInit = false;
    }
    
    Oracle * Oracle::getOracle(const std::string & name, std::string filename)
    {
      createDatabase();
    
      auto it = str2oracle.find(name);
    
      if(it != str2oracle.end())
      {
        if(!filename.empty())
          it->second->setFilename(filename);
        return it->second.get();
      }
    
      fprintf(stderr, "ERROR (%s) : invalid oracle name \'%s\'. Aborting.\n", ERRINFO, name.c_str());
    
      exit(1);
    
      return nullptr;
    }
    
    Oracle * Oracle::getOracle(const std::string & name)
    {
      return getOracle(name, "");
    }
    
    int Oracle::getActionCost(Config & config, const std::string & action)
    {
      if(!isInit)
        init();
    
      return getCost(config, this, action);
    }
    
    std::string Oracle::getAction(Config & config)
    {
      if(!isInit)
        init();
    
      return findAction(config, this);
    }
    
    void Oracle::init()
    {
      isInit = true;
      initialize(this);
    }
    
    void Oracle::setFilename(const std::string & filename)
    {
      this->filename = filename;
    }
    
    void Oracle::createDatabase()
    {
      static bool isInit = false;
      if(isInit)
        return;
      isInit = true;
    
      str2oracle.emplace("null", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config &, Oracle *)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO);
        exit(1);
    
        return std::string("");
      },
      [](Config &, Oracle *, const std::string &)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO);
        exit(1);
    
        return 1;
      })));
    
      str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config & c, Oracle *)
      {
        if (c.getCurrentStateHistory().size() >= 2 && (c.getCurrentStateHistory().top() == "BACK" || c.getCurrentStateHistory().getElem(1) == "BACK"))
          return std::string("EPSILON");
    
        if (c.getCurrentStateHistory().size() < 2)
          return std::string("EPSILON");
    
        if (c.hashHistory.contains(c.computeHash()))
          return std::string("EPSILON");
    
        //return std::string("BACK 1");
    
        auto & pos = c.getTape("POS");
    
        if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "det" && pos[c.head] == "prorel")
          return std::string("BACK 1");
    
        if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "det" && pos[c.head] == "prep")
          return std::string("BACK 1");
    
        if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "nc" && pos[c.head] == "nc")
          return std::string("BACK 1");
    
        if (c.head > 0 && pos[c.head-1] != pos.ref[c.head-1] && pos[c.head-1] == "nc" && pos[c.head] == "prep")
          return std::string("BACK 1");
    
        return std::string("EPSILON");
      },
      [](Config &, Oracle *, const std::string &)
      {
        return 0;
      })));
    
      str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config & c, Oracle *)
      {
        if (c.getCurrentStateHistory().size() >= 2 && (c.getCurrentStateHistory().top() == "BACK" || c.getCurrentStateHistory().getElem(1) == "BACK"))
          return std::string("EPSILON");
    
        if (c.getCurrentStateHistory().size() < 2)
          return std::string("EPSILON");
    
        if (c.hashHistory.contains(c.computeHash()))
          return std::string("EPSILON");
    
        //return std::string("BACK 1");
    
        auto & morpho = c.getTape("MORPHO");
    
        if (c.head <= 0)
          return std::string("EPSILON");
    
        auto & morphoRef = morpho.ref[c.head-1];
        auto & morpho0 = morpho[c.head-1];
        auto & morpho1 = morpho[c.head];
    
        if (morpho0 == morphoRef)
          return std::string("EPSILON");
    
        auto genre0 = split(morpho0, '|')[0];
        auto genre1 = split(morpho1, '|')[0];
    
        if (genre0 == "g=f" && genre1 == "g=m")
          return std::string("BACK 1");
    
        if (genre0 == "g=m" && genre1 == "g=f")
          return std::string("BACK 1");
    
        return std::string("EPSILON");
      },
      [](Config &, Oracle *, const std::string &)
      {
        return 0;
      })));
    
      str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config &, Oracle *)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
        exit(1);
    
        return std::string("");
      },
      [](Config & c, Oracle *, const std::string & action)
      {
        return action == "WRITE b.0 POS " + c.getTape("POS").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
      })));
    
      str2oracle.emplace("tokenizer", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config &, Oracle *)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
        exit(1);
    
        return std::string("");
      },
      [](Config & c, Oracle *, const std::string & action)
      {
        return action == "WRITE b.0 BIO " + c.getTape("BIO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
      })));
    
      str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config &, Oracle *)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
        exit(1);
    
        return std::string("");
      },
      [](Config & c, Oracle *, const std::string & action)
      {
        return action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).ref[c.head] == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0")) || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
      })));
    
      str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config &, Oracle *)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
        exit(1);
    
        return std::string("");
      },
      [](Config & c, Oracle *, const std::string & action)
      {
        return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
      })));
    
      str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle * oracle)
      {
        File file(oracle->filename, "r");
        FILE * fd = file.getDescriptor();
        char b1[1024];
        char b2[1024];
    
        while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) != 2);
        while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) == 2)
          oracle->data[b1] = b2;
      },
      [](Config & c, Oracle * oracle)
      {
        int window = 3;
        int start = std::max<int>(c.head-window, 0);
        int end = std::min<int>(c.head+window, c.getTape("SGN").hyp.size()-1);
    
        while (start < (int)c.getTape("SGN").hyp.size() && !c.getTape("SGN").hyp[start].empty())
          start++;
    
        if (start > end)
          return std::string("NOTHING");
    
        std::string action("MULTIWRITE " + std::to_string(start-c.head) + " " + std::to_string(end-c.head) + " " + std::string("SGN"));
    
        for(int i = start; i <= end; i++)
        {
          const std::string & form = c.getTape("FORM").ref[i];
          std::string & signature = oracle->data[form];
    
          if(signature.empty())
           signature = oracle->data[noAccentLower(form)];
    
          if(signature.empty())
            action += std::string(" ") + "UNKNOWN";
          else
            action += std::string(" ") + signature;
        }
    
        return action;
      },
      [](Config &, Oracle *, const std::string &)
      {
        return 0;
      })));
    
      str2oracle.emplace("lemma_lookup", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle * oracle)
      {
        File file(oracle->filename, "r");
        FILE * fd = file.getDescriptor();
        char b1[1024];
        char b2[1024];
        char b3[1024];
        char b4[1024];
    
        while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) != 4);
        while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) == 4)
        {
          oracle->data[std::string(b1) + std::string("_") + b2] = b3;
          oracle->data[std::string(b1) + std::string("_??")] = b3;
        }
      },
      [](Config & c, Oracle * oracle)
      {
        const std::string & form = c.getTape("FORM")[c.head];
        const std::string & pos = c.getTape("POS")[c.head];
        std::string & lemma = oracle->data[form + "_" + pos];
    
        if(lemma.empty())
          lemma = oracle->data[noAccentLower(form) + "_" + pos];
    
        if(lemma.empty())
          lemma = oracle->data[form + "_??"];
    
        if(lemma.empty())
          lemma = oracle->data[noAccentLower(form) + "_??"];
    
        if(lemma.empty())
          return std::string("NOTFOUND");
        else
          return std::string("WRITE b.0 LEMMA ") + lemma;
      },
      [](Config &, Oracle *, const std::string &)
      {
        return 0;
      })));
    
      str2oracle.emplace("lemma_rules", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config &, Oracle *)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
        exit(1);
    
        return std::string("");
      },
      [](Config & c, Oracle *, const std::string & action)
      {
        const std::string & form = c.getTape("FORM").ref[c.head];
        const std::string & lemma = c.getTape("LEMMA").ref[c.head];
        std::string rule = getRule(form, lemma);
    
        return action == std::string("RULE LEMMA ON FORM ") + rule || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
      })));
    
      str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle *)
      {
      },
      [](Config &, Oracle *)
      {
        fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
        exit(1);
    
        return std::string("");
      },
      [](Config & c, Oracle *, const std::string & action)
      {
        auto & labels = c.getTape("LABEL");
        auto & govs = c.getTape("GOV");
        auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape);
    
        int head = c.head;
        int stackHead = c.stackEmpty() ? 0 : c.stackTop();
        int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
        int headGov = head + std::stoi(govs.ref[head]);
        int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
        int sentenceEnd = c.head;
    
        int cost = 0;
    
        while(sentenceStart >= 0 && eos.ref[sentenceStart] != ProgramParameters::sequenceDelimiter)
          sentenceStart--;
        if (sentenceStart != 0)
          sentenceStart++;
        while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != ProgramParameters::sequenceDelimiter)
          sentenceEnd++;
        if (sentenceEnd == (int)eos.ref.size())
          sentenceEnd--;
    
        auto parts = split(action);
    
        if (parts[0] == "SHIFT")
        {
          for (int i = sentenceStart; i <= sentenceEnd; i++)
          {
            if (!isNum(govs.ref[i]))
            {
              fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.ref[i].c_str());
              exit(1);
            }
    
            int otherGov = i + std::stoi(govs.ref[i]);
    
            for (int j = 0; j < c.stackSize(); j++)
            {
              auto s = c.stackGetElem(j);
              if (s == i)
                if (otherGov == head || headGov == s)
                  cost++;
            }
          }
    
          return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
        }
        else if (parts[0] == "WRITE" && parts.size() == 4)
        {
          auto object = split(parts[1], '.');
          if (object[0] == "b")
          {
            if (parts[2] == "LABEL")
              return action == "WRITE b.0 LABEL " + c.getTape("LABEL").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 || c.getTape("LABEL").ref[c.head] == "root" ? 0 : 1;
            else if (parts[2] == "GOV")
              return action == "WRITE b.0 GOV " + c.getTape("GOV").ref[c.head] || c.head >= (int)c.tapes[0].ref.size()-1 ? 0 : 1;
          }
          else if (object[0] == "s")
          {
            int index = c.stackGetElem(-1);
            if (parts[2] == "LABEL")
              return action == "WRITE s.-1 LABEL " + c.getTape("LABEL").ref[index] || c.getTape("LABEL").ref[index] == "root" ? 0 : 1;
            else if (parts[2] == "GOV")
              return action == "WRITE s.-1 GOV " + c.getTape("GOV").ref[index] ? 0 : 1;
          }
    
          return 1;
        }
        else if (parts[0] == "REDUCE")
        {
          for (int i = head; i <= sentenceEnd; i++)
          {
            int otherGov = i + std::stoi(govs.ref[i]);
            if (otherGov == stackHead)
              cost++;
          }
    
          return eos.ref[stackHead] != ProgramParameters::sequenceDelimiter ? cost : cost+1;
        }
        else if (parts[0] == "LEFT")
        {
          if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
            cost++;
    
          for (int i = head+1; i <= sentenceEnd; i++)
          {
            int otherGov = i + std::stoi(govs.ref[i]);
            if (otherGov == stackHead || stackGov == i)
              cost++;
          }
    
          if (stackGov != head)
            cost++;
    
          return parts.size() == 1 || labels.ref[stackHead] == parts[1] ? cost : cost+1;
        }
        else if (parts[0] == "RIGHT")
        {
          for (int j = 0; j < c.stackSize(); j++)
          {
            auto s = c.stackGetElem(j);
    
            if (s == c.stackTop())
              continue;
    
            int otherGov = s + std::stoi(govs.ref[s]);
            if (otherGov == head || headGov == s)
              cost++;
          }
    
          for (int i = head; i <= sentenceEnd; i++)
            if (headGov == i)
              cost++;
    
          return parts.size() == 1 || labels.ref[head] == parts[1] ? cost : cost+1;
        }
        else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
        {
          return eos.ref[stackHead] == ProgramParameters::sequenceDelimiter ? cost : cost+1;
        }
    
        return cost;
      })));
    
      str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle(
      [](Oracle * oracle)
      {
        File file(oracle->filename, "r");
        FILE * fd = file.getDescriptor();
        char b1[1024];
    
        if (fscanf(fd, "Default : %[^\n]\n", b1) == 1)
          oracle->data[b1] = "ok";
        while (fscanf(fd, "%[^\n]\n", b1) == 1)
          oracle->data[b1] = "ok";
      },
      [](Config & c, Oracle * oracle)
      {
        for (auto & it : oracle->data)
        {
            Action action(it.first);
            if (!action.appliable(c))
              continue;
            if (oracle->getActionCost(c, it.first) == 0)
              return it.first;
         }
    
        fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO);
    	c.printForDebug(stderr);
        for (auto & it : oracle->data)
        {
            Action action(it.first);
    		fprintf(stderr, "%s : ", action.name.c_str());
    		explainCostOfAction(stderr, c, action.name);
    		fprintf(stderr, "cost (%d)\n", oracle->getActionCost(c, it.first));
         }
        exit(1);
    
        return std::string("");
      },
      str2oracle["parser"]->getCost)));
    }
    
    void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & action)
    {
      auto parts = split(action);
      if (parts[0] == "WRITE")
      {
        if (parts.size() != 4)
        {
          fprintf(stderr, "Wrong number of action arguments\n");
          return;
        }
        auto object = split(parts[1], '.');
        auto tape = parts[2];
        auto label = parts[3];
        std::string expected;
    
        if (object[0] == "b")
        {
          int index = c.head + std::stoi(object[1]);
          expected = c.getTape(tape).ref[index];
        }
        else if (object[0] == "s")
        {
          int stackIndex = std::stoi(object[1]);
          int bufferIndex = c.stackGetElem(stackIndex) + c.head;
          expected = c.getTape(tape).ref[bufferIndex];
        }
        else
        {
          fprintf(stderr, "ERROR (%s) : wrong action name \'%s\'. Aborting.\n", ERRINFO, action.c_str());
          exit(1);
        }
        fprintf(output, "Wrong write (%s) expected (%s)\n", label.c_str(), expected.c_str());
        return;
      }
    
      auto & labels = c.getTape("LABEL");
      auto & govs = c.getTape("GOV");
      auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape);
    
      int head = c.head;
      int stackHead = c.stackEmpty() ? 0 : c.stackTop();
      int stackGov = stackHead + std::stoi(govs.ref[stackHead]);
      int headGov = head + std::stoi(govs.ref[head]);
      int sentenceStart = c.head-1 < 0 ? 0 : c.head-1;
      int sentenceEnd = c.head;
    
      while(sentenceStart >= 0 && eos.ref[sentenceStart] != ProgramParameters::sequenceDelimiter)
        sentenceStart--;
      if (sentenceStart != 0)
        sentenceStart++;
      while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != ProgramParameters::sequenceDelimiter)
        sentenceEnd++;
      if (sentenceEnd == (int)eos.ref.size())
        sentenceEnd--;
    
      if (parts[0] == "SHIFT")
      {
        for (int i = sentenceStart; i <= sentenceEnd; i++)
        {
          int otherGov = i + std::stoi(govs.ref[i]);
    
          for (int j = 0; j < c.stackSize(); j++)
          {
            auto s = c.stackGetElem(j);
            if (s == i)
            {
              if (otherGov == head)
              {
                fprintf(output, "Word on stack %d(%s)\'s governor is the current head\n", s, c.getTape("FORM").ref[s].c_str());
                return;
              }
              else if (headGov == s)
              {
                fprintf(output, "The current head\'s governor is on the stack %d(%s)\n", s, c.getTape("FORM").ref[s].c_str());
                return;
              }
            }
          }
        }
    
        if (eos.ref[stackHead] != ProgramParameters::sequenceDelimiter)
        {
          fprintf(output, "Zero cost\n");
          return;
        }
        else
        {
          fprintf(output, "The top of the stack is end of sentence\n");
        }
      }
      else if (parts[0] == "REDUCE")
      {
        for (int i = head; i <= sentenceEnd; i++)
        {
          int otherGov = i + std::stoi(govs.ref[i]);
          if (otherGov == stackHead)
          {
            fprintf(output, "Stack head is the governor of %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
            return;
          }
        }
    
        if (eos.ref[stackHead] != ProgramParameters::sequenceDelimiter)
        {
          fprintf(output, "Zero cost\n");
          return;
        }
        else
        {
          fprintf(output, "The top of the stack is end of sentence\n");
        }
      }
      else if (parts[0] == "LEFT")
      {
        if (parts.size() == 2 && stackGov == head && labels.ref[stackHead] == parts[1])
        {
          fprintf(output, "Zero cost\n");
          return;
        }
        if (parts.size() == 1 && stackGov == head )
        {
          fprintf(output, "Zero cost\n");
          return;
        }
    
        if (labels.ref[stackHead] != parts[1])
        {
          fprintf(output, "Stack head label %s mismatch with action label %s\n", labels.ref[stackHead].c_str(), parts[1].c_str());
    	  return;
        }
    
        for (int i = head; i <= sentenceEnd; i++)
        {
          int otherGov = i + std::stoi(govs.ref[i]);
          if (otherGov == stackHead)
          {
            fprintf(output, "Word %d(%s)\'s governor is the stack head\n", i, c.getTape("FORM").ref[i].c_str());
            return;
          }
          else if (stackGov == i)
          {
            fprintf(output, "Stack head\'s governor is the word %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
            return;
          }
        }
    
        if (parts.size() == 1)
        {
          fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO);
        }
    
        fprintf(output, "Unable to explain action\n");
        return;
      }
      else if (parts[0] == "RIGHT")
      {
        for (int j = 0; j < c.stackSize(); j++)
        {
          auto s = c.stackGetElem(j);
    
          if (s == c.stackTop())
            continue;
    
          int otherGov = s + std::stoi(govs.ref[s]);
          if (otherGov == head)
          {
            fprintf(output, "The governor of %d(%s) in the stack, is the current head\n", s, c.getTape("FORM").ref[s].c_str());
            return;
          }
          else if (headGov == s)
          {
            fprintf(output, "The current head's governor is the stack element %d(%s)\n", s, c.getTape("FORM").ref[s].c_str());
            return;
          }
        }
    
        for (int i = head; i <= sentenceEnd; i++)
          if (headGov == i)
          {
            fprintf(output, "The current head's governor is the future word %d(%s)\n", i, c.getTape("FORM").ref[i].c_str());
            return;
          }
    
        if (parts.size() == 1)
        {
          fprintf(output, "Zero cost\n");
          return;
        }
    
        if (labels.ref[head] == parts[1])
        {
          fprintf(output, "Zero cost\n");
          return;
        }
        else
        {
          fprintf(output, "Current head's label %s mismatch action label %s\n", labels.ref[head].c_str(), parts[1].c_str());
          return;
        }
      }
      else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
      {
        if (eos.ref[stackHead] == ProgramParameters::sequenceDelimiter)
        {
          fprintf(output, "Zero cost\n");
          return;     
        }
        else
        {
          fprintf(output, "The head of the stack is not end of sentence\n");
          return;
        }
      }
    
      fprintf(output, "Unknown reason\n");
    }