Skip to content
Snippets Groups Projects
lefff2w2v.py 6.20 KiB
#! /usr/bin/env python3
# Create a w2v formatted embedding file.
# Each line associate a lowercase word with an embedding whose dimmensions are the UD POS.
# The input to this script is a combination of lefff lexicon and conllu UD corpora.
# Example: ./lefff2w2v --lefff lefff.fplm --conllu data/UD_French-GSD/*\.conllu
# Example: ./lefff2w2v --lefff lefff.fplm
# Example: ./lefff2w2v --conllu data/UD_French-GSD/*\.conllu
# We can chose to output binary vector with the option --binary which is a threshold above which values will become 1.
# We can ignore infrequent words in conllu by setting a threshold with --minfreq.
# We can modulate the impact of the lefff with the parameter --lefffweight.

import sys
import argparse
from readMCD import readMCD

# List of UD POS tags : https://universaldependencies.org/u/pos/index.html
allPos = ["adj", "adp", "adv", "aux", "cconj", "det", "intj", "noun", "num", "part", "pron", "propn", "punct", "sconj", "sym", "verb", "x"]

# Convert lefff part of speech into UD UPOS.
lefffPOS2UD = {
  "adj" : "adj",
  "csu" : "sconj",
  "que" : "sconj", # Not only ?
  "det" : "det",
  "pres" : "intj", # Nothing match ? INTJ or X
  "v" : "verb",
  "nc" : "noun",
  "cfi" : "noun",
  "advPref" : "x", # No meaning with UD tokenization
  "adjPref" : "x", # same
  "suffAdj" : "x", # same
  "cln" : "pron",
  "ce" : "pron",
  "clg" : "adp",
  "cll" : "pron",
  "ilimp" : "pron",
  "cla" : "pron",
  "cld" : "pron",
  "pro" : "pron",
  "caimp" : "pron",
  "pri" : "adv",
  "prel" : "pron",
  "clr" : "pron",
  "clar" : "pron",
  "cldr" : "pron",
  "adv" : "adv",
  "advm" : "adv",
  "advp" : "adv",
  "coo" : "cconj",
  "ponctw" : "punct",
  "advneg" : "adv",
  "clneg" : "adv",
  "que_restr" : "sconj",
  "np" : "propn",
  "poncts" : "punct",
  "parento" : "punct",
  "epsilon" : "punct",
  "parentf" : "punct",
  "prep" : "adp",
  "auxAvoir" : "aux",
  "auxEtre" : "aux",
}

if __name__ == "__main__" :
  parser = argparse.ArgumentParser()
  parser.add_argument("--lefff", type=str,
    help="Lefff file in tab separated columns: FORM POS LEMMA MORPHO.")
  parser.add_argument("--conllu", nargs="+", type=str,
    help="Conllu files to estimate the probability of each POS.")
  parser.add_argument("--binary", type=float,
    help="A threshold in [0,1] that will separate zeroes from ones.")
  parser.add_argument("--minfreq", type=int,
    help="A threshold in number of occurrences of words.")
  parser.add_argument("--lefffWeight", type=int, default=1,
    help="What is the weight, in number of occurrences of the couple (form,POS) in annotated conllu data, that the lefff add ?")

  args = parser.parse_args()

  if args.lefff is None and args.conllu is None :
    print("ERROR: must provide --lefff and/or --conllu", file=sys.stderr)
    exit(1)

  # Dict with key=FORM and value= dict associationg pos with number of occ 
  form2pos = {}
  # Associate each form with a counter, only for conllu files
  formCount = {}

  # Read lefff and populate form2pos with # of occ = 1
  if args.lefff is not None :
    for line in open(args.lefff, "r") :
      splited = line.strip().split("\t")
      form = splited[0].lower()
      pos = lefffPOS2UD[splited[1]]
      # In lefff there might be spaces in forms. W2v format don't allow it. We replace space by dotted circle.
      form.replace(" ", "")
      if pos not in allPos :
        print("ERROR: Unknown pos '%s' (check allPos in the script)"%pos, file=sys.stderr)
      if form not in form2pos :
        form2pos[form] = {}
      if form not in formCount :
        formCount[form] = 0
      if pos not in form2pos[form] :
        form2pos[form][pos] = args.lefffWeight

  # If conllu files are provided, count number of occurences into form2pos
  if args.conllu is not None :
    if args.conllu is not None :
      for filename in args.conllu :
        baseMCD = "ID FORM LEMMA POS XPOS FEATS HEAD DEPREL"
        conllMCD, conllMCDr = readMCD(baseMCD)
        for line in open(filename, "r") :
          line = line.strip()
          if "global.columns =" in line and line[0] == "#" :
            splited = line.split("global.columns =")
            conllMCD, conllMCDr = readMCD(splited[-1].strip())
            continue
          if len(line) == 0 or line[0] == "#" :
            continue
          splited = line.split("\t")
          wordId = splited[conllMCD["ID"]].lower()
          if "-" in wordId :
            continue
          form = splited[conllMCD["FORM"]].lower()
          pos = splited[conllMCD["UPOS"]].lower()
          form.replace(" ", "")
          if pos not in allPos :
            print("ERROR: Unknown pos '%s' (check allPos in the script)"%pos, file=sys.stderr)
          if form not in form2pos :
            form2pos[form] = {}
          if pos not in form2pos[form] :
            form2pos[form][pos] = 0
          form2pos[form][pos] += 1
          if form not in formCount :
            formCount[form] = 0
          formCount[form] += 1

  outputLines = []

  # To check if all pos are represented in our embeddings
  usedPos = set()

  # Compute probability for each pos and form
  for form in form2pos :
    if args.minfreq is not None and formCount[form] < args.minfreq :
      continue
    vec = ["0" for _ in allPos]
    totalOccs = 0
    for pos in form2pos[form] :
      totalOccs += form2pos[form][pos]
    for pos in form2pos[form] :
      vec[allPos.index(pos)] = form2pos[form][pos] / totalOccs
    baseVec = vec.copy()
    for pos in form2pos[form] :
      posIndex = allPos.index(pos)
      if args.binary is not None :
        if vec[posIndex] >= args.binary :
          vec[posIndex] = 1
        else :
          vec[posIndex] = 0
      if vec[posIndex] > 0.0 :
        usedPos.add(posIndex)
      if args.binary is not None :
        vec[posIndex] = "%d"%vec[posIndex]
      else :
        vec[posIndex] = "%.2f"%vec[posIndex]
    if sum(map(float, vec)) == 0 :
      print("WARNING: word '%s' gets all 0. Original: '%s'"%(form, " ".join(map(str,baseVec))), file=sys.stderr)
    outputLines.append(form+" "+" ".join(vec))

  # Print the w2v file
  print(len(outputLines), len(allPos))
  outputLines.sort()
  print("\n".join(outputLines))

  # Check unused pos
  for posIndex in range(len(allPos)) :
    if posIndex not in usedPos :
      print("WARNING: unused POS '%s'"%allPos[posIndex], file=sys.stderr)