Skip to content
Snippets Groups Projects
Commit c9ca5a3c authored by Franck Dary's avatar Franck Dary
Browse files

Improved meanEmbeddings

parent c26fc998
No related branches found
No related tags found
No related merge requests found
#! /usr/bin/env python3 #! /usr/bin/env python3
import sys import sys
from datetime import datetime
################################################################################ ################################################################################
def printUsageAndExit() : def printUsageAndExit() :
...@@ -9,16 +10,13 @@ def printUsageAndExit() : ...@@ -9,16 +10,13 @@ def printUsageAndExit() :
################################################################################ ################################################################################
################################################################################ ################################################################################
if __name__ == "__main__" : def meanSameEmbeddingsSmallMemory(filename) :
if len(sys.argv) != 2 :
printUsageAndExit()
dim = None dim = None
words = [] words = []
# Read the file, save words and their positions in the file # Read the file, save words and their positions in the file
fp = open(sys.argv[1]) fp = open(filename)
while True : while True :
curPos = fp.tell() curPos = fp.tell()
line = fp.readline().strip() line = fp.readline().strip()
...@@ -47,7 +45,7 @@ if __name__ == "__main__" : ...@@ -47,7 +45,7 @@ if __name__ == "__main__" :
lang = ws[0] lang = ws[0]
word = "_".join(ws[1:]) word = "_".join(ws[1:])
words.append([word, curPos]) words.append([word.strip(), curPos])
# Sort words to group same words # Sort words to group same words
words.sort() words.sort()
...@@ -55,6 +53,8 @@ if __name__ == "__main__" : ...@@ -55,6 +53,8 @@ if __name__ == "__main__" :
# Mean vectors in linear time # Mean vectors in linear time
i = 0 i = 0
while True : while True :
if i % 10000 == 0 :
print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr)
sameLines = [] sameLines = []
j = i j = i
while j in range(len(words)) and words[i][0] == words[j][0] : while j in range(len(words)) and words[i][0] == words[j][0] :
...@@ -86,3 +86,74 @@ if __name__ == "__main__" : ...@@ -86,3 +86,74 @@ if __name__ == "__main__" :
fp.close() fp.close()
################################################################################ ################################################################################
################################################################################
def meanSameEmbeddingsBigMemory(filename) :
dim = None
words = []
# Read the file, save words and their positions in the file
fp = open(filename)
while True :
line = fp.readline().strip()
if len(line) == 0 :
break
splited = line.split(' ')
# Ignores potential w2v header
if len(splited) == 2 :
continue
if dim == None :
dim = len(splited)
if len(splited) != dim :
print("ERROR : line number %d wrong number of columns. Had %d instead of %d.\n%s\n"%(nbLines, len(splited), dim, line), file=sys.stderr)
exit(1)
# Words are prefixed by language id
ws = splited[0].split('_')
if len(ws) < 2 :
print("ERROR : line number %d bad format for word. '%s'.\n%s\n"%(nbLines, splited[0], line), file=sys.stderr)
exit(1)
lang = ws[0]
word = "_".join(ws[1:])
words.append([word.strip(), map(float,splited[1:])])
# Sort words to group same words
words.sort()
# Mean vectors in linear time
i = 0
while True :
if i % 10000 == 0 :
print("%s : %.2f%%"%(datetime.now(), 100.0*i/len(words)), file=sys.stderr)
vector = words[i][1]
j = i+1
while j in range(len(words)) and words[i][0] == words[j][0] :
for k in range(len(vector)) :
vector[k] += words[j][1][k]
j += 1
for k in range(len(vector)) :
vector[k] /= j - i
vector = ["%.6f"%val for val in vector]
print("%s %s"%(words[i][0], " ".join(vector)))
i = j
if j not in range(len(words)) :
break
fp.close()
################################################################################
################################################################################
if __name__ == "__main__" :
if len(sys.argv) != 2 :
printUsageAndExit()
meanSameEmbeddingsSmallMemory(sys.argv[1])
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment