Skip to content
Snippets Groups Projects
Commit 98e2ffb0 authored by Franck Dary's avatar Franck Dary
Browse files

Feature canBack is now only used when back is available

parent 1e33b269
Branches
Tags
No related merge requests found
...@@ -37,7 +37,7 @@ def getNeededDicts(name) : ...@@ -37,7 +37,7 @@ def getNeededDicts(name) :
################################################################################ ################################################################################
################################################################################ ################################################################################
def createNetwork(name, dicts, outputSizes, incremental, pretrained) : def createNetwork(name, dicts, outputSizes, incremental, pretrained, hasBack) :
featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2" featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2"
historyNb = 10 historyNb = 10
...@@ -48,26 +48,20 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained) : ...@@ -48,26 +48,20 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained) :
columns = ["UPOS", "FORM"] columns = ["UPOS", "FORM"]
if name == "base" : if name == "base" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained) return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack)
elif name == "semi" : elif name == "baseNoLetters" :
return SemiNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize) return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, 0, 0, columns, hiddenSize, pretrained, hasBack)
elif name == "big" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns, hiddenSize*2, pretrained)
elif name == "lstm" :
return LSTMNet(dicts, outputSizes, incremental)
elif name == "separated" :
return SeparatedNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize)
elif name == "tagger" : elif name == "tagger" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained) return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack)
elif name == "taggerLexicon" : elif name == "taggerLexicon" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained) return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, historyPopNb, suffixSize, prefixSize, ["UPOS","FORM","LEXICON"], hiddenSize, pretrained, hasBack)
raise Exception("Unknown network name '%s'"%name) raise Exception("Unknown network name '%s'"%name)
################################################################################ ################################################################################
################################################################################ ################################################################################
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained) : def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize, pretrained, hasBack) :
super().__init__() super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
...@@ -79,6 +73,7 @@ class BaseNet(nn.Module): ...@@ -79,6 +73,7 @@ class BaseNet(nn.Module):
self.suffixSize = suffixSize self.suffixSize = suffixSize
self.prefixSize = prefixSize self.prefixSize = prefixSize
self.columns = columns self.columns = columns
self.hasBack = hasBack
self.embSize = 64 self.embSize = 64
embSizes = {} embSizes = {}
...@@ -94,10 +89,10 @@ class BaseNet(nn.Module): ...@@ -94,10 +89,10 @@ class BaseNet(nn.Module):
else : else :
embSizes[name] = self.embSize embSizes[name] = self.embSize
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.inputSize = (self.historyNb+self.historyPopNb)*embSizes["HISTORY"]+(self.suffixSize+self.prefixSize)*embSizes["LETTER"] + sum([self.nbTargets*embSizes[col] for col in self.columns]) self.inputSize = (self.historyNb+self.historyPopNb)*embSizes.get("HISTORY",0)+(self.suffixSize+self.prefixSize)*embSizes.get("LETTER",0) + sum([self.nbTargets*embSizes.get(col,0) for col in self.columns])
self.fc1 = nn.Linear(self.inputSize, hiddenSize) self.fc1 = nn.Linear(self.inputSize, hiddenSize)
for i in range(len(outputSizes)) : for i in range(len(outputSizes)) :
self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i])) self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack else 0), outputSizes[i]))
self.dropout = nn.Dropout(0.3) self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights) self.apply(self.initWeights)
...@@ -107,6 +102,7 @@ class BaseNet(nn.Module): ...@@ -107,6 +102,7 @@ class BaseNet(nn.Module):
def forward(self, x) : def forward(self, x) :
embeddings = [] embeddings = []
if self.hasBack :
canBack = x[...,0:1] canBack = x[...,0:1]
x = x[...,1:] x = x[...,1:]
...@@ -132,6 +128,7 @@ class BaseNet(nn.Module): ...@@ -132,6 +128,7 @@ class BaseNet(nn.Module):
curIndex = curIndex+self.suffixSize curIndex = curIndex+self.suffixSize
y = self.dropout(y) y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y))) y = F.relu(self.dropout(self.fc1(y)))
if self.hasBack :
y = torch.cat([y,canBack], 1) y = torch.cat([y,canBack], 1)
y = getattr(self, "output_"+str(self.state))(y) y = getattr(self, "output_"+str(self.state))(y)
return y return y
...@@ -150,8 +147,11 @@ class BaseNet(nn.Module): ...@@ -150,8 +147,11 @@ class BaseNet(nn.Module):
historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb) historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize) prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize) suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
backAction = None
if self.hasBack :
backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int) backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues]) allFeatures = [f for f in [backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues] if f is not None]
return torch.cat(allFeatures)
################################################################################ ################################################################################
################################################################################ ################################################################################
......
...@@ -18,15 +18,15 @@ import Config ...@@ -18,15 +18,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate from conll18_ud_eval import load_conllu, evaluate
################################################################################ ################################################################################
def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False) : def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=False) :
sentences = Config.readConllu(filename, predicted) sentences = Config.readConllu(filename, predicted)
if type == "oracle" : if type == "oracle" :
trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent) trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent, hasBack)
return return
if type == "rl": if type == "rl":
trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent) trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent, hasBack)
return return
print("ERROR : unknown type '%s'"%type, file=sys.stderr) print("ERROR : unknown type '%s'"%type, file=sys.stderr)
...@@ -100,7 +100,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ...@@ -100,7 +100,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False) : def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=False) :
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, Networks.getNeededDicts(networkName), 2, pretrained) dicts.readConllu(filename, Networks.getNeededDicts(networkName), 2, pretrained)
transitionNames = {} transitionNames = {}
...@@ -111,7 +111,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize ...@@ -111,7 +111,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize
dicts.addDict("HISTORY", transitionNames) dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir+"/dicts.json") dicts.save(modelDir+"/dicts.json")
network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained, hasBack).to(getDevice())
examples = [[] for _ in transitionSets] examples = [[] for _ in transitionSets]
sentences = copy.deepcopy(sentencesOriginal) sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
...@@ -187,7 +187,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize ...@@ -187,7 +187,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False) : def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=False) :
memory = None memory = None
dicts = Dicts() dicts = Dicts()
...@@ -207,8 +207,8 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF ...@@ -207,8 +207,8 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
policy_net = torch.load(modelDir+"/lastNetwork.pt") policy_net = torch.load(modelDir+"/lastNetwork.pt")
target_net = torch.load(modelDir+"/lastNetwork.pt") target_net = torch.load(modelDir+"/lastNetwork.pt")
else : else :
policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained, hasBack).to(getDevice())
target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice()) target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained, hasBack).to(getDevice())
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
......
...@@ -85,6 +85,7 @@ if __name__ == "__main__" : ...@@ -85,6 +85,7 @@ if __name__ == "__main__" :
args.bootstrap = int(args.bootstrap) args.bootstrap = int(args.bootstrap)
networkName = args.network networkName = args.network
hasBack = False
if args.transitions == "tagger" : if args.transitions == "tagger" :
tmpDicts = Dicts() tmpDicts = Dicts()
...@@ -98,6 +99,7 @@ if __name__ == "__main__" : ...@@ -98,6 +99,7 @@ if __name__ == "__main__" :
networkName = "tagger" networkName = "tagger"
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "taggerbt" : elif args.transitions == "taggerbt" :
hasBack = True
tmpDicts = Dicts() tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
...@@ -118,6 +120,7 @@ if __name__ == "__main__" : ...@@ -118,6 +120,7 @@ if __name__ == "__main__" :
networkName = "base" networkName = "base"
probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "eagerbt" : elif args.transitions == "eagerbt" :
hasBack = True
transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]] transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]]
args.predictedStr = "HEAD" args.predictedStr = "HEAD"
args.states = ["backer", "parser"] args.states = ["backer", "parser"]
...@@ -155,6 +158,7 @@ if __name__ == "__main__" : ...@@ -155,6 +158,7 @@ if __name__ == "__main__" :
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "tagparserbt" : elif args.transitions == "tagparserbt" :
hasBack = True
tmpDicts = Dicts() tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
...@@ -168,6 +172,7 @@ if __name__ == "__main__" : ...@@ -168,6 +172,7 @@ if __name__ == "__main__" :
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))], [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]] [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
elif args.transitions == "recovery" : elif args.transitions == "recovery" :
hasBack = True
tmpDicts = Dicts() tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
...@@ -197,7 +202,7 @@ if __name__ == "__main__" : ...@@ -197,7 +202,7 @@ if __name__ == "__main__" :
json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w")) json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSets, sys.stderr) printTS(transitionSets, sys.stderr)
Train.trainMode(args.debug, networkName, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.pretrained, args.silent) Train.trainMode(args.debug, networkName, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.pretrained, args.silent, hasBack)
elif args.mode == "decode" : elif args.mode == "decode" :
transInfos = json.load(open(args.model+"/transitions.json", "r")) transInfos = json.load(open(args.model+"/transitions.json", "r"))
transNames = json.load(open(args.model+"/transitions.json", "r"))[1] transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment