diff --git a/scripts/meanSameEmbeddings.py b/scripts/meanSameEmbeddings.py index 992995cf73f9dbd88e7b602c1139c8df4811701a..de776a4c5090150d72975567fe244d13d5da9faa 100755 --- a/scripts/meanSameEmbeddings.py +++ b/scripts/meanSameEmbeddings.py @@ -1,6 +1,7 @@ #! /usr/bin/env python3 import sys +from datetime import datetime ################################################################################ def printUsageAndExit() : @@ -9,16 +10,13 @@ def printUsageAndExit() : ################################################################################ ################################################################################ -if __name__ == "__main__" : - if len(sys.argv) != 2 : - printUsageAndExit() - +def meanSameEmbeddingsSmallMemory(filename) : dim = None words = [] # Read the file, save words and their positions in the file - fp = open(sys.argv[1]) + fp = open(filename) while True : curPos = fp.tell() line = fp.readline().strip() @@ -47,7 +45,7 @@ if __name__ == "__main__" : lang = ws[0] word = "_".join(ws[1:]) - words.append([word, curPos]) + words.append([word.strip(), curPos]) # Sort words to group same words words.sort() @@ -55,6 +53,8 @@ if __name__ == "__main__" : # Mean vectors in linear time i = 0 while True : + if i % 10000 == 0 : + print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr) sameLines = [] j = i while j in range(len(words)) and words[i][0] == words[j][0] : @@ -86,3 +86,74 @@ if __name__ == "__main__" : fp.close() ################################################################################ +################################################################################ +def meanSameEmbeddingsBigMemory(filename) : + dim = None + + words = [] + + # Read the file, save words and their positions in the file + fp = open(filename) + while True : + line = fp.readline().strip() + if len(line) == 0 : + break + + splited = line.split(' ') + + # Ignores potential w2v header + if len(splited) == 2 : + continue + + if dim == None : + dim = len(splited) + + if len(splited) != dim : + print("ERROR : line number %d wrong number of columns. Had %d instead of %d.\n%s\n"%(nbLines, len(splited), dim, line), file=sys.stderr) + exit(1) + + # Words are prefixed by language id + ws = splited[0].split('_') + if len(ws) < 2 : + print("ERROR : line number %d bad format for word. '%s'.\n%s\n"%(nbLines, splited[0], line), file=sys.stderr) + exit(1) + + lang = ws[0] + word = "_".join(ws[1:]) + + words.append([word.strip(), map(float,splited[1:])]) + + # Sort words to group same words + words.sort() + + # Mean vectors in linear time + i = 0 + while True : + if i % 10000 == 0 : + print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr) + vector = words[i][1] + j = i+1 + while j in range(len(words)) and words[i][0] == words[j][0] : + for k in range(len(vector)) : + vector[k] += words[j][1][k] + j += 1 + + for k in range(len(vector)) : + vector[k] /= j - i + + vector = ["%.6f"%val for val in vector] + print("%s %s"%(words[i][0], " ".join(vector))) + i = j + if j not in range(len(words)) : + break + + fp.close() +################################################################################ + +################################################################################ +if __name__ == "__main__" : + if len(sys.argv) != 2 : + printUsageAndExit() + + meanSameEmbeddingsSmallMemory(sys.argv[1]) +################################################################################