Skip to content
Snippets Groups Projects
Commit b7045988 authored by Maxime Petit's avatar Maxime Petit
Browse files

Added pos tag

parent ead830cc
No related branches found
No related tags found
No related merge requests found
......@@ -118,11 +118,11 @@ class Config :
toPrint = []
for colIndex in range(len(self.lines[index])) :
value = str(self.getAsFeature(index, self.index2col[colIndex]))
if value == "" :
if value == "" or value == '_':
value = "_"
elif self.index2col[colIndex] == "HEAD" and value != "-1":
elif self.index2col[colIndex] == "HEAD" and (value != "-1" and self.getAsFeature(index, "DEPREL") != 'root'):
value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1":
elif self.index2col[colIndex] == "HEAD" and (value == "-1" or self.getAsFeature(index, "DEPREL") == 'root'):
value = "0"
toPrint.append(value)
print("\t".join(toPrint), file=output)
......
......@@ -77,6 +77,7 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
reward = rewarding(True, config, candidate, missingLinks, rewardFunc)
moved = applyTransition(strat, config, candidate, reward)
if len(strat) > 1 :
EOS.apply(config, strat)
network.to(currentDevice)
......
......@@ -3,6 +3,12 @@ import torch.nn as nn
import torch.nn.functional as F
import Features
def get_network(mlp, dicts, outputSize, incremntal):
if mlp == 'POSTagNet':
return POSTagNet(dicts, outputSize, incremntal)
elif mlp == 'BaseNet':
return BaseNet(dicts, outputSize, incremntal)
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, outputSize, incremental) :
......@@ -134,3 +140,67 @@ class LSTMNet(nn.Module):
return torch.cat([colsValuesBase, colsValuesLSTM, historyValues])
################################################################################
################################################################################
class POSTagNet(nn.Module):
def __init__(self, dicts, outputSize, incremental) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2"
self.historyNb = 5
self.suffixSize = 4
self.prefixSize = 4
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSize = outputSize
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize)
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def forward(self, x) :
embeddings = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.fc2(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################
\ No newline at end of file
......@@ -16,15 +16,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
def trainMode(debug, filename, type, transitionSet, strategy, mlp, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
sentences = Config.readConllu(filename, predicted)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
......@@ -63,6 +63,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
moved = applyTransition(strat, config, candidate, None)
if len(strat) > 1:
EOS.apply(config, strat)
return examples
......@@ -94,12 +95,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
network = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
......@@ -149,7 +150,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
memory = None
dicts = Dicts()
......@@ -157,8 +158,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
policy_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
target_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
......
......@@ -51,7 +51,7 @@ if __name__ == "__main__" :
parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.")
parser.add_argument("--transitions", default="eager",
help="Transition set to use (eager | swift | tagparser).")
help="Transition set to use (eager | swift | tagparser | tag).")
parser.add_argument("--ts", default="",
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--reward", default="A",
......@@ -86,13 +86,24 @@ if __name__ == "__main__" :
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0]
args.predicted = "HEAD,UPOS"
elif args.transitions == "tag":
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s" % p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSet = [Transition(elem) for elem in (tagActions + args.ts.split(',')) if len(elem) > 0]
args.predicted = "UPOS"
elif args.transitions == "swift" :
transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
args.predicted = "HEAD"
else :
raise Exception("Unknown transition set '%s'"%args.transitions)
if args.transitions == "tag":
strategy = {"TAG": 1}
mlp = 'POSTagNet'
else:
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
mlp = 'BaseNet'
args.predicted = set({colName for colName in args.predicted.split(',')})
......@@ -101,7 +112,7 @@ if __name__ == "__main__" :
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSet, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, mlp, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment