Skip to content
Snippets Groups Projects
Commit 394637aa authored by Maxime Petit's avatar Maxime Petit
Browse files

modified reward for back

parent b699477b
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ def randomDecode(ts, strat, config, debug=False) :
if debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(ts, strat, config, candidate.name)
applyTransition(ts, strat, config, candidate.name, 0.)
EOS.apply(config)
################################################################################
......@@ -37,7 +37,7 @@ def oracleDecode(ts, strat, config, debug=False) :
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
moved = applyTransition(ts, strat, config, candidate, 0.)
EOS.apply(config)
################################################################################
......@@ -53,6 +53,9 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
decodeDevice = getDevice()
network.to(decodeDevice)
if debug :
print("\n"+("-"*80)+"\n", file=sys.stderr)
with torch.no_grad():
while moved :
features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
......@@ -65,7 +68,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
if debug :
config.printForDebug(sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
moved = applyTransition(ts, strat, config, candidate, 0.)
EOS.apply(config, strat)
......
......@@ -76,3 +76,17 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
return float(loss)
################################################################################
################################################################################
def rewarding(appliable, config, action, missingLinks):
if appliable:
if "BACK" not in action.name :
reward = -1.0*action.getOracleScore(config, missingLinks)
else :
back = int(action.name.split()[-1])
error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
reward = last_error - back
else:
reward = -3.0
return reward
################################################################################
......@@ -8,7 +8,7 @@ from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp, prettyInt, numParameters, getDevice
from Rl import ReplayMemory, selectAction, optimizeModel
from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
import Networks
import Decode
import Config
......@@ -201,16 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
appliable = action.appliable(sentence)
# Reward for doing an illegal action
reward = -3.0
if appliable :
reward = -1.0*action.getOracleScore(sentence, missingLinks)
reward = torch.FloatTensor([reward]).to(getDevice())
reward_ = rewarding(appliable, sentence, action, missingLinks)
reward = torch.FloatTensor([reward_]).to(getDevice())
newState = None
if appliable :
applyTransition(transitionSet, strategy, sentence, action.name)
applyTransition(transitionSet, strategy, sentence, action.name, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
if memory is None :
......
......@@ -145,7 +145,7 @@ def scoreOracleReduce(config, ml) :
################################################################################
def applyBack(config, strategy, size) :
for i in range(size) :
trans, data, movement = config.historyPop.pop()
trans, data, movement, _ = config.historyPop.pop()
config.moveWordIndex(-movement)
if trans.name == "RIGHT" :
applyBackRight(config)
......@@ -231,14 +231,14 @@ def applyEOS(config) :
################################################################################
################################################################################
def applyTransition(ts, strat, config, name) :
def applyTransition(ts, strat, config, name, reward) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name] if transition.name in strat else 0
transition.apply(config, strat)
moved = config.moveWordIndex(movement)
movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in name :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement)
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward)
return moved
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment