Skip to content
Snippets Groups Projects
Commit 6c8dcca4 authored by Franck Dary's avatar Franck Dary
Browse files

Added tagger transition set

parent 69aeae9c
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,9 @@ class Config : ...@@ -18,6 +18,9 @@ class Config :
self.history = [] self.history = []
self.historyHistory = set() self.historyHistory = set()
self.historyPop = [] self.historyPop = []
def hasCol(self, colname) :
return colname in self.col2index
def addLine(self, cols) : def addLine(self, cols) :
self.lines.append([[val,""] for val in cols]) self.lines.append([[val,""] for val in cols])
...@@ -28,8 +31,7 @@ class Config : ...@@ -28,8 +31,7 @@ class Config :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)))) raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines))))
if colname not in self.col2index : if colname not in self.col2index :
print("Unknown colname '%s'"%(colname), file=sys.stderr) raise Exception("Unknown colname '%s'"%(colname))
exit(1)
index = 1 if predicted else 0 index = 1 if predicted else 0
return self.lines[lineIndex][self.col2index[colname]][index] return self.lines[lineIndex][self.col2index[colname]][index]
...@@ -37,8 +39,7 @@ class Config : ...@@ -37,8 +39,7 @@ class Config :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines)))) raise(Exception("Line index %d is out of range (0,%d)"%(lineIndex, len(self.lines))))
if colname not in self.col2index : if colname not in self.col2index :
print("Unknown colname '%s'"%(colname), file=sys.stderr) raise Exception("Unknown colname '%s'"%(colname))
exit(1)
index = 1 if predicted else 0 index = 1 if predicted else 0
self.lines[lineIndex][self.col2index[colname]][index] = value self.lines[lineIndex][self.col2index[colname]][index] = value
...@@ -148,7 +149,7 @@ def readConllu(filename, predicted) : ...@@ -148,7 +149,7 @@ def readConllu(filename, predicted) :
continue continue
if len(line) == 0 : if len(line) == 0 :
for index in range(len(configs[-1])) : for index in range(len(configs[-1])) :
head = configs[-1].getGold(index, "HEAD") head = configs[-1].getGold(index, "HEAD") if "HEAD" in col2index else "_"
if head == "_" : if head == "_" :
continue continue
if head == "0" : if head == "0" :
......
...@@ -75,7 +75,7 @@ def decodeModel(transitionSets, strat, config, network, dicts, debug, rewardFunc ...@@ -75,7 +75,7 @@ def decodeModel(transitionSets, strat, config, network, dicts, debug, rewardFunc
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%str(candidate), file=sys.stderr) print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%str(candidate), file=sys.stderr)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"])
print("Oracle costs :"+str([[c[0],str(c[1])] for c in candidates]), file=sys.stderr) print("Oracle costs :"+str([[c[0],str(c[1])] for c in candidates]), file=sys.stderr)
print("-"*80, file=sys.stderr) print("-"*80, file=sys.stderr)
......
...@@ -5,29 +5,38 @@ import Features ...@@ -5,29 +5,38 @@ import Features
################################################################################ ################################################################################
def createNetwork(name, dicts, outputSizes, incremental) : def createNetwork(name, dicts, outputSizes, incremental) :
featureFunctionAll = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
featureFunctionNostack = "b.-2 b.-1 b.0 b.1 b.2"
historyNb = 5
suffixSize = 4
prefixSize = 4
columns = ["UPOS", "FORM"]
if name == "base" : if name == "base" :
return BaseNet(dicts, outputSizes, incremental) return BaseNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, suffixSize, prefixSize, columns)
elif name == "lstm" : elif name == "lstm" :
return LSTMNet(dicts, outputSizes, incremental) return LSTMNet(dicts, outputSizes, incremental)
elif name == "separated" : elif name == "separated" :
return SeparatedNet(dicts, outputSizes, incremental) return SeparatedNet(dicts, outputSizes, incremental)
elif name == "tagger" :
return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, suffixSize, prefixSize, columns)
raise Exception("Unknown network name '%s'"%name) raise Exception("Unknown network name '%s'"%name)
################################################################################ ################################################################################
################################################################################ ################################################################################
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental) : def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, suffixSize, prefixSize, columns) :
super().__init__() super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental self.incremental = incremental
self.state = 0 self.state = 0
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.featureFunction = featureFunction
self.historyNb = 5 self.historyNb = historyNb
self.suffixSize = 4 self.suffixSize = suffixSize
self.prefixSize = 4 self.prefixSize = prefixSize
self.columns = ["UPOS", "FORM"] self.columns = columns
self.embSize = 64 self.embSize = 64
self.nbTargets = len(self.featureFunction.split()) self.nbTargets = len(self.featureFunction.split())
......
...@@ -95,7 +95,7 @@ def rewarding(appliable, config, action, missingLinks, funcname): ...@@ -95,7 +95,7 @@ def rewarding(appliable, config, action, missingLinks, funcname):
################################################################################ ################################################################################
def rewardA(appliable, config, action, missingLinks): def rewardA(appliable, config, action, missingLinks):
if appliable: if appliable:
if "BACK" not in action.name : if action.name != "BACK" :
reward = -1.0*action.getOracleScore(config, missingLinks) reward = -1.0*action.getOracleScore(config, missingLinks)
else : else :
back = action.size back = action.size
...@@ -110,7 +110,7 @@ def rewardA(appliable, config, action, missingLinks): ...@@ -110,7 +110,7 @@ def rewardA(appliable, config, action, missingLinks):
################################################################################ ################################################################################
def rewardB(appliable, config, action, missingLinks): def rewardB(appliable, config, action, missingLinks):
if appliable: if appliable:
if "BACK" not in action.name : if action.name != "BACK" :
reward = 1.0 - action.getOracleScore(config, missingLinks) reward = 1.0 - action.getOracleScore(config, missingLinks)
else : else :
back = action.size back = action.size
...@@ -125,7 +125,7 @@ def rewardB(appliable, config, action, missingLinks): ...@@ -125,7 +125,7 @@ def rewardB(appliable, config, action, missingLinks):
################################################################################ ################################################################################
def rewardC(appliable, config, action, missingLinks): def rewardC(appliable, config, action, missingLinks):
if appliable: if appliable:
if "BACK" not in action.name : if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks) reward = -action.getOracleScore(config, missingLinks)
else : else :
back = action.size back = action.size
...@@ -140,7 +140,7 @@ def rewardC(appliable, config, action, missingLinks): ...@@ -140,7 +140,7 @@ def rewardC(appliable, config, action, missingLinks):
################################################################################ ################################################################################
def rewardD(appliable, config, action, missingLinks): def rewardD(appliable, config, action, missingLinks):
if appliable: if appliable:
if "BACK" not in action.name : if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks) reward = -action.getOracleScore(config, missingLinks)
else : else :
back = action.size back = action.size
...@@ -155,7 +155,7 @@ def rewardD(appliable, config, action, missingLinks): ...@@ -155,7 +155,7 @@ def rewardD(appliable, config, action, missingLinks):
################################################################################ ################################################################################
def rewardE(appliable, config, action, missingLinks): def rewardE(appliable, config, action, missingLinks):
if appliable: if appliable:
if "BACK" not in action.name : if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks) reward = -action.getOracleScore(config, missingLinks)
else : else :
reward = -0.5 reward = -0.5
...@@ -167,7 +167,7 @@ def rewardE(appliable, config, action, missingLinks): ...@@ -167,7 +167,7 @@ def rewardE(appliable, config, action, missingLinks):
################################################################################ ################################################################################
def rewardF(appliable, config, action, missingLinks): def rewardF(appliable, config, action, missingLinks):
if appliable: if appliable:
if "BACK" not in action.name : if action.name != "BACK" :
reward = -1.0*action.getOracleScore(config, missingLinks) reward = -1.0*action.getOracleScore(config, missingLinks)
else : else :
back = action.size back = action.size
...@@ -182,7 +182,7 @@ def rewardF(appliable, config, action, missingLinks): ...@@ -182,7 +182,7 @@ def rewardF(appliable, config, action, missingLinks):
################################################################################ ################################################################################
def rewardG(appliable, config, action, missingLinks): def rewardG(appliable, config, action, missingLinks):
if appliable: if appliable:
if "BACK" not in action.name : if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks) reward = -action.getOracleScore(config, missingLinks)
else : else :
back = action.size back = action.size
......
...@@ -43,7 +43,7 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami ...@@ -43,7 +43,7 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami
while moved : while moved :
ts = transitionSets[config.state] ts = transitionSets[config.state]
missingLinks = getMissingLinks(config) missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name]) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"])
if len(candidates) == 0 : if len(candidates) == 0 :
break break
best = min([cand[0] for cand in candidates]) best = min([cand[0] for cand in candidates])
......
...@@ -47,7 +47,7 @@ class Transition : ...@@ -47,7 +47,7 @@ class Transition :
print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
exit(1) exit(1)
config.history.append(self) config.history.append(self)
if "BACK" not in self.name : if self.name != "BACK" :
config.historyPop.append((self,data,None, None, config.state)) config.historyPop.append((self,data,None, None, config.state))
def appliable(self, config) : def appliable(self, config) :
...@@ -111,6 +111,8 @@ class Transition : ...@@ -111,6 +111,8 @@ class Transition :
################################################################################ ################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions # Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) : def getMissingLinks(config) :
if not config.hasCol("HEAD") :
return {}
return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}} return {**{"StackRight"+str(n) : nbLinksStackRight(config, n) for n in range(1,6)}, **{"BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}}
################################################################################ ################################################################################
...@@ -191,7 +193,7 @@ def applyBack(config, strategy, size) : ...@@ -191,7 +193,7 @@ def applyBack(config, strategy, size) :
applyBackReduce(config, data) applyBackReduce(config, data)
elif trans.name == "TAG" : elif trans.name == "TAG" :
applyBackTag(config, trans.colName) applyBackTag(config, trans.colName)
else : elif trans.name != "NOBACK" :
print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr) print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr)
exit(1) exit(1)
config.state = state config.state = state
...@@ -264,6 +266,9 @@ def applyReduce(config) : ...@@ -264,6 +266,9 @@ def applyReduce(config) :
################################################################################ ################################################################################
def applyEOS(config) : def applyEOS(config) :
if not config.hasCol("HEAD") :
return
rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] rootCandidates = [index for index in config.stack if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
if len(rootCandidates) == 0 : if len(rootCandidates) == 0 :
rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))] rootCandidates = [index for index in range(len(config.lines)) if not config.isMultiword(index) and isEmpty(config.getAsFeature(index, "HEAD"))]
...@@ -296,9 +301,9 @@ def applyTransition(strat, config, transition, reward) : ...@@ -296,9 +301,9 @@ def applyTransition(strat, config, transition, reward) :
transition.apply(config, strat) transition.apply(config, strat)
moved = config.moveWordIndex(movement) moved = config.moveWordIndex(movement)
movement = movement if moved else 0 movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in transition.name : if len(config.historyPop) > 0 and transition.name != "BACK" :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward, config.historyPop[-1][4]) config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward, config.historyPop[-1][4])
if "BACK" not in transition.name : if transition.name != "BACK" :
config.state = newState config.state = newState
return moved return moved
################################################################################ ################################################################################
......
...@@ -56,7 +56,7 @@ if __name__ == "__main__" : ...@@ -56,7 +56,7 @@ if __name__ == "__main__" :
parser.add_argument("--ts", default="", parser.add_argument("--ts", default="",
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--network", default="base", parser.add_argument("--network", default="base",
help="Name of the neural network to use (base | lstm | separated).") help="Name of the neural network to use (base | lstm | separated | tagger).")
parser.add_argument("--reward", default="A", parser.add_argument("--reward", default="A",
help="Reward function to use (A,B,C,D,E)") help="Reward function to use (A,B,C,D,E)")
parser.add_argument("--probaRandom", default="0.6,4,0.1", parser.add_argument("--probaRandom", default="0.6,4,0.1",
...@@ -80,9 +80,27 @@ if __name__ == "__main__" : ...@@ -80,9 +80,27 @@ if __name__ == "__main__" :
if args.bootstrap is not None : if args.bootstrap is not None :
args.bootstrap = int(args.bootstrap) args.bootstrap = int(args.bootstrap)
if args.transitions == "eager" : if args.transitions == "tagger" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0]]
args.predictedStr = "UPOS"
args.states = ["tagger"]
strategy = {"TAG" : (1,0)}
args.network = "tagger"
elif args.transitions == "taggerbt" :
tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition("NOBACK"), Transition("BACK 3")]]
args.predictedStr = "UPOS"
args.states = ["tagger", "backer"]
strategy = {"TAG" : (1,1), "NOBACK" : (0,0)}
args.network = "tagger"
elif args.transitions == "eager" :
transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]] transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
args.predicted = "HEAD" args.predictedStr = "HEAD"
args.states = ["parser"] args.states = ["parser"]
strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)} strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
elif args.transitions == "tagparser" : elif args.transitions == "tagparser" :
...@@ -97,7 +115,7 @@ if __name__ == "__main__" : ...@@ -97,7 +115,7 @@ if __name__ == "__main__" :
tmpDicts = Dicts() tmpDicts = Dicts()
tmpDicts.readConllu(args.corpus, ["UPOS"], 0) tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)] tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], ["NOBACK","BACK 4"]] transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("BACK 4")]]
args.predictedStr = "HEAD,UPOS" args.predictedStr = "HEAD,UPOS"
args.states = ["tagger", "parser", "backer"] args.states = ["tagger", "parser", "backer"]
strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)} strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment