Skip to content
Snippets Groups Projects
Commit c26fc998 authored by Franck Dary's avatar Franck Dary
Browse files

Added script to compute mean embeddings

parent 3c53ce92
No related branches found
No related tags found
No related merge requests found
#! /usr/bin/env python3
import sys
################################################################################
def printUsageAndExit() :
print("USAGE : %s embeddings.w2v"%sys.argv[0], file=sys.stderr)
exit(1)
################################################################################
################################################################################
if __name__ == "__main__" :
if len(sys.argv) != 2 :
printUsageAndExit()
dim = None
words = []
# Read the file, save words and their positions in the file
fp = open(sys.argv[1])
while True :
curPos = fp.tell()
line = fp.readline().strip()
if len(line) == 0 :
break
splited = line.split(' ')
# Ignores potential w2v header
if len(splited) == 2 :
continue
if dim == None :
dim = len(splited)
if len(splited) != dim :
print("ERROR : line number %d wrong number of columns. Had %d instead of %d.\n%s\n"%(nbLines, len(splited), dim, line), file=sys.stderr)
exit(1)
# Words are prefixed by language id
ws = splited[0].split('_')
if len(ws) < 2 :
print("ERROR : line number %d bad format for word. '%s'.\n%s\n"%(nbLines, splited[0], line), file=sys.stderr)
exit(1)
lang = ws[0]
word = "_".join(ws[1:])
words.append([word, curPos])
# Sort words to group same words
words.sort()
# Mean vectors in linear time
i = 0
while True :
sameLines = []
j = i
while j in range(len(words)) and words[i][0] == words[j][0] :
sameLines.append(words[j][1])
j += 1
vector = []
for linePos in sameLines :
fp.seek(linePos)
line = fp.readline().strip()
thisvec = list(map(float, line.split(' ')[1:]))
if len(vector) == 0 :
vector = thisvec
else :
if len(vector) != len(thisvec) :
print("ERROR : can't add vectors of different lengths for word '%s'."%words[i][0], file=sys.stderr)
exit(1)
for k in range(len(vector)) :
vector[k] += thisvec[k]
for k in range(len(vector)) :
vector[k] /= len(sameLines)
vector = ["%.6f"%val for val in vector]
print("%s %s"%(words[i][0], " ".join(vector)))
i = j
if j not in range(len(words)) :
break
fp.close()
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment