From 4a7dabd5a1c3259140a46629f37d394aa6b1828f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 20 May 2022 11:23:00 +0200 Subject: [PATCH] Fixed embeddings generation --- Networks.py | 6 +++++- embeddings/lefffEmbeddings/generateLefffEmbeddings.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Networks.py b/Networks.py index 5612b99..32daab1 100644 --- a/Networks.py +++ b/Networks.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F import Features import Transition +from Util import prettyInt ################################################################################ def readPretrainedSize(w2vFile) : @@ -13,16 +14,19 @@ def readPretrainedSize(w2vFile) : ################################################################################ def loadW2v(w2vFile, weights, dicts, colname) : size = None + nbLoaded = 0 for line in open(w2vFile, "r") : line = line.strip() if size is None : size = int(line.split()[1]) continue splited = line.split() - word = " ".join(splited[0:len(splited)-size]) + word = " ".join(splited[0:len(splited)-size]).replace("◌"," ") emb = torch.tensor(list(map(float,splited[len(splited)-size:]))) weights[dicts.get(colname, word)] = emb + nbLoaded += 1 + print("Loaded %s pretrained embeddings for '%s' from '%s'"%(prettyInt(nbLoaded,3), colname, w2vFile)) ################################################################################ ################################################################################ diff --git a/embeddings/lefffEmbeddings/generateLefffEmbeddings.py b/embeddings/lefffEmbeddings/generateLefffEmbeddings.py index edfd35e..0303685 100755 --- a/embeddings/lefffEmbeddings/generateLefffEmbeddings.py +++ b/embeddings/lefffEmbeddings/generateLefffEmbeddings.py @@ -35,14 +35,14 @@ for conllu in [("--conllu %s"%conlluFiles, "conllu")] : nbDone = 0 for name, command in commands : - print("\r%s\r%5.2f%% Generating %s"%(" "*80, 100*nbDone/len(commands), name), end="") + print("\r%s\r%5.2f%% Generating %s"%(" "*80, 100*nbDone/len(commands), name), end="", file=sys.stderr) sys.stdout.flush() err = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stderr.read().decode() if len(err) > 0 : - print(err, sys.stderr) + print("\r%s\rWhen generating %s\n%s"%(" "*80, name, err), file=sys.stderr) nbDone += 1 -print() +print(file=sys.stderr) -- GitLab