From e8b9c9f0e524203fad15859b6c77b6fed60f67aa Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 12 May 2021 16:05:15 +0200
Subject: [PATCH] Added multiple reward functions and added program argument to
 chose them

---
 Decode.py | 19 +++++++++++-----
 Rl.py     | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 Train.py  | 20 ++++++++---------
 main.py   |  6 +++--
 4 files changed, 91 insertions(+), 19 deletions(-)

diff --git a/Decode.py b/Decode.py
index 720e158..9198e4d 100644
--- a/Decode.py
+++ b/Decode.py
@@ -3,6 +3,7 @@ import sys
 from Transition import Transition, getMissingLinks, applyTransition
 from Dicts import Dicts
 from Util import getDevice
+from Rl import rewarding
 import Config
 import torch
 
@@ -43,7 +44,7 @@ def oracleDecode(ts, strat, config, debug=False) :
 ################################################################################
 
 ################################################################################
-def decodeModel(ts, strat, config, network, dicts, debug) :
+def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
   EOS = Transition("EOS")
   config.moveWordIndex(0)
   moved = True
@@ -54,7 +55,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
   network.to(decodeDevice)
 
   if debug :
-    print("\n"+("-"*80)+"\n", file=sys.stderr)
+    print("\n"+("-"*80), file=sys.stderr)
 
   with torch.no_grad():
     while moved :
@@ -65,10 +66,16 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
       if len(candidates) == 0 :
         break
       candidate = candidates[0][1]
+      missingLinks = getMissingLinks(config)
       if debug :
         config.printForDebug(sys.stderr)
-        print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr)
-      moved = applyTransition(ts, strat, config, candidate, 0.)
+        print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate, file=sys.stderr)
+        candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
+        print("Oracle costs :"+str([[c[0],c[1].name] for c in candidates]), file=sys.stderr)
+        print("-"*80, file=sys.stderr)
+
+      reward = rewarding(True, config, ts[[t.name for t in ts].index(candidate)], missingLinks, rewardFunc)
+      moved = applyTransition(ts, strat, config, candidate, reward)
 
   EOS.apply(config, strat)
 
@@ -76,7 +83,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
 ################################################################################
 
 ################################################################################
-def decodeMode(debug, filename, type, transitionSet, strategy, modelDir=None, network=None, dicts=None, output=sys.stdout) :
+def decodeMode(debug, filename, type, transitionSet, strategy, rewardFunc, modelDir=None, network=None, dicts=None, output=sys.stdout) :
 
   sentences = Config.readConllu(filename)
 
@@ -93,7 +100,7 @@ def decodeMode(debug, filename, type, transitionSet, strategy, modelDir=None, ne
       dicts.load(modelDir+"/dicts.json")
       network = torch.load(modelDir+"/network.pt")
     for config in sentences :
-      decodeModel(transitionSet, strategy, config, network, dicts, debug)
+      decodeModel(transitionSet, strategy, config, network, dicts, debug, rewardFunc)
     sentences[0].print(output, header=True)
     for config in sentences[1:] :
       config.print(output, header=False)
diff --git a/Rl.py b/Rl.py
index 29def04..6f30f40 100644
--- a/Rl.py
+++ b/Rl.py
@@ -77,7 +77,12 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
 ################################################################################
 
 ################################################################################
-def rewarding(appliable, config, action, missingLinks):
+def rewarding(appliable, config, action, missingLinks, funcname):
+  return globals()["reward"+funcname](appliable, config, action, missingLinks)
+################################################################################
+
+################################################################################
+def rewardA(appliable, config, action, missingLinks):
   if appliable:
     if "BACK" not in action.name :
       reward = -1.0*action.getOracleScore(config, missingLinks)
@@ -90,3 +95,61 @@ def rewarding(appliable, config, action, missingLinks):
     reward = -3.0
   return reward
 ################################################################################
+
+################################################################################
+def rewardB(appliable, config, action, missingLinks):
+  if appliable:
+    if "BACK" not in action.name :
+      reward = 1.0 - action.getOracleScore(config, missingLinks)
+    else :
+      back = int(action.name.split()[-1])
+      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
+      last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
+      reward = last_error - back
+  else:
+    reward = -3.0
+  return reward
+################################################################################
+
+################################################################################
+def rewardC(appliable, config, action, missingLinks):
+  if appliable:
+    if "BACK" not in action.name :
+      reward = -action.getOracleScore(config, missingLinks)
+    else :
+      back = int(action.name.split()[-1])
+      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
+      canceledRewards = [h[3] for h in config.historyPop[-back:]]
+      reward = -sum(canceledRewards)
+  else:
+    reward = -3.0
+  return reward
+################################################################################
+
+################################################################################
+def rewardD(appliable, config, action, missingLinks):
+  if appliable:
+    if "BACK" not in action.name :
+      reward = -action.getOracleScore(config, missingLinks)
+    else :
+      back = int(action.name.split()[-1])
+      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
+      canceledRewards = [h[3] for h in config.historyPop[-back:]]
+      reward = -sum(canceledRewards) - 1
+  else:
+    reward = -3.0
+  return reward
+################################################################################
+
+################################################################################
+def rewardE(appliable, config, action, missingLinks):
+  if appliable:
+    if "BACK" not in action.name :
+      reward = -action.getOracleScore(config, missingLinks)
+    else :
+      reward = -0.5
+  else:
+    reward = -3.0
+  return reward
+################################################################################
+
diff --git a/Train.py b/Train.py
index 2e7d280..9a57091 100644
--- a/Train.py
+++ b/Train.py
@@ -16,15 +16,15 @@ import Config
 from conll18_ud_eval import load_conllu, evaluate
 
 ################################################################################
-def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, silent=False) :
+def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, silent=False) :
   sentences = Config.readConllu(filename)
 
   if type == "oracle" :
-    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, silent)
+    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, silent)
     return
 
   if type == "rl":
-    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, silent)
+    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, silent)
     return
 
   print("ERROR : unknown type '%s'"%type, file=sys.stderr)
@@ -70,13 +70,13 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
 ################################################################################
 
 ################################################################################
-def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental) :
+def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc) :
   devScore = ""
   saved = True if bestLoss is None else totalLoss < bestLoss
   bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
   if devFile is not None :
     outFilename = modelDir+"/predicted_dev.conllu"
-    Decode.decodeMode(debug, devFile, "model", ts, strat, modelDir, model, dicts, open(outFilename, "w"))
+    Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, modelDir, model, dicts, open(outFilename, "w"))
     res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
     UAS = res["UAS"][0].f1
     score = UAS
@@ -92,7 +92,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
 ################################################################################
 
 ################################################################################
-def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, silent=False) :
+def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, silent=False) :
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM","UPOS"], 2)
   dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
@@ -143,11 +143,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
       optimizer.step()
       totalLoss += float(loss)
 
-    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental)
+    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc)
 ################################################################################
 
 ################################################################################
-def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, silent=False) :
+def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, silent=False) :
 
   memory = None
   dicts = Dicts()
@@ -201,7 +201,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
 
         appliable = action.appliable(sentence)
 
-        reward_ = rewarding(appliable, sentence, action, missingLinks)
+        reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
         reward = torch.FloatTensor([reward_]).to(getDevice())
 
         newState = None
@@ -226,6 +226,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
       if i >= nbExByEpoch :
         break
       sentIndex += 1
-    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental)
+    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc)
 ################################################################################
 
diff --git a/main.py b/main.py
index 1c33e21..81d3118 100755
--- a/main.py
+++ b/main.py
@@ -41,6 +41,8 @@ if __name__ == "__main__" :
     help="Don't print advancement infos.")
   parser.add_argument("--ts", default="",
     help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
+  parser.add_argument("--reward", default="A",
+    help="Reward function to use (A,B,C,D,E)")
   args = parser.parse_args()
 
   if args.debug :
@@ -63,13 +65,13 @@ if __name__ == "__main__" :
     json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
     print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
-    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.silent)
+    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, args.silent)
   elif args.mode == "decode" :
     transNames = json.load(open(args.model+"/transitions.json", "r"))
     transitionSet = [Transition(elem) for elem in transNames]
     strategy = json.load(open(args.model+"/strategy.json", "r"))
     print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
-    Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model)
+    Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, args.reward)
   else :
     print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
     exit(1)
-- 
GitLab