diff --git a/Train.py b/Train.py index 0724dbdf4b2b5071ddc844ce98ddd71d12f431c3..e56c47eed27bd4a3b0134e4ad6e12409e520ed38 100644 --- a/Train.py +++ b/Train.py @@ -254,9 +254,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF reward = torch.FloatTensor([reward_]).to(getDevice()) newState = None + toState = strategy[action.name][1] if action.name in strategy else -1 if appliable : applyTransition(strategy, sentence, action, reward_) - toState = sentence.state newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) else: count+=1