From 394637aaaa386aaabe68dc0c187bf78e0a2b94b9 Mon Sep 17 00:00:00 2001
From: Maxime Petit <maxime.petit.3@etu.univ-amu.fr>
Date: Tue, 4 May 2021 15:01:18 +0200
Subject: [PATCH] modified reward for back

---
 Decode.py     |  9 ++++++---
 Rl.py         | 14 ++++++++++++++
 Train.py      | 12 ++++--------
 Transition.py |  6 +++---
 4 files changed, 27 insertions(+), 14 deletions(-)

diff --git a/Decode.py b/Decode.py
index 3130b12..11060f2 100644
--- a/Decode.py
+++ b/Decode.py
@@ -18,7 +18,7 @@ def randomDecode(ts, strat, config, debug=False) :
     if debug :
       config.printForDebug(sys.stderr)
       print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
-    applyTransition(ts, strat, config, candidate.name)
+    applyTransition(ts, strat, config, candidate.name, 0.)
 
   EOS.apply(config)
 ################################################################################
@@ -37,7 +37,7 @@ def oracleDecode(ts, strat, config, debug=False) :
     if debug :
       config.printForDebug(sys.stderr)
       print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
-    moved = applyTransition(ts, strat, config, candidate)
+    moved = applyTransition(ts, strat, config, candidate, 0.)
 
   EOS.apply(config)
 ################################################################################
@@ -53,6 +53,9 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
   decodeDevice = getDevice()
   network.to(decodeDevice)
 
+  if debug :
+    print("\n"+("-"*80)+"\n", file=sys.stderr)
+
   with torch.no_grad():
     while moved :
       features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
@@ -65,7 +68,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
       if debug :
         config.printForDebug(sys.stderr)
         print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr)
-      moved = applyTransition(ts, strat, config, candidate)
+      moved = applyTransition(ts, strat, config, candidate, 0.)
 
   EOS.apply(config, strat)
 
diff --git a/Rl.py b/Rl.py
index 2ec5e66..29def04 100644
--- a/Rl.py
+++ b/Rl.py
@@ -76,3 +76,17 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
   return float(loss)
 ################################################################################
 
+################################################################################
+def rewarding(appliable, config, action, missingLinks):
+  if appliable:
+    if "BACK" not in action.name :
+      reward = -1.0*action.getOracleScore(config, missingLinks)
+    else :
+      back = int(action.name.split()[-1])
+      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
+      last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
+      reward = last_error - back
+  else:
+    reward = -3.0
+  return reward
+################################################################################
diff --git a/Train.py b/Train.py
index 8093ecf..09e1737 100644
--- a/Train.py
+++ b/Train.py
@@ -8,7 +8,7 @@ from Transition import Transition, getMissingLinks, applyTransition
 import Features
 from Dicts import Dicts
 from Util import timeStamp, prettyInt, numParameters, getDevice
-from Rl import ReplayMemory, selectAction, optimizeModel
+from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
 import Networks
 import Decode
 import Config
@@ -201,16 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
 
         appliable = action.appliable(sentence)
 
-        # Reward for doing an illegal action
-        reward = -3.0
-        if appliable :
-          reward = -1.0*action.getOracleScore(sentence, missingLinks)
-      
-        reward = torch.FloatTensor([reward]).to(getDevice())
+        reward_ = rewarding(appliable, sentence, action, missingLinks)
+        reward = torch.FloatTensor([reward_]).to(getDevice())
 
         newState = None
         if appliable :
-          applyTransition(transitionSet, strategy, sentence, action.name)
+          applyTransition(transitionSet, strategy, sentence, action.name, reward_)
           newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
 
         if memory is None :
diff --git a/Transition.py b/Transition.py
index 1b30b71..9607198 100644
--- a/Transition.py
+++ b/Transition.py
@@ -145,7 +145,7 @@ def scoreOracleReduce(config, ml) :
 ################################################################################
 def applyBack(config, strategy, size) :
   for i in range(size) :
-    trans, data, movement = config.historyPop.pop()
+    trans, data, movement, _ = config.historyPop.pop()
     config.moveWordIndex(-movement)
     if trans.name == "RIGHT" :
       applyBackRight(config)
@@ -231,14 +231,14 @@ def applyEOS(config) :
 ################################################################################
 
 ################################################################################
-def applyTransition(ts, strat, config, name) :
+def applyTransition(ts, strat, config, name, reward) :
   transition = [trans for trans in ts if trans.name == name][0]
   movement = strat[transition.name] if transition.name in strat else 0
   transition.apply(config, strat)
   moved = config.moveWordIndex(movement)
   movement = movement if moved else 0
   if len(config.historyPop) > 0 and "BACK" not in name :
-    config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement)
+    config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward)
   return moved
 ################################################################################
 
-- 
GitLab