diff --git a/Dicts.py b/Dicts.py index be4f3fe6aa72553c0a7e0e4b30a3b888bed9096c..253bf11389dc3129ed2fcc26c998d19adf739df9 100644 --- a/Dicts.py +++ b/Dicts.py @@ -15,6 +15,9 @@ class Dicts : def __init__(self) : self.dicts = {} + self.foundLower = {} + self.notFound = {} + self.found = {} def addDict(self, name, d) : if name in self.dicts : @@ -45,6 +48,8 @@ class Dicts : splited = line.split('\t') for col in targetColumns : + if col in pretrained : + continue if col == "LETTER" : values = [letter for letter in splited[col2index["FORM"]]] else : @@ -88,9 +93,19 @@ class Dicts : if col not in self.dicts : raise Exception("Unknown dict name '%s' among %s"%(col, str(list(self.dicts.keys())))) if value in self.dicts[col] : + if col not in self.found : + self.found[col] = set() + self.found[col].add(value) return self.dicts[col][value][0] if value.lower() in self.dicts[col] : + if col not in self.foundLower : + self.foundLower[col] = set() + self.foundLower[col].add(value) return self.dicts[col][value.lower()][0] + + if col not in self.notFound : + self.notFound[col] = set() + self.notFound[col].add(value) return self.dicts[col][self.unkToken][0] def getElementsOf(self, col) : @@ -103,5 +118,29 @@ class Dicts : def load(self, target) : self.dicts = json.load(open(target, "r")) + + def printStats(self, output) : + for col in self.dicts : + total = 0 + if col in self.found : + total += len(self.found[col]) + if col in self.notFound : + total += len(self.notFound[col]) + if col in self.foundLower : + total += len(self.foundLower[col]) + if total == 0 : + continue + print(col, file=output) + if col in self.found : + print("Found : %.2f%%"%(100.0*len(self.found[col])/total), file=output) + if col in self.foundLower : + print("Found Lower : %.2f%%"%(100.0*len(self.foundLower[col])/total), file=output) + if col in self.notFound : + print("Not found : %.2f%%"%(100.0*len(self.notFound[col])/total), file=output) + + def resetStats(self) : + self.found = {} + self.notFound = {} + self.foundLower = {} ################################################################################ diff --git a/Train.py b/Train.py index 9cc05a713d60c2e6481227d8a2c0c473863508b6..d83f2a65d7c824e5f0fd775079374b03e5ff5e75 100644 --- a/Train.py +++ b/Train.py @@ -78,6 +78,12 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) : col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"} + statsFile = open(modelDir+"/dicts.stats", "w") + + print("Train :", file=statsFile) + dicts.printStats(statsFile) + dicts.resetStats() + devScore = "" saved = True if bestLoss is None else totalLoss < bestLoss bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) @@ -91,11 +97,16 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss saved = True if bestScore is None else score > bestScore bestScore = score if bestScore is None else max(bestScore, score) devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))]) + print("\nDev :", file=statsFile) + dicts.printStats(statsFile) + dicts.resetStats() if saved : torch.save(model, modelDir+"/network.pt") for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] : print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out) + statsFile.close() + return bestLoss, bestScore ################################################################################