From 9d39440aab1bae35b45916e140a0139f715c393b Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 20 Jul 2021 16:06:11 +0200
Subject: [PATCH] tout

---
 Config.py     |  2 ++
 Networks.py   | 33 ++++++++++++++++++++++-----------
 Rl.py         | 21 +++++++++++++++++++++
 Train.py      |  2 +-
 Transition.py | 18 +++++++++++-------
 main.py       | 39 ++++++++++++++++++++++++++++++++-------
 6 files changed, 89 insertions(+), 26 deletions(-)

diff --git a/Config.py b/Config.py
index a675e4d..0aa7595 100644
--- a/Config.py
+++ b/Config.py
@@ -14,6 +14,7 @@ class Config :
     self.wordIndex = 0
     self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack
     self.state = 0 #State of the analysis (e.g. 0=tagger, 1=parser)
+    self.nbUndone = 0 #Number of actions that has been undone and not replaced
     self.stack = []
     self.comments = []
     self.history = []
@@ -91,6 +92,7 @@ class Config :
     right = 5
     print("state :", self.state, file=output)
     print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
+    print("nbUndone :", self.nbUndone, file=output)
     print("history :",[str(trans) for trans in self.history], file=output)
     print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3]),"state:"+str(c[4])) for c in self.historyPop], file=output)
     toPrint = []
diff --git a/Networks.py b/Networks.py
index 1c0caf4..a34561f 100644
--- a/Networks.py
+++ b/Networks.py
@@ -24,7 +24,7 @@ def createNetwork(name, dicts, outputSizes, incremental) :
   elif name == "lstm" :
     return LSTMNet(dicts, outputSizes, incremental)
   elif name == "separated" :
-    return SeparatedNet(dicts, outputSizes, incremental)
+    return SeparatedNet(dicts, outputSizes, incremental, featureFunctionAll, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize)
   elif name == "tagger" :
     return BaseNet(dicts, outputSizes, incremental, featureFunctionNostack, historyNb, suffixSize, prefixSize, columns, hiddenSize)
 
@@ -188,28 +188,29 @@ class SemiNet(nn.Module):
 
 ################################################################################
 class SeparatedNet(nn.Module):
-  def __init__(self, dicts, outputSizes, incremental) :
+  def __init__(self, dicts, outputSizes, incremental, featureFunction, historyNb, historyPopNb, suffixSize, prefixSize, columns, hiddenSize) :
     super().__init__()
     self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
 
     self.incremental = incremental
     self.state = 0
-    self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
-    self.historyNb = 5
-    self.suffixSize = 4
-    self.prefixSize = 4
-    self.columns = ["UPOS", "FORM"]
+    self.featureFunction = featureFunction
+    self.historyNb = historyNb
+    self.historyPopNb = historyPopNb
+    self.suffixSize = suffixSize
+    self.prefixSize = prefixSize
+    self.columns = columns
 
     self.embSize = 64
     self.nbTargets = len(self.featureFunction.split())
-    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
+    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.historyPopNb+self.suffixSize+self.prefixSize
     self.outputSizes = outputSizes
 
     for i in range(len(outputSizes)) :
       for name in dicts.dicts :
         self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize))
-      self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, 1600))
-      self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
+      self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, hiddenSize))
+      self.add_module("output_"+str(i), nn.Linear(hiddenSize+1, outputSizes[i]))
     self.dropout = nn.Dropout(0.3)
 
     self.apply(self.initWeights)
@@ -219,6 +220,9 @@ class SeparatedNet(nn.Module):
 
   def forward(self, x) :
     embeddings = []
+    canBack = x[...,0:1]
+    x = x[...,1:]
+
     for i in range(len(self.columns)) :
       embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
     y = torch.cat(embeddings,-1).view(x.size(0),-1)
@@ -227,6 +231,10 @@ class SeparatedNet(nn.Module):
       historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
       y = torch.cat([y, historyEmb],-1)
       curIndex = curIndex+self.historyNb
+    if self.historyPopNb > 0 :
+      historyPopEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyPopNb]).view(x.size(0),-1)
+      y = torch.cat([y, historyPopEmb],-1)
+      curIndex = curIndex+self.historyPopNb
     if self.prefixSize > 0 :
       prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
       y = torch.cat([y, prefixEmb],-1)
@@ -237,6 +245,7 @@ class SeparatedNet(nn.Module):
       curIndex = curIndex+self.suffixSize
     y = self.dropout(y)
     y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y)))
+    y = torch.cat([y,canBack], 1)
     y = getattr(self, "output_"+str(self.state))(y)
     return y
 
@@ -251,9 +260,11 @@ class SeparatedNet(nn.Module):
   def extractFeatures(self, dicts, config) :
     colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
     historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
+    historyPopValues = Features.extractHistoryPopFeatures(dicts, config, self.historyPopNb)
     prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
     suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
-    return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
+    backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
+    return torch.cat([backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues])
 
 ################################################################################
 
diff --git a/Rl.py b/Rl.py
index 5d04b3f..cb4d996 100644
--- a/Rl.py
+++ b/Rl.py
@@ -145,6 +145,27 @@ def rewardA(appliable, config, action, missingLinks):
   return reward
 ################################################################################
 
+################################################################################
+def rewardB(appliable, config, action, missingLinks):
+  if appliable:
+    if action.name != "BACK" :
+      reward = -action.getOracleScore(config, missingLinks)
+    else :
+      canceledRewards = []
+      found = 0
+      for i in range(len(config.historyPop))[::-1] :
+        if config.historyPop[i][0].name == "NOBACK" :
+          found += 1
+          if found == action.size :
+            break
+        else :
+          canceledRewards.append(config.historyPop[i][3])
+      reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1
+  else:
+    reward = -forbiddenReward
+  return (1.0 if config.nbUndone == 0 else 2.0)*reward
+################################################################################
+
 ################################################################################
 def rewardA2(appliable, config, action, missingLinks):
   if appliable:
diff --git a/Train.py b/Train.py
index 140c896..bd20efe 100644
--- a/Train.py
+++ b/Train.py
@@ -259,7 +259,7 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
         reward = torch.FloatTensor([reward_]).to(getDevice())
 
         newState = None
-        toState = strategy[action.name][1] if action.name in strategy else -1
+        toState = strategy[fromState][action.name][1] if action.name in strategy[fromState] else -1
         if appliable :
           applyTransition(strategy, sentence, action, reward_)
           newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
diff --git a/Transition.py b/Transition.py
index e4864a1..b977cc0 100644
--- a/Transition.py
+++ b/Transition.py
@@ -15,7 +15,7 @@ class Transition :
     if len(splited) == 3 :
       self.colName = splited[1]
       self.argument = splited[2]
-    if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","EOS","TAG"] :
+    if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","NOBACK","NOBACKAB","EOS","TAG"] :
       raise(Exception("'%s' is not a valid transition type."%name))
 
   def __str__(self) :
@@ -39,8 +39,9 @@ class Transition :
       applyEOS(config)
     elif self.name == "TAG" :
       applyTag(config, self.colName, self.argument)
-    elif self.name == "NOBACK" :
+    elif "NOBACK" in self.name :
       data = None
+      config.nbUndone = max(0, config.nbUndone-1)
     elif "BACK" in self.name :
       config.historyHistory.add(str([t[0].name for t in config.historyPop]))
       applyBack(config, strategy, self.size)
@@ -80,7 +81,9 @@ class Transition :
     if self.name == "TAG" :
       return isEmpty(config.getAsFeature(config.wordIndex, self.colName)) or config.getAsFeature(config.wordIndex, self.colName) == Dicts.Dicts.erased
     if self.name == "NOBACK" :
-      return True
+      return config.nbUndone == 0
+    if self.name == "NOBACKAB" :
+      return config.nbUndone != 0
     if "BACK" in self.name :
       if len([h[0].name for h in config.historyPop if "NOBACK" in h[0].name]) < self.size :
         return False
@@ -100,7 +103,7 @@ class Transition :
       return scoreOracleReduce(config, missingLinks)
     if self.name == "TAG" :
       return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1
-    if self.name == "NOBACK" :
+    if "NOBACK" in self.name :
       return 0
     if "BACK" in self.name :
       return 1
@@ -182,6 +185,7 @@ def scoreOracleReduce(config, ml) :
 ################################################################################
 def applyBack(config, strategy, size) :
   i = 0
+  config.nbUndone += size+1 
   while True :
     trans, data, movement, _, state = config.historyPop.pop()
     config.moveWordIndex(-movement)
@@ -195,7 +199,7 @@ def applyBack(config, strategy, size) :
       applyBackReduce(config, data)
     elif trans.name == "TAG" :
       applyBackTag(config, trans.colName)
-    elif trans.name == "NOBACK" :
+    elif "NOBACK" in trans.name :
       i += 1
     else :
       print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr)
@@ -301,8 +305,8 @@ def applyTag(config, colName, tag) :
 
 ################################################################################
 def applyTransition(strat, config, transition, reward) :
-  movement = strat[transition.name][0] if transition.name in strat else 0
-  newState = strat[transition.name][1] if transition.name in strat else -1
+  movement = strat[config.state][transition.name][0] if transition.name in strat[config.state] else 0
+  newState = strat[config.state][transition.name][1] if transition.name in strat[config.state] else -1
   transition.apply(config, strat)
   moved = config.moveWordIndex(movement)
   movement = movement if moved else 0
diff --git a/main.py b/main.py
index 38c2b87..56517ae 100755
--- a/main.py
+++ b/main.py
@@ -52,7 +52,7 @@ if __name__ == "__main__" :
   parser.add_argument("--silent", "-s", default=False, action="store_true",
     help="Don't print advancement infos.")
   parser.add_argument("--transitions", default="eager",
-    help="Transition set to use (eager | swift | tagparser | tagparserbt).")
+    help="Transition set to use (eager | swift | tagparser | tagparserbt | tagparserbt1 | recovery).")
   parser.add_argument("--ts", default="",
     help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
   parser.add_argument("--network", default="base",
@@ -89,7 +89,7 @@ if __name__ == "__main__" :
     transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "UPOS"
     args.states = ["tagger"]
-    strategy = {"TAG" : (1,0)}
+    strategy = [{"TAG" : (1,0)}]
     args.network = "tagger"
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "taggerbt" :
@@ -99,7 +99,7 @@ if __name__ == "__main__" :
     transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition("NOBACK"), Transition("BACK 2")]]
     args.predictedStr = "UPOS"
     args.states = ["tagger", "backer"]
-    strategy = {"TAG" : (1,1), "NOBACK" : (0,0)}
+    strategy = [{"TAG" : (1,1)}, {"NOBACK" : (0,0)}]
     args.network = "tagger"
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
@@ -107,7 +107,7 @@ if __name__ == "__main__" :
     transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "HEAD"
     args.states = ["parser"]
-    strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
+    strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}]
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "tagparser" :
     tmpDicts = Dicts()
@@ -116,7 +116,7 @@ if __name__ == "__main__" :
     transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "HEAD,UPOS"
     args.states = ["tagger", "parser"]
-    strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
+    strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1)}]
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "tagparserbt" :
@@ -126,16 +126,41 @@ if __name__ == "__main__" :
     transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("BACK 2")]]
     args.predictedStr = "HEAD,UPOS"
     args.states = ["tagger", "parser", "backer"]
-    strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)}
+    strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1)}, {"NOBACK" : (0,0)}]
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
+  elif args.transitions == "recovery" :
+    tmpDicts = Dicts()
+    tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
+    tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
+    transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("NOBACKAB"),Transition("BACK 2")], [Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]]
+    args.predictedStr = "HEAD,UPOS"
+    args.states = ["tagger", "parser", "backer", "taggerReco", "parserReco"]
+    strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1)}, {"NOBACK" : (0,0), "NOBACKAB" : (0,3)}, {"TAG" : (0,4)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,4), "REDUCE" : (0,4)}]
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))],
+              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
+]
 
+  elif args.transitions == "tagparserbt1" :
+    tmpDicts = Dicts()
+    tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
+    tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
+    transitionSets = [[Transition(elem) for elem in tagActions if len(elem) > 0], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0], [Transition("NOBACK"),Transition("BACK 1")]]
+    args.predictedStr = "HEAD,UPOS"
+    args.states = ["tagger", "parser", "backer"]
+    strategy = [{"TAG" : (0,1)}, {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1)}, {"NOBACK" : (0,0)}]
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
   elif args.transitions == "swift" :
     transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "HEAD"
     args.states = ["parser"]
-    strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
+    strategy = [{"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}]
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   else :
     raise Exception("Unknown transition set '%s'"%args.transitions)
-- 
GitLab