Skip to content
Snippets Groups Projects
Commit 33789c0f authored by Franck Dary's avatar Franck Dary
Browse files

Improved script to generate w2v based on lefff pos

parent 6a093c3e
No related branches found
No related tags found
No related merge requests found
......@@ -7,11 +7,15 @@
# Example: ./lefff2w2v --conllu data/UD_French-GSD/*\.conllu
# We can chose to output binary vector with the option --binary which is a threshold above which values will become 1.
# We can ignore infrequent words in conllu by setting a threshold with --minfreq.
# We can modulate the impact of the lefff with the parameter --lefffweight.
import sys
import argparse
from readMCD import readMCD
# List of UD POS tags : https://universaldependencies.org/u/pos/index.html
allPos = ["adj", "adp", "adv", "aux", "cconj", "det", "intj", "noun", "num", "part", "pron", "propn", "punct", "sconj", "sym", "verb", "x"]
# Convert lefff part of speech into UD UPOS.
lefffPOS2UD = {
"adj" : "adj",
......@@ -78,8 +82,6 @@ if __name__ == "__main__" :
# Dict with key=FORM and value= dict associationg pos with number of occ
form2pos = {}
# List of all pos (UD format) present in data
allPos = []
# Associate each form with a counter, only for conllu files
formCount = {}
......@@ -92,9 +94,11 @@ if __name__ == "__main__" :
# In lefff there might be spaces in forms. W2v format don't allow it. We replace space by dotted circle.
form.replace(" ", "")
if pos not in allPos :
allPos.append(pos)
print("ERROR: Unknown pos '%s' (check allPos in the script)"%pos, file=sys.stderr)
if form not in form2pos :
form2pos[form] = {}
if form not in formCount :
formCount[form] = 0
if pos not in form2pos[form] :
form2pos[form][pos] = args.lefffWeight
......@@ -113,11 +117,14 @@ if __name__ == "__main__" :
if len(line) == 0 or line[0] == "#" :
continue
splited = line.split("\t")
wordId = splited[conllMCD["ID"]].lower()
if "-" in wordId :
continue
form = splited[conllMCD["FORM"]].lower()
pos = splited[conllMCD["UPOS"]].lower()
form.replace(" ", "")
if pos not in allPos :
allPos.append(pos)
print("ERROR: Unknown pos '%s' (check allPos in the script)"%pos, file=sys.stderr)
if form not in form2pos :
form2pos[form] = {}
if pos not in form2pos[form] :
......@@ -129,6 +136,9 @@ if __name__ == "__main__" :
outputLines = []
# To check if all pos are represented in our embeddings
usedPos = set()
# Compute probability for each pos and form
for form in form2pos :
if args.minfreq is not None and formCount[form] < args.minfreq :
......@@ -141,15 +151,18 @@ if __name__ == "__main__" :
vec[allPos.index(pos)] = form2pos[form][pos] / totalOccs
baseVec = vec.copy()
for pos in form2pos[form] :
posIndex = allPos.index(pos)
if args.binary is not None :
if vec[allPos.index(pos)] >= args.binary :
vec[allPos.index(pos)] = 1
if vec[posIndex] >= args.binary :
vec[posIndex] = 1
else :
vec[allPos.index(pos)] = 0
vec[posIndex] = 0
if vec[posIndex] > 0.0 :
usedPos.add(posIndex)
if args.binary is not None :
vec[allPos.index(pos)] = "%d"%vec[allPos.index(pos)]
vec[posIndex] = "%d"%vec[posIndex]
else :
vec[allPos.index(pos)] = "%.2f"%vec[allPos.index(pos)]
vec[posIndex] = "%.2f"%vec[posIndex]
if sum(map(float, vec)) == 0 :
print("WARNING: word '%s' gets all 0. Original: '%s'"%(form, " ".join(map(str,baseVec))), file=sys.stderr)
outputLines.append(form+" "+" ".join(vec))
......@@ -159,3 +172,8 @@ if __name__ == "__main__" :
outputLines.sort()
print("\n".join(outputLines))
# Check unused pos
for posIndex in range(len(allPos)) :
if posIndex not in usedPos :
print("WARNING: unused POS '%s'"%allPos[posIndex], file=sys.stderr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment