Skip to content
Snippets Groups Projects
Commit bb5b34c5 authored by Maxime Petit's avatar Maxime Petit
Browse files

Start RL

parent 04472a67
No related branches found
No related tags found
No related merge requests found
__pycache__
bin/*
.idea
Rl.py 0 → 100644
import random
import torch
################################################################################
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, transition):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = transition
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
################################################################################
################################################################################
def selectAction(network, state, ts):
sample = random.random()
if sample > .2:
with torch.no_grad():
return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name
else:
return ts[random.randrange(len(ts))].name
################################################################################
\ No newline at end of file
......@@ -6,6 +6,7 @@ from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
from Rl import ReplayMemory, selectAction
import Networks
import Decode
import Config
......@@ -23,6 +24,10 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
......@@ -110,3 +115,38 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
memory = ReplayMemory(1000)
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for i_episode in range(nbIter):
sentence = sentences[i_episode%len(sentences)]
state = Features.extractFeaturesPosExtended(dicts, sentence)
notDone = True
while notDone:
action = selectAction(policy_net, state, transitionSet)
print(action, file=sys.stderr)
notDone = applyTransition(transitionSet, strategy, sentence, action)
reward = getReward(state, newState)
reward = torch.tensor([reward])
if notDone:
newState = Features.extractFeaturesPosExtended(dicts, sentence)
else:
newState = None
memory.push((state, action, newState, reward))
state = newState
optimizeModel()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment