diff --git a/Decode.py b/Decode.py index 80b46ea5ac514010ebd0a931834a44ea393a258a..af8fc687cad2236c42cdf36a7f724732e694e0ec 100644 --- a/Decode.py +++ b/Decode.py @@ -2,6 +2,7 @@ import random import sys from Transition import Transition, getMissingLinks, applyTransition from Features import extractFeatures +from Dicts import Dicts import Config import torch @@ -65,7 +66,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ################################################################################ ################################################################################ -def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) : +def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} @@ -79,6 +80,10 @@ def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdou for config in sentences[1:] : config.print(sys.stdout, header=False) elif type == "model" : + if dicts is None : + dicts = Dicts() + dicts.load(modelDir+"/dicts.json") + network = torch.load(modelDir+"/network.pt") for config in sentences : decodeModel(transitionSet, strategy, config, network, dicts, debug) sentences[0].print(output, header=True) diff --git a/Train.py b/Train.py index e16725f168b3a8bbc0f4b025e4e3a017bd9a6b6c..6d192121d5bac7ba72b1526935a3371c32b9c78a 100644 --- a/Train.py +++ b/Train.py @@ -67,12 +67,12 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran examples = torch.stack(examples) network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) - network.train() optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() bestLoss = None bestScore = None for iter in range(1,nbIter+1) : + network.train() examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 nbEx = 0 @@ -98,7 +98,7 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" - Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) + Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) UAS = res["UAS"][0].f1 score = UAS diff --git a/main.py b/main.py index 132f0585e1f13a76fc10b732ef9c2f828691e788..cf785b3fc162ea0004fa855ec77466624599fd05 100755 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ if __name__ == "__main__" : if args.mode == "train" : Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) elif args.mode == "decode" : - Decode.decodeMode(args.debug, args.corpus, args.type) + Decode.decodeMode(args.debug, args.corpus, args.type, args.model) else : print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) exit(1)