Skip to content
Snippets Groups Projects
Commit 394637aa authored by Maxime Petit's avatar Maxime Petit
Browse files

modified reward for back

parent b699477b
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,7 @@ def randomDecode(ts, strat, config, debug=False) : ...@@ -18,7 +18,7 @@ def randomDecode(ts, strat, config, debug=False) :
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(ts, strat, config, candidate.name) applyTransition(ts, strat, config, candidate.name, 0.)
EOS.apply(config) EOS.apply(config)
################################################################################ ################################################################################
...@@ -37,7 +37,7 @@ def oracleDecode(ts, strat, config, debug=False) : ...@@ -37,7 +37,7 @@ def oracleDecode(ts, strat, config, debug=False) :
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate) moved = applyTransition(ts, strat, config, candidate, 0.)
EOS.apply(config) EOS.apply(config)
################################################################################ ################################################################################
...@@ -53,6 +53,9 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -53,6 +53,9 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
decodeDevice = getDevice() decodeDevice = getDevice()
network.to(decodeDevice) network.to(decodeDevice)
if debug :
print("\n"+("-"*80)+"\n", file=sys.stderr)
with torch.no_grad(): with torch.no_grad():
while moved : while moved :
features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
...@@ -65,7 +68,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -65,7 +68,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr) print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate) moved = applyTransition(ts, strat, config, candidate, 0.)
EOS.apply(config, strat) EOS.apply(config, strat)
......
...@@ -76,3 +76,17 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : ...@@ -76,3 +76,17 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
return float(loss) return float(loss)
################################################################################ ################################################################################
################################################################################
def rewarding(appliable, config, action, missingLinks):
if appliable:
if "BACK" not in action.name :
reward = -1.0*action.getOracleScore(config, missingLinks)
else :
back = int(action.name.split()[-1])
error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
reward = last_error - back
else:
reward = -3.0
return reward
################################################################################
...@@ -8,7 +8,7 @@ from Transition import Transition, getMissingLinks, applyTransition ...@@ -8,7 +8,7 @@ from Transition import Transition, getMissingLinks, applyTransition
import Features import Features
from Dicts import Dicts from Dicts import Dicts
from Util import timeStamp, prettyInt, numParameters, getDevice from Util import timeStamp, prettyInt, numParameters, getDevice
from Rl import ReplayMemory, selectAction, optimizeModel from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
import Networks import Networks
import Decode import Decode
import Config import Config
...@@ -201,16 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -201,16 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
appliable = action.appliable(sentence) appliable = action.appliable(sentence)
# Reward for doing an illegal action reward_ = rewarding(appliable, sentence, action, missingLinks)
reward = -3.0 reward = torch.FloatTensor([reward_]).to(getDevice())
if appliable :
reward = -1.0*action.getOracleScore(sentence, missingLinks)
reward = torch.FloatTensor([reward]).to(getDevice())
newState = None newState = None
if appliable : if appliable :
applyTransition(transitionSet, strategy, sentence, action.name) applyTransition(transitionSet, strategy, sentence, action.name, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
if memory is None : if memory is None :
......
...@@ -145,7 +145,7 @@ def scoreOracleReduce(config, ml) : ...@@ -145,7 +145,7 @@ def scoreOracleReduce(config, ml) :
################################################################################ ################################################################################
def applyBack(config, strategy, size) : def applyBack(config, strategy, size) :
for i in range(size) : for i in range(size) :
trans, data, movement = config.historyPop.pop() trans, data, movement, _ = config.historyPop.pop()
config.moveWordIndex(-movement) config.moveWordIndex(-movement)
if trans.name == "RIGHT" : if trans.name == "RIGHT" :
applyBackRight(config) applyBackRight(config)
...@@ -231,14 +231,14 @@ def applyEOS(config) : ...@@ -231,14 +231,14 @@ def applyEOS(config) :
################################################################################ ################################################################################
################################################################################ ################################################################################
def applyTransition(ts, strat, config, name) : def applyTransition(ts, strat, config, name, reward) :
transition = [trans for trans in ts if trans.name == name][0] transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name] if transition.name in strat else 0 movement = strat[transition.name] if transition.name in strat else 0
transition.apply(config, strat) transition.apply(config, strat)
moved = config.moveWordIndex(movement) moved = config.moveWordIndex(movement)
movement = movement if moved else 0 movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in name : if len(config.historyPop) > 0 and "BACK" not in name :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement) config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward)
return moved return moved
################################################################################ ################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment