Newer
Older
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
Franck Dary
committed
from Util import timeStamp, prettyInt, numParameters, getDevice
from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
import Networks
import Decode
import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
Franck Dary
committed
def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
sentences = Config.readConllu(filename, predicted)
if type == "oracle" :
Franck Dary
committed
trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
Franck Dary
committed
trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
# Return list of examples for each transitionSet
def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) :
examples = [[] for _ in transitionSets]
with torch.no_grad() :
EOS = Transition("EOS")
config.moveWordIndex(0)
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"])
if len(candidates) == 0 :
break
best = min([cand[0] for cand in candidates])
candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
Franck Dary
committed
features = network.extractFeatures(dicts, config)
if debug :
config.printForDebug(sys.stderr)
print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
Franck Dary
committed
if dynamic :
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
if debug :
Maxime Petit
committed
goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
example = torch.cat([torch.LongTensor([goldIndex]), features])
moved = applyTransition(strat, config, candidate, None)
Maxime Petit
committed
EOS.apply(config, strat)
Maxime Petit
committed
return examples
################################################################################
################################################################################
def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
toEval = sorted([col for col in predicted])
scores = [res[col2metric[col]][0].f1 for col in toEval]
score = sum(scores)/len(scores)
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))])
if saved :
torch.save(model, modelDir+"/network.pt")
for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
return bestLoss, bestScore
################################################################################
################################################################################
Franck Dary
committed
def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
transitionNames = {}
for ts in transitionSets :
for t in ts :
transitionNames[str(t)] = (len(transitionNames), 0)
transitionNames[dicts.nullToken] = (len(transitionNames), 0)
dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir+"/dicts.json")
Franck Dary
committed
network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
for e in range(len(examples)) :
examples[e] += extracted[e]
totalNbExamples = sum(map(len,examples))
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
for e in range(len(examples)) :
examples[e] = torch.stack(examples[e])
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=lr)
lossFct = torch.nn.CrossEntropyLoss()
for epoch in range(1,nbEpochs+1) :
if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
examples = [[] for _ in transitionSets]
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
for e in range(len(examples)) :
examples[e] += extracted[e]
totalNbExamples = sum(map(len,examples))
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
for e in range(len(examples)) :
examples[e] = torch.stack(examples[e])
for e in range(len(examples)) :
examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
distribution = [len(e)/totalNbExamples for e in examples]
curIndexes = [0 for _ in examples]
while True :
state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0]
if curIndexes[state] >= len(examples[state]) :
state = -1
for i in range(len(examples)) :
if curIndexes[i] < len(examples[i]) :
state = i
if state == -1 :
break
batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice())
curIndexes[state] += batchSize
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr)
network.setState(state)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
################################################################################
################################################################################
Franck Dary
committed
def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
transitionNames = {}
for ts in transitionSets :
for t in ts :
transitionNames[str(t)] = (len(transitionNames), 0)
transitionNames[dicts.nullToken] = (len(transitionNames), 0)
dicts.addDict("HISTORY", transitionNames)
Franck Dary
committed
policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
Franck Dary
committed
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
Franck Dary
committed
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
sentences = copy.deepcopy(sentencesOriginal)
nbExByEpoch = sum(map(len,sentences))
sentIndex = 0
for epoch in range(1,nbIter+1) :
while True :
if sentIndex >= len(sentences) :
sentences = copy.deepcopy(sentencesOriginal)
random.shuffle(sentences)
sentIndex = 0
print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
Franck Dary
committed
state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
Maxime Petit
committed
count = 0
list_probas = []
for pb in range(len(probas)):
list_probas.append([round((probas[pb][0][0]-probas[pb][0][2])*math.exp((-epoch+1)/probas[pb][0][1])+probas[pb][0][2], 2),
round((probas[pb][1][0]-probas[pb][1][2])*math.exp((-epoch+1)/probas[pb][1][1])+probas[pb][1][2], 2)])
Maxime Petit
committed
fromState = sentence.state
toState = sentence.state
probaRandom = list_probas[fromState][0]
probaOracle = list_probas[fromState][1]
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
print("Selected action : %s"%str(action), file=sys.stderr)
appliable = action.appliable(sentence)
reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
reward = torch.FloatTensor([reward_]).to(getDevice())
toState = strategy[action.name][1] if action.name in strategy else -1
applyTransition(strategy, sentence, action, reward_)
Franck Dary
committed
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
Maxime Petit
committed
else:
count+=1
memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
i += 1
Maxime Petit
committed
if state is None or count == countBreak:
break
if i >= nbExByEpoch :
break
sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
################################################################################