-
Franck Dary authored
Added incremental mode (-i) : parser cannot see right context if it has not been fixated (so only accessible after a BACK)
Franck Dary authoredAdded incremental mode (-i) : parser cannot see right context if it has not been fixated (so only accessible after a BACK)
Train.py 10.12 KiB
import sys
import random
import torch
import copy
import math
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp, prettyInt, numParameters, getDevice
from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
import Networks
import Decode
import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, silent=False) :
sentences = Config.readConllu(filename)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, silent)
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
examples = []
with torch.no_grad() :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
if len(candidates) == 0 :
break
best = min([cand[0] for cand in candidates])
candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
features = network.extractFeatures(dicts, config)
candidate = candidateOracle.name
if debug :
config.printForDebug(sys.stderr)
print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if dynamic :
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
if debug :
print(candidate.name, file=sys.stderr)
goldIndex = [trans.name for trans in ts].index(candidateOracle.name)
candidateIndex = [trans.name for trans in ts].index(candidate)
example = torch.cat([torch.LongTensor([goldIndex]), features])
examples.append(example)
moved = applyTransition(ts, strat, config, candidate, None)
EOS.apply(config, strat)
return examples
################################################################################
################################################################################
def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental) :
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", ts, strat, modelDir, model, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(model, modelDir+"/network.pt")
for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
return bestLoss, bestScore
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for epoch in range(1,nbEpochs+1) :
if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize].to(getDevice())
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Current epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental)
################################################################################
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, silent=False) :
memory = None
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
bestLoss = None
bestScore = None
sentences = copy.deepcopy(sentencesOriginal)
nbExByEpoch = sum(map(len,sentences))
sentIndex = 0
for epoch in range(1,nbIter+1) :
probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2)
probaOracle = round(0.3*math.exp((-epoch+1)/2), 2)
i = 0
totalLoss = 0.0
while True :
if sentIndex >= len(sentences) :
sentences = copy.deepcopy(sentencesOriginal)
random.shuffle(sentences)
sentIndex = 0
if not silent :
print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
while True :
missingLinks = getMissingLinks(sentence)
if debug :
sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
if action is None :
break
if debug :
print("Selected action : %s"%action.name, file=sys.stderr)
appliable = action.appliable(sentence)
reward_ = rewarding(appliable, sentence, action, missingLinks)
reward = torch.FloatTensor([reward_]).to(getDevice())
newState = None
if appliable :
applyTransition(transitionSet, strategy, sentence, action.name, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
if memory is None :
memory = ReplayMemory(5000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
if i % (1*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
i += 1
if state is None :
break
if i >= nbExByEpoch :
break
sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental)
################################################################################