import sys import random import torch import copy from Transition import Transition, getMissingLinks, applyTransition import Features from Dicts import Dicts from Util import timeStamp, prettyInt, numParameters, getDevice from Rl import ReplayMemory, selectAction, optimizeModel import Networks import Decode import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(filename) if type == "oracle" : trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, silent) return if type == "rl": trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################ ################################################################################ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : examples = [] with torch.no_grad() : EOS = Transition("EOS") config.moveWordIndex(0) moved = True while moved : missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) if len(candidates) == 0 : break best = min([cand[0] for cand in candidates]) candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1] features = network.extractFeatures(dicts, config) candidate = candidateOracle.name if debug : config.printForDebug(sys.stderr) print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) if dynamic : output = network(features.unsqueeze(0).to(getDevice())) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] if debug : print(candidate.name, file=sys.stderr) goldIndex = [trans.name for trans in ts].index(candidateOracle.name) candidateIndex = [trans.name for trans in ts].index(candidate) example = torch.cat([torch.LongTensor([goldIndex]), features]) examples.append(example) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) return examples ################################################################################ ################################################################################ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) : devScore = "" saved = True if bestLoss is None else totalLoss < bestLoss bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) if devFile is not None : outFilename = modelDir+"/predicted_dev.conllu" Decode.decodeMode(debug, devFile, "model", modelDir, model, dicts, open(outFilename, "w")) res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), []) UAS = res["UAS"][0].f1 score = UAS saved = True if bestScore is None else score > bestScore bestScore = score if bestScore is None else max(bestScore, score) devScore = ", Dev : UAS=%.2f"%(UAS) if saved : torch.save(model, modelDir+"/network.pt") print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) return bestLoss, bestScore ################################################################################ ################################################################################ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["UPOS"]) dicts.save(modelDir+"/dicts.json") network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) examples = [] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) examples = torch.stack(examples) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() bestLoss = None bestScore = None for epoch in range(1,nbEpochs+1) : if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 : examples = [] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr) for config in sentences : examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) examples = torch.stack(examples) network.train() examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 nbEx = 0 printInterval = 2000 advancement = 0 for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : batch = examples[batchIndex:batchIndex+batchSize].to(getDevice()) targets = batch[:,:1].view(-1) inputs = batch[:,1:] nbEx += targets.size(0) advancement += targets.size(0) if not silent and advancement >= printInterval : advancement = 0 print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) outputs = network(inputs) loss = lossFct(outputs, targets) network.zero_grad() loss.backward() optimizer.step() totalLoss += float(loss) bestLoss, bestScore = evalModelAndSave(debug, network, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs) ################################################################################ ################################################################################ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) : memory = None dicts = Dicts() dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.save(modelDir + "/dicts.json") policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) target_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) bestLoss = None bestScore = None sentences = copy.deepcopy(sentencesOriginal) nbExByEpoch = sum(map(len,sentences)) sentIndex = 0 for epoch in range(1,nbIter+1) : i = 0 totalLoss = 0.0 while True : if sentIndex >= len(sentences) : sentences = copy.deepcopy(sentencesOriginal) random.shuffle(sentences) sentIndex = 0 if not silent : print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) sentence = sentences[sentIndex] sentence.moveWordIndex(0) state = policy_net.extractFeatures(dicts, sentence).to(getDevice()) while True : missingLinks = getMissingLinks(sentence) if debug : sentence.printForDebug(sys.stderr) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.1, probaOracle=0.1) if action is None : break appliable = action.appliable(sentence) # Reward for doing an illegal action reward = -3.0 if appliable : reward = -1.0*action.getOracleScore(sentence, missingLinks) reward = torch.FloatTensor([reward]).to(getDevice()) newState = None if appliable : applyTransition(transitionSet, strategy, sentence, action.name) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) if memory is None : memory = ReplayMemory(5000, state.numel()) memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer) if i % (2*batchSize) == 0 : target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() i += 1 if state is None : break if i >= nbExByEpoch : break sentIndex += 1 bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) ################################################################################