Skip to content
Snippets Groups Projects
Commit bb5b34c5 authored by Maxime Petit's avatar Maxime Petit
Browse files

Start RL

parent 04472a67
No related branches found
No related tags found
No related merge requests found
__pycache__ __pycache__
bin/* bin/*
.idea
Rl.py 0 → 100644
import random
import torch
################################################################################
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, transition):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = transition
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
################################################################################
################################################################################
def selectAction(network, state, ts):
sample = random.random()
if sample > .2:
with torch.no_grad():
return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name
else:
return ts[random.randrange(len(ts))].name
################################################################################
\ No newline at end of file
...@@ -6,6 +6,7 @@ from Transition import Transition, getMissingLinks, applyTransition ...@@ -6,6 +6,7 @@ from Transition import Transition, getMissingLinks, applyTransition
import Features import Features
from Dicts import Dicts from Dicts import Dicts
from Util import timeStamp from Util import timeStamp
from Rl import ReplayMemory, selectAction
import Networks import Networks
import Decode import Decode
import Config import Config
...@@ -23,6 +24,10 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen ...@@ -23,6 +24,10 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr) print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1) exit(1)
################################################################################ ################################################################################
...@@ -110,3 +115,38 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -110,3 +115,38 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################ ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
memory = ReplayMemory(1000)
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for i_episode in range(nbIter):
sentence = sentences[i_episode%len(sentences)]
state = Features.extractFeaturesPosExtended(dicts, sentence)
notDone = True
while notDone:
action = selectAction(policy_net, state, transitionSet)
print(action, file=sys.stderr)
notDone = applyTransition(transitionSet, strategy, sentence, action)
reward = getReward(state, newState)
reward = torch.tensor([reward])
if notDone:
newState = Features.extractFeaturesPosExtended(dicts, sentence)
else:
newState = None
memory.push((state, action, newState, reward))
state = newState
optimizeModel()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment