Skip to content
Snippets Groups Projects
Select Git revision
  • 832d5450301832b31646b81ff46bedf48a6bcd87
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

format_dataset.py

Blame
  • Config.py 6.58 KiB
    from readMCD import readMCD
    import sys
    
    ################################################################################
    class Config :
      def __init__(self, col2index, index2col, predicted) :
        self.lines = []
        self.goldChilds = []
        self.predChilds = []
        self.col2index = col2index
        self.index2col = index2col
        self.predicted = predicted
        self.wordIndex = 0
        self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack
        self.stack = []
        self.comments = []
        self.history = []
        self.historyHistory = set()
        self.historyPop = []
      
      def addLine(self, cols) :
        self.lines.append([[val,""] for val in cols])
        self.goldChilds.append([])
        self.predChilds.append([])
    
      def get(self, lineIndex, colname, predicted) :
        if lineIndex not in range(len(self.lines)) :
          raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines))))
        if colname not in self.col2index :
          print("Unknown colname '%s'"%(colname), file=sys.stderr)
          exit(1)
        index = 1 if predicted else 0
        return self.lines[lineIndex][self.col2index[colname]][index]
    
      def set(self, lineIndex, colname, value, predicted=True) :
        if lineIndex not in range(len(self.lines)) :
          raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines))))
        if colname not in self.col2index :
          print("Unknown colname '%s'"%(colname), file=sys.stderr)
          exit(1)
        index = 1 if predicted else 0
        self.lines[lineIndex][self.col2index[colname]][index] = value
    
      def getAsFeature(self, lineIndex, colname) :
        return self.get(lineIndex, colname, colname in self.predicted)
    
      def getGold(self, lineIndex, colname) :
        return self.get(lineIndex, colname, False)
    
      def addWordIndexToStack(self) :
        self.stack.append(self.wordIndex)
    
      def popStack(self) :
        return self.stack.pop()
    
      # Move wordIndex by a relative forward movement if possible. Ignore multiwords.
      # Don't go out of bounds, but don't fail either.
      # Return true if movement was completed.
      def moveWordIndex(self, movement) :
        done = 0
        relMov = 1 if movement == 0 else movement // abs(movement)
        if self.isMultiword(self.wordIndex) :
          self.wordIndex += relMov
        while done != abs(movement) :
          if self.wordIndex+relMov in range(0, len((self.lines))) :
            self.wordIndex += relMov
          else :
            self.maxWordIndex = max(self.maxWordIndex, self.wordIndex)
            return False
          if self.isMultiword(self.wordIndex) :
            self.wordIndex += relMov
          done += 1
        self.maxWordIndex = max(self.maxWordIndex, self.wordIndex)
        return True
    
      def isMultiword(self, index) :
        return "-" in self.getAsFeature(index, "ID")
    
      def __len__(self) :
        return len(self.lines)
    
      def printForDebug(self, output) :
        printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"]
        left = 5
        right = 5
        print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
        print("history :",[str(trans) for trans in self.history[-10:]], file=output)
        print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output)
        toPrint = []
        for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
          if lineIndex not in range(len(self.lines)) :
            continue
          toPrint.append(["%s"%("=>" if lineIndex == self.wordIndex else "  ")])
          for colIndex in range(len(self.lines[lineIndex])) :
            if self.index2col[colIndex] not in printedCols :
              continue
            value = str(self.getAsFeature(lineIndex, self.index2col[colIndex]))
            if value == "" :
              value = "_"
            elif self.index2col[colIndex] == "HEAD" and value != "-1":
              value = self.getAsFeature(int(value), "ID")
            elif self.index2col[colIndex] == "HEAD" and value == "-1":
              value = "0"
            toPrint[-1].append(value)
        maxCol = [max([len(toPrint[i][j]) for i in range(len(toPrint))]) for j in range(len(toPrint[0]))]
        for i in range(len(toPrint)) :
          for j in range(len(toPrint[i])) :
            toPrint[i][j] = "{:{}}".format(toPrint[i][j], maxCol[j])
          toPrint[i] = toPrint[i][0]+" ".join(toPrint[i][1:])
        print("\n".join(toPrint), file=output)
    
      def print(self, output, header=False) :
        if header :
          print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
        if len(self.comments) > 0 :
          print("\n".join(self.comments), file=output)
        for index in range(len(self.lines)) :
          toPrint = []
          for colIndex in range(len(self.lines[index])) :
            value = str(self.getAsFeature(index, self.index2col[colIndex]))
            if value == "" :
              value = "_"
            elif self.index2col[colIndex] == "HEAD" and value != "-1":
              value = self.getAsFeature(int(value), "ID")
            elif self.index2col[colIndex] == "HEAD" and value == "-1":
              value = "0"
            toPrint.append(value)
          print("\t".join(toPrint), file=output)
        print("", file=output)
    ################################################################################
      
    ################################################################################
    def readConllu(filename, predicted) :
      configs = []
      defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
      col2index, index2col = readMCD(defaultMCD)
      currentIndex = 0
      id2index = {}
      comments = []
    
      for line in open(filename, "r") :
        line = line.strip()
        if "# global.columns =" in line :
          mcd = line.split('=')[-1].strip()
          col2index, index2col = readMCD(mcd)
          continue
        if len(line) == 0 :
          for index in range(len(configs[-1])) :
            head = configs[-1].getGold(index, "HEAD")
            if head == "_" :
              continue
            if head == "0" :
              continue
            configs[-1].set(index, "HEAD", id2index[head], False)
            configs[-1].goldChilds[int(id2index[head])].append(index)
    
          configs[-1].comments = comments
    
          configs.append(Config(col2index, index2col, predicted))
          currentIndex = 0
          id2index = {}
          comments = []
    
          continue
        if line[0] == '#' :
          comments.append(line)
          continue
    
        if len(configs) == 0 :
          configs.append(Config(col2index, index2col, predicted))
          currentIndex = 0
          id2index = {}
    
        splited = line.split('\t')
    
        ID = splited[col2index["ID"]]
        if '.' in ID :
          continue
    
        configs[-1].addLine(splited)
        ID = configs[-1].getGold(currentIndex, "ID")
        id2index[ID] = currentIndex
        currentIndex += 1
    
      if len(configs[-1]) == 0 :
        configs.pop()
    
      return configs
    ################################################################################