diff --git a/Train.py b/Train.py index 7985a456857d45f3cabc9d7c9b05971d2b76e7ac..01c5aaf0a0bb2018ec743486920d7ab7aaab8e05 100644 --- a/Train.py +++ b/Train.py @@ -79,11 +79,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss outFilename = modelDir+"/predicted_dev.conllu" Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) - scores = [res[col2metric[col]][0].f1 for col in predicted] + toEval = sorted([col for col in predicted]) + scores = [res[col2metric[col]][0].f1 for col in toEval] score = sum(scores)/len(scores) saved = True if bestScore is None else score > bestScore bestScore = score if bestScore is None else max(bestScore, score) - devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[list(predicted)[i]], scores[i]) for i in range(len(predicted))]) + devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))]) if saved : torch.save(model, modelDir+"/network.pt") for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :