From 4a7dabd5a1c3259140a46629f37d394aa6b1828f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 20 May 2022 11:23:00 +0200
Subject: [PATCH] Fixed embeddings generation

---
 Networks.py                                           | 6 +++++-
 embeddings/lefffEmbeddings/generateLefffEmbeddings.py | 6 +++---
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/Networks.py b/Networks.py
index 5612b99..32daab1 100644
--- a/Networks.py
+++ b/Networks.py
@@ -3,6 +3,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 import Features
 import Transition
+from Util import prettyInt
 
 ################################################################################
 def readPretrainedSize(w2vFile) :
@@ -13,16 +14,19 @@ def readPretrainedSize(w2vFile) :
 ################################################################################
 def loadW2v(w2vFile, weights, dicts, colname) :
   size = None
+  nbLoaded = 0
   for line in open(w2vFile, "r") :
     line = line.strip()
     if size is None :
       size = int(line.split()[1])
       continue
     splited = line.split()
-    word = " ".join(splited[0:len(splited)-size])
+    word = " ".join(splited[0:len(splited)-size]).replace("◌"," ")
     emb = torch.tensor(list(map(float,splited[len(splited)-size:])))
 
     weights[dicts.get(colname, word)] = emb
+    nbLoaded += 1
+  print("Loaded %s pretrained embeddings for '%s' from '%s'"%(prettyInt(nbLoaded,3), colname, w2vFile))
 ################################################################################
 
 ################################################################################
diff --git a/embeddings/lefffEmbeddings/generateLefffEmbeddings.py b/embeddings/lefffEmbeddings/generateLefffEmbeddings.py
index edfd35e..0303685 100755
--- a/embeddings/lefffEmbeddings/generateLefffEmbeddings.py
+++ b/embeddings/lefffEmbeddings/generateLefffEmbeddings.py
@@ -35,14 +35,14 @@ for conllu in [("--conllu %s"%conlluFiles, "conllu")] :
 
 nbDone = 0
 for name, command in commands :
-  print("\r%s\r%5.2f%% Generating %s"%(" "*80, 100*nbDone/len(commands), name), end="")
+  print("\r%s\r%5.2f%% Generating %s"%(" "*80, 100*nbDone/len(commands), name), end="", file=sys.stderr)
   sys.stdout.flush()
   
   err = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stderr.read().decode()
   if len(err) > 0 :
-    print(err, sys.stderr)
+    print("\r%s\rWhen generating %s\n%s"%(" "*80, name, err), file=sys.stderr)
 
   nbDone += 1
 
-print()
+print(file=sys.stderr)
 
-- 
GitLab