Skip to content
Snippets Groups Projects
Commit cf60ff93 authored by Franck Dary's avatar Franck Dary
Browse files

PrintAdvancement and corrected some bugs

parent 74a8e16e
No related branches found
No related tags found
No related merge requests found
...@@ -22,7 +22,7 @@ def timeStamp() : ...@@ -22,7 +22,7 @@ def timeStamp() :
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) : def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
...@@ -46,10 +46,18 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) : ...@@ -46,10 +46,18 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) :
for iter in range(1,nbIter+1) : for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0))) examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0 totalLoss = 0.0
for batchIndex in range(0,examples.size(0)-6,6) : nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize] batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1) targets = batch[:,:1].view(-1)
inputs = batch[:,1:] inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs) outputs = network(inputs)
loss = lossFct(outputs, targets) loss = lossFct(outputs, targets)
network.zero_grad() network.zero_grad()
...@@ -63,8 +71,6 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) : ...@@ -63,8 +71,6 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) :
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1) devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr) print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
decodeMode(debug, filename, "model", network, dicts)
return return
print("ERROR : unknown type '%s'"%type, file=sys.stderr) print("ERROR : unknown type '%s'"%type, file=sys.stderr)
...@@ -115,12 +121,14 @@ if __name__ == "__main__" : ...@@ -115,12 +121,14 @@ if __name__ == "__main__" :
help="Name of the CoNLL-U file of the dev corpus.") help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--debug", "-d", default=False, action="store_true", parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.") help="Print debug infos on stderr.")
parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.")
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.model, exist_ok=True) os.makedirs(args.model, exist_ok=True)
if args.mode == "train" : if args.mode == "train" :
trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev) trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
decodeMode(args.debug, args.corpus, args.type) decodeMode(args.debug, args.corpus, args.type)
else : else :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment