Skip to content
Snippets Groups Projects
Commit cf60ff93 authored by Franck Dary's avatar Franck Dary
Browse files

PrintAdvancement and corrected some bugs

parent 74a8e16e
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ def timeStamp() :
################################################################################
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) :
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
......@@ -46,10 +46,18 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) :
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
for batchIndex in range(0,examples.size(0)-6,6) :
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
......@@ -63,8 +71,6 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) :
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
decodeMode(debug, filename, "model", network, dicts)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
......@@ -115,12 +121,14 @@ if __name__ == "__main__" :
help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.")
parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.")
args = parser.parse_args()
os.makedirs(args.model, exist_ok=True)
if args.mode == "train" :
trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev)
trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
elif args.mode == "decode" :
decodeMode(args.debug, args.corpus, args.type)
else :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment