Skip to content
Snippets Groups Projects
Commit bf44c7fe authored by Franck Dary's avatar Franck Dary
Browse files

Added way to decode model in main

parent 18803742
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import random ...@@ -2,6 +2,7 @@ import random
import sys import sys
from Transition import Transition, getMissingLinks, applyTransition from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures from Features import extractFeatures
from Dicts import Dicts
import Config import Config
import torch import torch
...@@ -65,7 +66,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -65,7 +66,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
################################################################################ ################################################################################
################################################################################ ################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) : def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
...@@ -79,6 +80,10 @@ def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdou ...@@ -79,6 +80,10 @@ def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdou
for config in sentences[1:] : for config in sentences[1:] :
config.print(sys.stdout, header=False) config.print(sys.stdout, header=False)
elif type == "model" : elif type == "model" :
if dicts is None :
dicts = Dicts()
dicts.load(modelDir+"/dicts.json")
network = torch.load(modelDir+"/network.pt")
for config in sentences : for config in sentences :
decodeModel(transitionSet, strategy, config, network, dicts, debug) decodeModel(transitionSet, strategy, config, network, dicts, debug)
sentences[0].print(output, header=True) sentences[0].print(output, header=True)
......
...@@ -67,12 +67,12 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -67,12 +67,12 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
examples = torch.stack(examples) examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss() lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None bestLoss = None
bestScore = None bestScore = None
for iter in range(1,nbIter+1) : for iter in range(1,nbIter+1) :
network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0))) examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0 totalLoss = 0.0
nbEx = 0 nbEx = 0
...@@ -98,7 +98,7 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -98,7 +98,7 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None : if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu" outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w")) Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1 UAS = res["UAS"][0].f1
score = UAS score = UAS
......
...@@ -35,7 +35,7 @@ if __name__ == "__main__" : ...@@ -35,7 +35,7 @@ if __name__ == "__main__" :
if args.mode == "train" : if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
Decode.decodeMode(args.debug, args.corpus, args.type) Decode.decodeMode(args.debug, args.corpus, args.type, args.model)
else : else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1) exit(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment