Skip to content
Snippets Groups Projects
Select Git revision
  • cdfa91d8bf3dfa851ab2d95b18cef206a4421f33
  • master default protected
  • erased
  • states
  • negatives
  • temp
  • negativeExamples
  • Rl
8 results

Transition.py

Blame
  • meanSameEmbeddings.py 4.36 KiB
    #! /usr/bin/env python3
    
    import sys
    from datetime import datetime
    
    ################################################################################
    def printUsageAndExit() :
      print("USAGE : %s embeddings.w2v"%sys.argv[0], file=sys.stderr)
      exit(1)
    ################################################################################
    
    ################################################################################
    def meanSameEmbeddingsSmallMemory(filename) :
      dim = None
    
      words = []
    
      # Read the file, save words and their positions in the file
      fp = open(filename)
      while True :
        curPos = fp.tell()
        line = fp.readline().strip()
        if len(line) == 0 :
          break
    
        splited = line.split(' ')
    
        # Ignores potential w2v header
        if len(splited) == 2 :
          continue
    
        if dim == None :
          dim = len(splited)
    
        if len(splited) != dim :
          print("ERROR : line number %d wrong number of columns. Had %d instead of %d.\n%s\n"%(nbLines, len(splited), dim, line), file=sys.stderr)
          exit(1)
    
        # Words are prefixed by language id
        ws = splited[0].split('_')
        if len(ws) < 2 :
          print("ERROR : line number %d bad format for word. '%s'.\n%s\n"%(nbLines, splited[0], line), file=sys.stderr)
          exit(1)
    
        lang = ws[0]
        word = "_".join(ws[1:])
    
        words.append([word.strip(), curPos])
    
      # Sort words to group same words
      words.sort()
    
      # Mean vectors in linear time
      i = 0
      while True :
        if i % 10000 == 0 :
          print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr)
        sameLines = []
        j = i
        while j in range(len(words)) and words[i][0] == words[j][0] :
          sameLines.append(words[j][1])
          j += 1
    
        vector = []
        for linePos in sameLines :
          fp.seek(linePos)
          line = fp.readline().strip()
          thisvec = list(map(float, line.split(' ')[1:]))
          if len(vector) == 0 :
            vector = thisvec
          else :
            if len(vector) != len(thisvec) :
              print("ERROR : can't add vectors of different lengths for word '%s'."%words[i][0], file=sys.stderr)
              exit(1)
            for k in range(len(vector)) :
              vector[k] += thisvec[k]
        for k in range(len(vector)) :
          vector[k] /= len(sameLines)
    
        vector = ["%.6f"%val for val in vector]
        print("%s %s"%(words[i][0], " ".join(vector)))
        i = j
        if j not in range(len(words)) :
          break
      
      fp.close()
    ################################################################################
    
    ################################################################################
    def meanSameEmbeddingsBigMemory(filename) :
      dim = None
    
      words = []
    
      # Read the file, save words and their positions in the file
      fp = open(filename)
      while True :
        line = fp.readline().strip()
        if len(line) == 0 :
          break
    
        splited = line.split(' ')
    
        # Ignores potential w2v header
        if len(splited) == 2 :
          continue
    
        if dim == None :
          dim = len(splited)
    
        if len(splited) != dim :
          print("ERROR : line number %d wrong number of columns. Had %d instead of %d.\n%s\n"%(nbLines, len(splited), dim, line), file=sys.stderr)
          exit(1)
    
        # Words are prefixed by language id
        ws = splited[0].split('_')
        if len(ws) < 2 :
          print("ERROR : line number %d bad format for word. '%s'.\n%s\n"%(nbLines, splited[0], line), file=sys.stderr)
          exit(1)
    
        lang = ws[0]
        word = "_".join(ws[1:])
    
        words.append([word.strip(), map(float,splited[1:])])
    
      # Sort words to group same words
      words.sort()
    
      # Mean vectors in linear time
      i = 0
      while True :
        if i % 10000 == 0 :
          print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr)
        vector = words[i][1]
        j = i+1
        while j in range(len(words)) and words[i][0] == words[j][0] :
          for k in range(len(vector)) :
            vector[k] += words[j][1][k]
          j += 1
    
        for k in range(len(vector)) :
          vector[k] /= j - i
    
        vector = ["%.6f"%val for val in vector]
        print("%s %s"%(words[i][0], " ".join(vector)))
        i = j
        if j not in range(len(words)) :
          break
      
      fp.close()
    ################################################################################
    
    ################################################################################
    if __name__ == "__main__" :
      if len(sys.argv) != 2 :
        printUsageAndExit()
    
      meanSameEmbeddingsSmallMemory(sys.argv[1])
    ################################################################################