Skip to content
Snippets Groups Projects
Commit bf44c7fe authored by Franck Dary's avatar Franck Dary
Browse files

Added way to decode model in main

parent 18803742
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
from Dicts import Dicts
import Config
import torch
......@@ -65,7 +66,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
################################################################################
################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
......@@ -79,6 +80,10 @@ def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdou
for config in sentences[1:] :
config.print(sys.stdout, header=False)
elif type == "model" :
if dicts is None :
dicts = Dicts()
dicts.load(modelDir+"/dicts.json")
network = torch.load(modelDir+"/network.pt")
for config in sentences :
decodeModel(transitionSet, strategy, config, network, dicts, debug)
sentences[0].print(output, header=True)
......
......@@ -67,12 +67,12 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for iter in range(1,nbIter+1) :
network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
......@@ -98,7 +98,7 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
......
......@@ -35,7 +35,7 @@ if __name__ == "__main__" :
if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
elif args.mode == "decode" :
Decode.decodeMode(args.debug, args.corpus, args.type)
Decode.decodeMode(args.debug, args.corpus, args.type, args.model)
else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment