Skip to content
Snippets Groups Projects
Commit 057e1d9d authored by Franck Dary's avatar Franck Dary
Browse files

Do not read conllu file for pretrained embeddings

parent 98e2ffb0
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,9 @@ class Dicts :
def __init__(self) :
self.dicts = {}
self.foundLower = {}
self.notFound = {}
self.found = {}
def addDict(self, name, d) :
if name in self.dicts :
......@@ -45,6 +48,8 @@ class Dicts :
splited = line.split('\t')
for col in targetColumns :
if col in pretrained :
continue
if col == "LETTER" :
values = [letter for letter in splited[col2index["FORM"]]]
else :
......@@ -88,9 +93,19 @@ class Dicts :
if col not in self.dicts :
raise Exception("Unknown dict name '%s' among %s"%(col, str(list(self.dicts.keys()))))
if value in self.dicts[col] :
if col not in self.found :
self.found[col] = set()
self.found[col].add(value)
return self.dicts[col][value][0]
if value.lower() in self.dicts[col] :
if col not in self.foundLower :
self.foundLower[col] = set()
self.foundLower[col].add(value)
return self.dicts[col][value.lower()][0]
if col not in self.notFound :
self.notFound[col] = set()
self.notFound[col].add(value)
return self.dicts[col][self.unkToken][0]
def getElementsOf(self, col) :
......@@ -103,5 +118,29 @@ class Dicts :
def load(self, target) :
self.dicts = json.load(open(target, "r"))
def printStats(self, output) :
for col in self.dicts :
total = 0
if col in self.found :
total += len(self.found[col])
if col in self.notFound :
total += len(self.notFound[col])
if col in self.foundLower :
total += len(self.foundLower[col])
if total == 0 :
continue
print(col, file=output)
if col in self.found :
print("Found : %.2f%%"%(100.0*len(self.found[col])/total), file=output)
if col in self.foundLower :
print("Found Lower : %.2f%%"%(100.0*len(self.foundLower[col])/total), file=output)
if col in self.notFound :
print("Not found : %.2f%%"%(100.0*len(self.notFound[col])/total), file=output)
def resetStats(self) :
self.found = {}
self.notFound = {}
self.foundLower = {}
################################################################################
......@@ -78,6 +78,12 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami
def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}
statsFile = open(modelDir+"/dicts.stats", "w")
print("Train :", file=statsFile)
dicts.printStats(statsFile)
dicts.resetStats()
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
......@@ -91,11 +97,16 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))])
print("\nDev :", file=statsFile)
dicts.printStats(statsFile)
dicts.resetStats()
if saved :
torch.save(model, modelDir+"/network.pt")
for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
statsFile.close()
return bestLoss, bestScore
################################################################################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment