Skip to content
Snippets Groups Projects
Commit bbdc365a authored by Franck Dary's avatar Franck Dary
Browse files

Started to work on states

parent 0d5dffb6
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ class Config :
self.predicted = predicted
self.wordIndex = 0
self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack
self.state = 0 #State of the analysis (e.g. 0=tagger, 1=parser)
self.stack = []
self.comments = []
self.history = []
......@@ -83,6 +84,7 @@ class Config :
printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"]
left = 5
right = 5
print("state :", self.state, file=output)
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[str(trans) for trans in self.history[-10:]], file=output)
print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output)
......
......@@ -11,6 +11,7 @@ import torch
def randomDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
config.state = 0
while True :
candidates = [trans for trans in ts if trans.appliable(config)]
if len(candidates) == 0 :
......@@ -28,6 +29,7 @@ def randomDecode(ts, strat, config, debug=False) :
def oracleDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
config.state = 0
moved = True
while moved :
missingLinks = getMissingLinks(config)
......@@ -44,9 +46,10 @@ def oracleDecode(ts, strat, config, debug=False) :
################################################################################
################################################################################
def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
def decodeModel(transitionSets, strat, config, network, dicts, debug, rewardFunc) :
EOS = Transition("EOS")
config.moveWordIndex(0)
config.state = 0
moved = True
network.eval()
......@@ -59,6 +62,8 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
with torch.no_grad():
while moved :
ts = transitionSets[config.state]
network.setState(config.state)
features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = network(features)
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
......
......@@ -5,11 +5,12 @@ import Features
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, outputSize, incremental) :
def __init__(self, dicts, outputSizes, incremental) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5
self.suffixSize = 4
......@@ -19,15 +20,19 @@ class BaseNet(nn.Module):
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSize = outputSize
self.outputSizes = outputSizes
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize)
for i in range(len(outputSizes)) :
self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
for i in range(len(self.columns)) :
......@@ -48,7 +53,7 @@ class BaseNet(nn.Module):
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.fc2(y)
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
......@@ -70,11 +75,12 @@ class BaseNet(nn.Module):
################################################################################
class LSTMNet(nn.Module):
def __init__(self, dicts, outputSize, incremental) :
def __init__(self, dicts, outputSizes, incremental) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2"
self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5
......@@ -91,11 +97,15 @@ class LSTMNet(nn.Module):
self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True)
self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True)
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize)
for i in range(len(outputSizes)) :
self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
embeddingsLSTM = []
......@@ -116,7 +126,7 @@ class LSTMNet(nn.Module):
y = torch.cat([y, historyEmb],-1)
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.fc2(y)
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
......
......@@ -7,7 +7,7 @@ from Util import getDevice
################################################################################
class ReplayMemory() :
def __init__(self, capacity, stateSize) :
def __init__(self, capacity, stateSize, nbStates) :
self.capacity = capacity
self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
......
......@@ -32,13 +32,16 @@ def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter,
################################################################################
################################################################################
def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
examples = []
# Return list of examples for each transitionSet
def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) :
examples = [[] for _ in transitionSets]
with torch.no_grad() :
EOS = Transition("EOS")
config.moveWordIndex(0)
config.state = 0
moved = True
while moved :
ts = transitionSets[config.state]
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
if len(candidates) == 0 :
......@@ -51,6 +54,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
config.printForDebug(sys.stderr)
print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if dynamic :
netword.setState(config.state)
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
......@@ -59,7 +63,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
example = torch.cat([torch.LongTensor([goldIndex]), features])
examples.append(example)
examples[config.state].append(example)
moved = applyTransition(strat, config, candidate, None)
......@@ -94,19 +98,28 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
transitionNames = {}
for ts in transitionSets :
for t in ts :
transitionNames[str(t)] = (len(transitionNames), 0)
transitionNames[dicts.nullToken] = (len(transitionNames), 0)
dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
examples = []
network = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
examples = [[] for _ in transitionSets]
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False)
extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
for e in range(len(examples)) :
examples[e] += extracted[e]
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
for e in range(len(examples)) :
examples[e] = torch.stack(examples[e])
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=lr)
......@@ -119,9 +132,12 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True)
extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
for e in range(len(examples)) :
examples[e] += extracted[e]
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
for e in range(len(examples)) :
examples[e] = torch.stack(examples[e])
network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0)))
......@@ -145,20 +161,25 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
optimizer.step()
totalLoss += float(loss)
bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
################################################################################
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
memory = None
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
transitionNames = {}
for ts in transitionSets :
for t in ts :
transitionNames[str(t)] = (len(transitionNames), 0)
transitionNames[dicts.nullToken] = (len(transitionNames), 0)
dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
policy_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
......@@ -193,6 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
while True :
missingLinks = getMissingLinks(sentence)
transitionSet = transitionSets[sentence.state]
if debug :
sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
......@@ -208,7 +230,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
reward = torch.FloatTensor([reward_]).to(getDevice())
#newState = None
newState = None
if appliable :
applyTransition(strategy, sentence, action, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
......@@ -233,6 +255,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if i >= nbExByEpoch :
break
sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
################################################################################
......@@ -284,12 +284,14 @@ def applyTag(config, colName, tag) :
################################################################################
def applyTransition(strat, config, transition, reward) :
movement = strat[transition.name] if transition.name in strat else 0
movement = strat[transition.name][0] if transition.name in strat else 0
newState = strat[transition.name][1] if transition.name in strat else -1
transition.apply(config, strat)
moved = config.moveWordIndex(movement)
movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in transition.name :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward)
config.state = newState
return moved
################################################################################
......@@ -15,8 +15,9 @@ from Transition import Transition
from Util import isEmpty
################################################################################
def printTS(ts, output) :
print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in transitionSet], file=output)
def printTS(transitionSet, output) :
for ts in transitionSet :
print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in ts], file=output)
################################################################################
################################################################################
......@@ -78,38 +79,42 @@ if __name__ == "__main__" :
args.bootstrap = int(args.bootstrap)
if args.transitions == "eager" :
transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]
transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predicted = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
elif args.transitions == "tagparser" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0]
transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
elif args.transitions == "swift" :
transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "HEAD"
args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
else :
raise Exception("Unknown transition set '%s'"%args.transitions)
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
if args.mode == "train" :
args.predicted = set({colName for colName in args.predictedStr.split(',')})
json.dump([args.predictedStr, [str(t) for t in transitionSet]], open(args.model+"/transitions.json", "w"))
json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSet, sys.stderr)
printTS(transitionSets, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
Train.trainMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transInfos = json.load(open(args.model+"/transitions.json", "r"))
transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
args.predictedStr = transInfos[0]
args.predicted = set({colName for colName in args.predictedStr.split(',')})
transitionSet = [Transition(elem) for elem in transNames]
transitionSets = [[Transition(elem) for elem in transNames]for ts in transNames]
strategy = json.load(open(args.model+"/strategy.json", "r"))
printTS(transitionSet, sys.stderr)
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.predicted, args.model)
printTS(transitionSets, sys.stderr)
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.reward, args.predicted, args.model)
else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment