Skip to content
Snippets Groups Projects
Commit a35103cd authored by Franck Dary's avatar Franck Dary
Browse files

Working RL

parent bb5b34c5
No related branches found
No related tags found
No related merge requests found
import random import random
import torch import torch
import torch.nn.functional as F
################################################################################ ################################################################################
class ReplayMemory(object): class ReplayMemory(object):
def __init__(self, capacity): def __init__(self, capacity):
self.capacity = capacity self.capacity = capacity
self.memory = [] self.memory = []
...@@ -21,19 +21,49 @@ class ReplayMemory(object): ...@@ -21,19 +21,49 @@ class ReplayMemory(object):
def __len__(self): def __len__(self):
return len(self.memory) return len(self.memory)
################################################################################ ################################################################################
################################################################################ ################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
def selectAction(network, state, ts):
sample = random.random() sample = random.random()
if sample > .2: if sample < probaRandom :
candidates = [trans for trans in ts if trans.appliable(config)]
return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() : with torch.no_grad() :
return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name output = network(torch.stack([state]))
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
return candidates[0] if len(candidates) > 0 else None
else: ################################################################################
return ts[random.randrange(len(ts))].name
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
gamma = 0.999
if len(memory) < batchSize :
return 0.0
batch = memory.sample(batchSize)
states = torch.stack([b[0] for b in batch])
actions = torch.stack([b[1] for b in batch])
next_states = torch.stack([b[2] for b in batch])
rewards = torch.stack([b[3] for b in batch])
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1)
expectedReward = gamma*nextQ + rewards
loss = F.smooth_l1_loss(predictedQ, expectedReward)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return float(loss)
################################################################################ ################################################################################
import sys import sys
import random import random
import torch import torch
import copy
from Transition import Transition, getMissingLinks, applyTransition from Transition import Transition, getMissingLinks, applyTransition
import Features import Features
from Dicts import Dicts from Dicts import Dicts
from Util import timeStamp from Util import timeStamp
from Rl import ReplayMemory, selectAction from Rl import ReplayMemory, selectAction, optimizeModel
import Networks import Networks
import Decode import Decode
import Config import Config
...@@ -115,7 +116,9 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran ...@@ -115,7 +116,9 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################ ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :
memory = ReplayMemory(1000) memory = ReplayMemory(1000)
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.readConllu(filename, ["FORM", "UPOS"])
...@@ -125,28 +128,61 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -125,28 +128,61 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net = Networks.BaseNet(dicts, 13, len(transitionSet)) target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None bestLoss = None
bestScore = None bestScore = None
for i_episode in range(nbIter): for epoch in range(nbIter) :
sentence = sentences[i_episode%len(sentences)] i = 0
totalLoss = 0.0
sentences = copy.deepcopy(sentencesOriginal)
for sentIndex in range(len(sentences)) :
if not silent :
print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence) state = Features.extractFeaturesPosExtended(dicts, sentence)
notDone = True while True :
while notDone: missingLinks = getMissingLinks(sentence)
action = selectAction(policy_net, state, transitionSet) if debug :
print(action, file=sys.stderr) sentence.printForDebug(sys.stderr)
notDone = applyTransition(transitionSet, strategy, sentence, action) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.3, probaOracle=0.15)
reward = getReward(state, newState) if action is None :
reward = torch.tensor([reward]) break
if notDone: reward = -1.0*action.getOracleScore(sentence, missingLinks)
reward = torch.FloatTensor([reward])
applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence) newState = Features.extractFeaturesPosExtended(dicts, sentence)
else:
newState = None
memory.push((state, action, newState, reward)) memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward))
state = newState state = newState
optimizeModel() if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
if i % (2*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
i += 1
# Fin epoch, compute score and save model
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(policy_net, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################
...@@ -12,6 +12,9 @@ class Transition : ...@@ -12,6 +12,9 @@ class Transition :
exit(1) exit(1)
self.name = name self.name = name
def __lt__(self, other) :
return self.name < other.name
def apply(self, config) : def apply(self, config) :
if self.name == "RIGHT" : if self.name == "RIGHT" :
applyRight(config) applyRight(config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment