Skip to content
Snippets Groups Projects
Commit bc508cdb authored by Franck Dary's avatar Franck Dary
Browse files

Code refactoring and cosmetic changes to print

parent a35103cd
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ import copy
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
from Util import timeStamp, prettyInt, numParameters
from Rl import ReplayMemory, selectAction, optimizeModel
import Networks
import Decode
......@@ -61,7 +61,28 @@ def extractExamples(ts, strat, config, dicts, debug=False) :
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) :
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, model, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(model, modelDir+"/network.pt")
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
return bestLoss, bestScore
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
......@@ -69,15 +90,16 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for iter in range(1,nbIter+1) :
for epoch in range(1,nbEpochs+1) :
network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
......@@ -99,21 +121,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(network, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
bestLoss, bestScore = evalModelAndSave(debug, network, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs)
################################################################################
################################################################################
......@@ -130,11 +139,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net.eval()
policy_net.train()
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
bestLoss = None
bestScore = None
for epoch in range(nbIter) :
for epoch in range(1,nbIter+1) :
i = 0
totalLoss = 0.0
sentences = copy.deepcopy(sentencesOriginal)
......@@ -167,22 +178,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net.eval()
policy_net.train()
i += 1
# Fin epoch, compute score and save model
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(policy_net, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter)
################################################################################
......@@ -10,3 +10,15 @@ def isEmpty(value) :
return value == "_" or value == ""
################################################################################
################################################################################
def prettyInt(value, p) :
l = ['' for _ in range((p-len(str(value))%p)%p)] + list(str(value))
l = ["".join(l[i:i+p]) for i in range(0,len(l),p)]
return " ".join(l)
################################################################################
################################################################################
def numParameters(model) :
return sum(p.numel() for p in model.parameters() if p.requires_grad)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment