Skip to content
Snippets Groups Projects
Commit b7045988 authored by Maxime Petit's avatar Maxime Petit
Browse files

Added pos tag

parent ead830cc
Branches
No related tags found
No related merge requests found
...@@ -118,11 +118,11 @@ class Config : ...@@ -118,11 +118,11 @@ class Config :
toPrint = [] toPrint = []
for colIndex in range(len(self.lines[index])) : for colIndex in range(len(self.lines[index])) :
value = str(self.getAsFeature(index, self.index2col[colIndex])) value = str(self.getAsFeature(index, self.index2col[colIndex]))
if value == "" : if value == "" or value == '_':
value = "_" value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "-1": elif self.index2col[colIndex] == "HEAD" and (value != "-1" and self.getAsFeature(index, "DEPREL") != 'root'):
value = self.getAsFeature(int(value), "ID") value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1": elif self.index2col[colIndex] == "HEAD" and (value == "-1" or self.getAsFeature(index, "DEPREL") == 'root'):
value = "0" value = "0"
toPrint.append(value) toPrint.append(value)
print("\t".join(toPrint), file=output) print("\t".join(toPrint), file=output)
......
...@@ -77,6 +77,7 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) : ...@@ -77,6 +77,7 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
reward = rewarding(True, config, candidate, missingLinks, rewardFunc) reward = rewarding(True, config, candidate, missingLinks, rewardFunc)
moved = applyTransition(strat, config, candidate, reward) moved = applyTransition(strat, config, candidate, reward)
if len(strat) > 1 :
EOS.apply(config, strat) EOS.apply(config, strat)
network.to(currentDevice) network.to(currentDevice)
......
...@@ -3,6 +3,12 @@ import torch.nn as nn ...@@ -3,6 +3,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import Features import Features
def get_network(mlp, dicts, outputSize, incremntal):
if mlp == 'POSTagNet':
return POSTagNet(dicts, outputSize, incremntal)
elif mlp == 'BaseNet':
return BaseNet(dicts, outputSize, incremntal)
################################################################################ ################################################################################
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, dicts, outputSize, incremental) : def __init__(self, dicts, outputSize, incremental) :
...@@ -134,3 +140,67 @@ class LSTMNet(nn.Module): ...@@ -134,3 +140,67 @@ class LSTMNet(nn.Module):
return torch.cat([colsValuesBase, colsValuesLSTM, historyValues]) return torch.cat([colsValuesBase, colsValuesLSTM, historyValues])
################################################################################ ################################################################################
################################################################################
class POSTagNet(nn.Module):
def __init__(self, dicts, outputSize, incremental) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2"
self.historyNb = 5
self.suffixSize = 4
self.prefixSize = 4
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSize = outputSize
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize)
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def forward(self, x) :
embeddings = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.fc2(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################
\ No newline at end of file
...@@ -16,15 +16,15 @@ import Config ...@@ -16,15 +16,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate from conll18_ud_eval import load_conllu, evaluate
################################################################################ ################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : def trainMode(debug, filename, type, transitionSet, strategy, mlp, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
sentences = Config.readConllu(filename, predicted) sentences = Config.readConllu(filename, predicted)
if type == "oracle" : if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent) trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
return return
if type == "rl": if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent) trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
return return
print("ERROR : unknown type '%s'"%type, file=sys.stderr) print("ERROR : unknown type '%s'"%type, file=sys.stderr)
...@@ -63,6 +63,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : ...@@ -63,6 +63,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
moved = applyTransition(strat, config, candidate, None) moved = applyTransition(strat, config, candidate, None)
if len(strat) > 1:
EOS.apply(config, strat) EOS.apply(config, strat)
return examples return examples
...@@ -94,12 +95,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ...@@ -94,12 +95,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) : def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2) dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir+"/dicts.json") dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) network = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
examples = [] examples = []
sentences = copy.deepcopy(sentencesOriginal) sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
...@@ -149,7 +150,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -149,7 +150,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
memory = None memory = None
dicts = Dicts() dicts = Dicts()
...@@ -157,8 +158,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -157,8 +158,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir + "/dicts.json") dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) policy_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice()) target_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
......
...@@ -51,7 +51,7 @@ if __name__ == "__main__" : ...@@ -51,7 +51,7 @@ if __name__ == "__main__" :
parser.add_argument("--silent", "-s", default=False, action="store_true", parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.") help="Don't print advancement infos.")
parser.add_argument("--transitions", default="eager", parser.add_argument("--transitions", default="eager",
help="Transition set to use (eager | swift | tagparser).") help="Transition set to use (eager | swift | tagparser | tag).")
parser.add_argument("--ts", default="", parser.add_argument("--ts", default="",
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--reward", default="A", parser.add_argument("--reward", default="A",
...@@ -86,13 +86,24 @@ if __name__ == "__main__" : ...@@ -86,13 +86,24 @@ if __name__ == "__main__" :
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0] transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0]
args.predicted = "HEAD,UPOS" args.predicted = "HEAD,UPOS"
elif args.transitions == "tag":
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s" % p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSet = [Transition(elem) for elem in (tagActions + args.ts.split(',')) if len(elem) > 0]
args.predicted = "UPOS"
elif args.transitions == "swift" : elif args.transitions == "swift" :
transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0] transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
args.predicted = "HEAD" args.predicted = "HEAD"
else : else :
raise Exception("Unknown transition set '%s'"%args.transitions) raise Exception("Unknown transition set '%s'"%args.transitions)
if args.transitions == "tag":
strategy = {"TAG": 1}
mlp = 'POSTagNet'
else:
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0} strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
mlp = 'BaseNet'
args.predicted = set({colName for colName in args.predicted.split(',')}) args.predicted = set({colName for colName in args.predicted.split(',')})
...@@ -101,7 +112,7 @@ if __name__ == "__main__" : ...@@ -101,7 +112,7 @@ if __name__ == "__main__" :
json.dump(strategy, open(args.model+"/strategy.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSet, sys.stderr) printTS(transitionSet, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, mlp, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r")) transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames] transitionSet = [Transition(elem) for elem in transNames]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment