diff --git a/scripts/meanSameEmbeddings.py b/scripts/meanSameEmbeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..992995cf73f9dbd88e7b602c1139c8df4811701a --- /dev/null +++ b/scripts/meanSameEmbeddings.py @@ -0,0 +1,88 @@ +#! /usr/bin/env python3 + +import sys + +################################################################################ +def printUsageAndExit() : + print("USAGE : %s embeddings.w2v"%sys.argv[0], file=sys.stderr) + exit(1) +################################################################################ + +################################################################################ +if __name__ == "__main__" : + if len(sys.argv) != 2 : + printUsageAndExit() + + dim = None + + words = [] + + # Read the file, save words and their positions in the file + fp = open(sys.argv[1]) + while True : + curPos = fp.tell() + line = fp.readline().strip() + if len(line) == 0 : + break + + splited = line.split(' ') + + # Ignores potential w2v header + if len(splited) == 2 : + continue + + if dim == None : + dim = len(splited) + + if len(splited) != dim : + print("ERROR : line number %d wrong number of columns. Had %d instead of %d.\n%s\n"%(nbLines, len(splited), dim, line), file=sys.stderr) + exit(1) + + # Words are prefixed by language id + ws = splited[0].split('_') + if len(ws) < 2 : + print("ERROR : line number %d bad format for word. '%s'.\n%s\n"%(nbLines, splited[0], line), file=sys.stderr) + exit(1) + + lang = ws[0] + word = "_".join(ws[1:]) + + words.append([word, curPos]) + + # Sort words to group same words + words.sort() + + # Mean vectors in linear time + i = 0 + while True : + sameLines = [] + j = i + while j in range(len(words)) and words[i][0] == words[j][0] : + sameLines.append(words[j][1]) + j += 1 + + vector = [] + for linePos in sameLines : + fp.seek(linePos) + line = fp.readline().strip() + thisvec = list(map(float, line.split(' ')[1:])) + if len(vector) == 0 : + vector = thisvec + else : + if len(vector) != len(thisvec) : + print("ERROR : can't add vectors of different lengths for word '%s'."%words[i][0], file=sys.stderr) + exit(1) + for k in range(len(vector)) : + vector[k] += thisvec[k] + for k in range(len(vector)) : + vector[k] /= len(sameLines) + + vector = ["%.6f"%val for val in vector] + print("%s %s"%(words[i][0], " ".join(vector))) + i = j + if j not in range(len(words)) : + break + + fp.close() +################################################################################ +