Skip to content
Snippets Groups Projects
Commit a35103cd authored by Franck Dary's avatar Franck Dary
Browse files

Working RL

parent bb5b34c5
No related branches found
No related tags found
No related merge requests found
import random
import torch
import torch.nn.functional as F
################################################################################
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
......@@ -21,19 +21,49 @@ class ReplayMemory(object):
def __len__(self):
return len(self.memory)
################################################################################
################################################################################
def selectAction(network, state, ts):
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
sample = random.random()
if sample > .2:
if sample < probaRandom :
candidates = [trans for trans in ts if trans.appliable(config)]
return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() :
return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name
output = network(torch.stack([state]))
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
return candidates[0] if len(candidates) > 0 else None
else:
return ts[random.randrange(len(ts))].name
################################################################################
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
gamma = 0.999
if len(memory) < batchSize :
return 0.0
batch = memory.sample(batchSize)
states = torch.stack([b[0] for b in batch])
actions = torch.stack([b[1] for b in batch])
next_states = torch.stack([b[2] for b in batch])
rewards = torch.stack([b[3] for b in batch])
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1)
expectedReward = gamma*nextQ + rewards
loss = F.smooth_l1_loss(predictedQ, expectedReward)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return float(loss)
################################################################################
import sys
import random
import torch
import copy
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
from Rl import ReplayMemory, selectAction
from Rl import ReplayMemory, selectAction, optimizeModel
import Networks
import Decode
import Config
......@@ -115,7 +116,9 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :
memory = ReplayMemory(1000)
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
......@@ -125,28 +128,61 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for i_episode in range(nbIter):
sentence = sentences[i_episode%len(sentences)]
for epoch in range(nbIter) :
i = 0
totalLoss = 0.0
sentences = copy.deepcopy(sentencesOriginal)
for sentIndex in range(len(sentences)) :
if not silent :
print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence)
notDone = True
while notDone:
action = selectAction(policy_net, state, transitionSet)
print(action, file=sys.stderr)
notDone = applyTransition(transitionSet, strategy, sentence, action)
reward = getReward(state, newState)
reward = torch.tensor([reward])
if notDone:
while True :
missingLinks = getMissingLinks(sentence)
if debug :
sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.3, probaOracle=0.15)
if action is None :
break
reward = -1.0*action.getOracleScore(sentence, missingLinks)
reward = torch.FloatTensor([reward])
applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence)
else:
newState = None
memory.push((state, action, newState, reward))
memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward))
state = newState
optimizeModel()
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
if i % (2*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
i += 1
# Fin epoch, compute score and save model
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(policy_net, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################
......@@ -12,6 +12,9 @@ class Transition :
exit(1)
self.name = name
def __lt__(self, other) :
return self.name < other.name
def apply(self, config) :
if self.name == "RIGHT" :
applyRight(config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment