Skip to content
Snippets Groups Projects
Commit bc508cdb authored by Franck Dary's avatar Franck Dary
Browse files

Code refactoring and cosmetic changes to print

parent a35103cd
No related branches found
No related tags found
No related merge requests found
...@@ -6,7 +6,7 @@ import copy ...@@ -6,7 +6,7 @@ import copy
from Transition import Transition, getMissingLinks, applyTransition from Transition import Transition, getMissingLinks, applyTransition
import Features import Features
from Dicts import Dicts from Dicts import Dicts
from Util import timeStamp from Util import timeStamp, prettyInt, numParameters
from Rl import ReplayMemory, selectAction, optimizeModel from Rl import ReplayMemory, selectAction, optimizeModel
import Networks import Networks
import Decode import Decode
...@@ -61,7 +61,28 @@ def extractExamples(ts, strat, config, dicts, debug=False) : ...@@ -61,7 +61,28 @@ def extractExamples(ts, strat, config, dicts, debug=False) :
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) :
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, model, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(model, modelDir+"/network.pt")
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
return bestLoss, bestScore
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
examples = [] examples = []
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.readConllu(filename, ["FORM", "UPOS"])
...@@ -69,15 +90,16 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -69,15 +90,16 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences : for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug) examples += extractExamples(transitionSet, strategy, config, dicts, debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples) examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)) network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss() lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None bestLoss = None
bestScore = None bestScore = None
for iter in range(1,nbIter+1) : for epoch in range(1,nbEpochs+1) :
network.train() network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0))) examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0 totalLoss = 0.0
...@@ -99,21 +121,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -99,21 +121,8 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
loss.backward() loss.backward()
optimizer.step() optimizer.step()
totalLoss += float(loss) totalLoss += float(loss)
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss bestLoss, bestScore = evalModelAndSave(debug, network, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs)
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(network, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################ ################################################################################
################################################################################ ################################################################################
...@@ -130,11 +139,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -130,11 +139,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
bestLoss = None bestLoss = None
bestScore = None bestScore = None
for epoch in range(nbIter) : for epoch in range(1,nbIter+1) :
i = 0 i = 0
totalLoss = 0.0 totalLoss = 0.0
sentences = copy.deepcopy(sentencesOriginal) sentences = copy.deepcopy(sentencesOriginal)
...@@ -167,22 +178,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -167,22 +178,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
i += 1 i += 1
# Fin epoch, compute score and save model bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter)
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(policy_net, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################ ################################################################################
...@@ -10,3 +10,15 @@ def isEmpty(value) : ...@@ -10,3 +10,15 @@ def isEmpty(value) :
return value == "_" or value == "" return value == "_" or value == ""
################################################################################ ################################################################################
################################################################################
def prettyInt(value, p) :
l = ['' for _ in range((p-len(str(value))%p)%p)] + list(str(value))
l = ["".join(l[i:i+p]) for i in range(0,len(l),p)]
return " ".join(l)
################################################################################
################################################################################
def numParameters(model) :
return sum(p.numel() for p in model.parameters() if p.requires_grad)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment