Skip to content
Snippets Groups Projects
Commit fb6f6ffa authored by Franck Dary's avatar Franck Dary
Browse files

ProbaOracle and probaRandom get smaller with number of epochs

parent 7729907f
Branches
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ class BaseNet(nn.Module):
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.columns = ["UPOS", "FORM"]
self.columns = ["UPOS"]
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
......
......@@ -52,7 +52,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
gamma = 0.9
gamma = 0.8
if len(memory) < batchSize :
return 0.0
......
......@@ -2,6 +2,7 @@ import sys
import random
import torch
import copy
import math
from Transition import Transition, getMissingLinks, applyTransition
import Features
......@@ -151,7 +152,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
memory = None
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
......@@ -170,6 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentIndex = 0
for epoch in range(1,nbIter+1) :
probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2)
probaOracle = round(0.3*math.exp((-epoch+1)/2), 2)
i = 0
totalLoss = 0.0
while True :
......@@ -188,7 +191,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
missingLinks = getMissingLinks(sentence)
if debug :
sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.1, probaOracle=0.1)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
if action is None :
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment