Skip to content
Snippets Groups Projects
Commit c9ca5a3c authored by Franck Dary's avatar Franck Dary
Browse files

Improved meanEmbeddings

parent c26fc998
No related branches found
No related tags found
No related merge requests found
#! /usr/bin/env python3
import sys
from datetime import datetime
################################################################################
def printUsageAndExit() :
......@@ -9,16 +10,13 @@ def printUsageAndExit() :
################################################################################
################################################################################
if __name__ == "__main__" :
if len(sys.argv) != 2 :
printUsageAndExit()
def meanSameEmbeddingsSmallMemory(filename) :
dim = None
words = []
# Read the file, save words and their positions in the file
fp = open(sys.argv[1])
fp = open(filename)
while True :
curPos = fp.tell()
line = fp.readline().strip()
......@@ -47,7 +45,7 @@ if __name__ == "__main__" :
lang = ws[0]
word = "_".join(ws[1:])
words.append([word, curPos])
words.append([word.strip(), curPos])
# Sort words to group same words
words.sort()
......@@ -55,6 +53,8 @@ if __name__ == "__main__" :
# Mean vectors in linear time
i = 0
while True :
if i % 10000 == 0 :
print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr)
sameLines = []
j = i
while j in range(len(words)) and words[i][0] == words[j][0] :
......@@ -86,3 +86,74 @@ if __name__ == "__main__" :
fp.close()
################################################################################
################################################################################
def meanSameEmbeddingsBigMemory(filename) :
dim = None
words = []
# Read the file, save words and their positions in the file
fp = open(filename)
while True :
line = fp.readline().strip()
if len(line) == 0 :
break
splited = line.split(' ')
# Ignores potential w2v header
if len(splited) == 2 :
continue
if dim == None :
dim = len(splited)
if len(splited) != dim :
print("ERROR : line number %d wrong number of columns. Had %d instead of %d.\n%s\n"%(nbLines, len(splited), dim, line), file=sys.stderr)
exit(1)
# Words are prefixed by language id
ws = splited[0].split('_')
if len(ws) < 2 :
print("ERROR : line number %d bad format for word. '%s'.\n%s\n"%(nbLines, splited[0], line), file=sys.stderr)
exit(1)
lang = ws[0]
word = "_".join(ws[1:])
words.append([word.strip(), map(float,splited[1:])])
# Sort words to group same words
words.sort()
# Mean vectors in linear time
i = 0
while True :
if i % 10000 == 0 :
print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr)
vector = words[i][1]
j = i+1
while j in range(len(words)) and words[i][0] == words[j][0] :
for k in range(len(vector)) :
vector[k] += words[j][1][k]
j += 1
for k in range(len(vector)) :
vector[k] /= j - i
vector = ["%.6f"%val for val in vector]
print("%s %s"%(words[i][0], " ".join(vector)))
i = j
if j not in range(len(words)) :
break
fp.close()
################################################################################
################################################################################
if __name__ == "__main__" :
if len(sys.argv) != 2 :
printUsageAndExit()
meanSameEmbeddingsSmallMemory(sys.argv[1])
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment