import sys import random import torch import copy import math from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts from Util import timeStamp, prettyInt, numParameters, getDevice from Rl import ReplayMemory, selectAction, optimizeModel, rewarding import Networks import Decode import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : sentences = Config.readConllu(filename, predicted) if type == "oracle" : trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent) return if type == "rl": trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################ ################################################################################ # Return list of examples for each transitionSet def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) : examples = [[] for _ in transitionSets] with torch.no_grad() : EOS = Transition("EOS") config.moveWordIndex(0) config.state = 0 moved = True while moved : ts = transitionSets[config.state] missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"]) if len(candidates) == 0 : break best = min([cand[0] for cand in candidates]) candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1] features = network.extractFeatures(dicts, config) candidate = candidateOracle if debug : config.printForDebug(sys.stderr) print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) if dynamic : network.setState(config.state) output = network(features.unsqueeze(0).to(getDevice())) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] if debug : print(str(candidate), file=sys.stderr) goldIndex = [str(trans) for trans in ts].index(str(candidateOracle)) example = torch.cat([torch.LongTensor([goldIndex]), features]) examples[config.state].append(example) moved = applyTransition(strat, config, candidate, None) EOS.apply(config, strat) return examples ################################################################################ ################################################################################ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) : col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"} devScore = "" saved = True if bestLoss is None else totalLoss < bestLoss bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) toEval = sorted([col for col in predicted]) scores = [res[col2metric[col]][0].f1 for col in toEval] score = sum(scores)/len(scores) saved = True if bestScore is None else score > bestScore bestScore = score if bestScore is None else max(bestScore, score) devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))]) if saved : torch.save(model, modelDir+"/network.pt") for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] : print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out) return bestLoss, bestScore ################################################################################ ################################################################################ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) transitionNames = {} for ts in transitionSets : for t in ts : transitionNames[str(t)] = (len(transitionNames), 0) transitionNames[dicts.nullToken] = (len(transitionNames), 0) dicts.addDict("HISTORY", transitionNames) dicts.save(modelDir+"/dicts.json") network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) examples = [[] for _ in transitionSets] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False) for e in range(len(examples)) : examples[e] += extracted[e] totalNbExamples = sum(map(len,examples)) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr) for e in range(len(examples)) : examples[e] = torch.stack(examples[e]) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) optimizer = torch.optim.Adam(network.parameters(), lr=lr) lossFct = torch.nn.CrossEntropyLoss() bestLoss = None bestScore = None for epoch in range(1,nbEpochs+1) : if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 : examples = [[] for _ in transitionSets] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr) for config in sentences : extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True) for e in range(len(examples)) : examples[e] += extracted[e] totalNbExamples = sum(map(len,examples)) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr) for e in range(len(examples)) : examples[e] = torch.stack(examples[e]) network.train() for e in range(len(examples)) : examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0))) totalLoss = 0.0 nbEx = 0 printInterval = 2000 advancement = 0 distribution = [len(e)/totalNbExamples for e in examples] curIndexes = [0 for _ in examples] while True : state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0] if curIndexes[state] >= len(examples[state]) : state = -1 for i in range(len(examples)) : if curIndexes[i] < len(examples[i]) : state = i if state == -1 : break batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice()) curIndexes[state] += batchSize targets = batch[:,:1].view(-1) inputs = batch[:,1:] nbEx += targets.size(0) advancement += targets.size(0) if not silent and advancement >= printInterval : advancement = 0 print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr) network.setState(state) outputs = network(inputs) loss = lossFct(outputs, targets) network.zero_grad() loss.backward() optimizer.step() totalLoss += float(loss) bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted) ################################################################################ ################################################################################ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : memory = None dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) transitionNames = {} for ts in transitionSets : for t in ts : transitionNames[str(t)] = (len(transitionNames), 0) transitionNames[dicts.nullToken] = (len(transitionNames), 0) dicts.addDict("HISTORY", transitionNames) dicts.save(modelDir + "/dicts.json") policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) bestLoss = None bestScore = None sentences = copy.deepcopy(sentencesOriginal) nbExByEpoch = sum(map(len,sentences)) sentIndex = 0 for epoch in range(1,nbIter+1) : probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2) probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2) i = 0 totalLoss = 0.0 while True : if sentIndex >= len(sentences) : sentences = copy.deepcopy(sentencesOriginal) random.shuffle(sentences) sentIndex = 0 if not silent : print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) sentence = sentences[sentIndex] sentence.moveWordIndex(0) state = policy_net.extractFeatures(dicts, sentence).to(getDevice()) count = 0 while True : missingLinks = getMissingLinks(sentence) transitionSet = transitionSets[sentence.state] fromState = sentence.state toState = sentence.state if debug : sentence.printForDebug(sys.stderr) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState) if action is None : break if debug : print("Selected action : %s"%str(action), file=sys.stderr) appliable = action.appliable(sentence) reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc) reward = torch.FloatTensor([reward_]).to(getDevice()) newState = None toState = strategy[action.name][1] if action.name in strategy else -1 if appliable : applyTransition(strategy, sentence, action, reward_) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) else: count+=1 if memory is None : memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))] memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) if i % (1*batchSize) == 0 : target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() i += 1 if state is None or count == countBreak: break if i >= nbExByEpoch : break sentIndex += 1 bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) ################################################################################