diff --git a/Decode.py b/Decode.py index c2e8df021b9c65536923cd5bc38fcf45d440ed10..f5ac8341017f35e7fcd2f95e0bde530bbd598fdd 100644 --- a/Decode.py +++ b/Decode.py @@ -1,7 +1,6 @@ import random import sys from Transition import Transition, getMissingLinks, applyTransition -from Features import extractFeatures from Dicts import Dicts from Util import getDevice import Config @@ -56,7 +55,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : with torch.no_grad(): while moved : - features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) + features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) output = network(features) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] diff --git a/Dicts.py b/Dicts.py index 41da2a2c50b193abe59dd599498368494b1cdcdf..010dc27fc594fa117f488c13615c9374c0fdf7e3 100644 --- a/Dicts.py +++ b/Dicts.py @@ -9,6 +9,9 @@ class Dicts : self.nullToken = "__null__" self.noStackToken = "__nostack__" self.oobToken = "__oob__" + self.noDepLeft = "__nodepleft__" + self.noDepRight = "__nodepright__" + self.noGov = "__nogov__" def readConllu(self, filename, colsSet=None) : defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC" @@ -30,7 +33,7 @@ class Dicts : targetColumns = list(col2index.keys()) else : targetColumns = list(colsSet) - self.dicts = {col : {self.unkToken : 0, self.nullToken : 1} for col in targetColumns} + self.dicts = {col : {self.unkToken : 0, self.nullToken : 1, self.noStackToken : 2, self.oobToken : 3, self.noDepLeft : 4, self.noDepRight : 5, self.noGov : 6} for col in targetColumns} splited = line.split('\t') for col in targetColumns : diff --git a/Features.py b/Features.py index c152e6318ad28fc7b03bece5804a49c5a3fc7923..6f67c7574077e10d41a823265bc0019a49837d91 100644 --- a/Features.py +++ b/Features.py @@ -3,61 +3,79 @@ import sys from Util import isEmpty ################################################################################ -def extractFeatures(dicts, config) : - return extractFeaturesPosExtended(dicts, config) -################################################################################ - -################################################################################ -def extractFeaturesPos(dicts, config) : - bufferWindow = range(-2,2+1) - stackWindow = range(0,3+1) - totalSize = len(bufferWindow)+len(stackWindow) - - result = torch.zeros(totalSize, dtype=torch.int) +# Input : b=buffer s=stack .0=governor .x=rightChild#x+1 .-x=leftChild#-x-1 +# Output : list of sentence indexes pointing to elements of featureFunction +# Special output values : +# -1 : Out of bounds +# -2 : Not in stack +# -3 : No dependent left +# -4 : No dependent right +# -5 : No gov +def extractIndexes(config, featureFunction) : + features = featureFunction.split() + res = [] + for feature in features : + splited = feature.split('.') + obj = splited[0] + index = int(splited[1]) + if obj == "b" : + index = config.wordIndex + index + if index not in (range(len(config.lines))) : + index = -1 + elif obj == "s" : + if index not in range(len(config.stack)) : + index = -2 + else : + index = config.stack[-1-index] + for depIndex in map(int,splited[2:]) : + if index < 0 : + break + if depIndex == 0 : + head = config.getAsFeature(index, "HEAD") + if isEmpty(head) : + index = -5 + else : + index = int(head) + continue + if depIndex > 0 : + rightChilds = [child for child in config.predChilds[index] if child > index] + if depIndex-1 in range(len(rightChilds)) : + index = rightChilds[depIndex-1] + else : + index = -4 + else : + leftChilds = [child for child in config.predChilds[index] if child < index] + if abs(depIndex)-1 in range(len(leftChilds)) : + index = leftChilds[abs(depIndex)-1] + else : + index = -3 + res.append(index) - insertIndex = 0 - for i in bufferWindow : - index = config.wordIndex + i - bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS") - result[insertIndex] = dicts.get("UPOS", bufferPos) - insertIndex += 1 - - for i in stackWindow : - stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS") - result[insertIndex] = dicts.get("UPOS", stackPos) - insertIndex += 1 - - return result + return res ################################################################################ ################################################################################ -# For each stack element, add its POS and the POS of its governor -def extractFeaturesPosExtended(dicts, config) : - bufferWindow = range(-2,2+1) - stackWindow = range(0,3+1) - totalSize = len(bufferWindow)+2*len(stackWindow) +# For each element of the feature function and for each column, concatenante the dict index +def extractColsFeatures(dicts, config, featureFunction, cols) : + specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov} + indexes = extractIndexes(config, featureFunction) + totalSize = len(cols)*len(indexes) result = torch.zeros(totalSize, dtype=torch.int) insertIndex = 0 - for i in bufferWindow : - index = config.wordIndex + i - bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS") - result[insertIndex] = dicts.get("UPOS", bufferPos) - insertIndex += 1 - - for i in stackWindow : - stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS") - stackGovHead = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "HEAD") - stackGovPos = dicts.nullToken - if not isEmpty(stackGovHead) and stackGovHead != dicts.nullToken : - stackGovPos = config.getAsFeature(int(stackGovHead), "UPOS") - elif stackGovHead == dicts.nullToken : - stackGovPos = dicts.noStackToken - result[insertIndex] = dicts.get("UPOS", stackPos) - insertIndex += 1 - result[insertIndex] = dicts.get("UPOS", stackGovPos) - insertIndex += 1 + for index in indexes : + if index < 0 : + for col in cols : + result[insertIndex] = dicts.get(col, specialValues[index]) + insertIndex += 1 + else : + for col in cols : + value = config.getAsFeature(index, col) + if isEmpty(value) : + value = dicts.nullToken + result[insertIndex] = dicts.get(col, value) + insertIndex += 1 return result ################################################################################ diff --git a/Networks.py b/Networks.py index d1beadec23625290e556e5d3329378d868b60dd3..6f91141de143b073b1b1310ea8467cfa20ed1ab6 100644 --- a/Networks.py +++ b/Networks.py @@ -1,19 +1,23 @@ import torch import torch.nn as nn import torch.nn.functional as F +import Features ################################################################################ class BaseNet(nn.Module): - def __init__(self, dicts, inputSize, outputSize) : + def __init__(self, dicts, outputSize) : super().__init__() self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) + self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" + self.columns = ["UPOS"] + self.embSize = 64 - self.inputSize = inputSize + self.inputSize = len(self.columns)*len(self.featureFunction.split()) self.outputSize = outputSize for name in dicts.dicts : self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) - self.fc1 = nn.Linear(inputSize * self.embSize, 1600) + self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600) self.fc2 = nn.Linear(1600, outputSize) self.dropout = nn.Dropout(0.3) @@ -32,5 +36,9 @@ class BaseNet(nn.Module): if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) + + def extractFeatures(self, dicts, config) : + return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns) + ################################################################################ diff --git a/Train.py b/Train.py index 3309201bab0f9095d1b8d8de1373aa6064b0318e..a0420047598332f4fd14694083ae4af125a56208 100644 --- a/Train.py +++ b/Train.py @@ -34,7 +34,7 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, boots ################################################################################ ################################################################################ -def extractExamples(debug, ts, strat, config, dicts, network=None) : +def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : examples = [] with torch.no_grad() : EOS = Transition("EOS") @@ -47,12 +47,12 @@ def extractExamples(debug, ts, strat, config, dicts, network=None) : break best = min([cand[0] for cand in candidates]) candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1] - features = Features.extractFeatures(dicts, config) + features = network.extractFeatures(dicts, config) candidate = candidateOracle.name if debug : config.printForDebug(sys.stderr) print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) - if network is not None : + if dynamic : output = network(features.unsqueeze(0).to(getDevice())) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] @@ -95,17 +95,17 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss ################################################################################ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : dicts = Dicts() - dicts.readConllu(filename, ["FORM", "UPOS"]) + dicts.readConllu(filename, ["UPOS"]) dicts.save(modelDir+"/dicts.json") + network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) examples = [] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : - examples += extractExamples(debug, transitionSet, strategy, config, dicts) + examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) examples = torch.stack(examples) - network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice()) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) lossFct = torch.nn.CrossEntropyLoss() @@ -117,7 +117,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr) for config in sentences : - examples += extractExamples(debug, transitionSet, strategy, config, dicts, network) + examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) examples = torch.stack(examples) @@ -154,9 +154,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.save(modelDir + "/dicts.json") - policy_net = None - target_net = None - optimizer = None + policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) + target_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) + target_net.load_state_dict(policy_net.state_dict()) + target_net.eval() + policy_net.train() + optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) + print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) bestLoss = None bestScore = None @@ -178,16 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) sentence = sentences[sentIndex] sentence.moveWordIndex(0) - state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) - - if policy_net is None : - policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice()) - target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice()) - target_net.load_state_dict(policy_net.state_dict()) - target_net.eval() - policy_net.train() - optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) - print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) + state = policy_net.extractFeatures(dicts, sentence).to(getDevice()) while True : missingLinks = getMissingLinks(sentence) @@ -209,7 +204,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti newState = None if appliable : applyTransition(transitionSet, strategy, sentence, action.name) - newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) + newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) if memory is None : memory = ReplayMemory(5000, state.numel()) diff --git a/main.py b/main.py index ef8a1ea5b732954056dd9df9acdd309cd3d52d95..c9f6518c9cb82ba2f76bea15dac677665bc49a09 100755 --- a/main.py +++ b/main.py @@ -44,6 +44,9 @@ if __name__ == "__main__" : random.seed(args.seed) torch.manual_seed(args.seed) + if args.bootstrap is not None : + args.bootstrap = int(args.bootstrap) + if args.mode == "train" : Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) elif args.mode == "decode" :