diff --git a/Networks.py b/Networks.py index 925049a907dc274984a7d187ab3f1a998d3d94ce..4ef91eccf0179ca306c7ec76fd19f10bf976e17b 100644 --- a/Networks.py +++ b/Networks.py @@ -10,7 +10,7 @@ class BaseNet(nn.Module): self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" - self.columns = ["UPOS", "FORM"] + self.columns = ["UPOS"] self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) diff --git a/Rl.py b/Rl.py index d1f63d61c3bcde60c48326cbae66c61f0f297c5b..2ec5e66c5c49a4e969f230af8eacfe3a94becde8 100644 --- a/Rl.py +++ b/Rl.py @@ -52,7 +52,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : - gamma = 0.9 + gamma = 0.8 if len(memory) < batchSize : return 0.0 diff --git a/Train.py b/Train.py index 09baefc2a06fd98da904dd1531211e0eafb56ef4..0cc5df9c092393de64d7c832ca5eb61e15b297f8 100644 --- a/Train.py +++ b/Train.py @@ -2,6 +2,7 @@ import sys import random import torch import copy +import math from Transition import Transition, getMissingLinks, applyTransition import Features @@ -151,7 +152,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti memory = None dicts = Dicts() - dicts.readConllu(filename, ["FORM", "UPOS"]) + dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.save(modelDir + "/dicts.json") policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) @@ -170,6 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti sentIndex = 0 for epoch in range(1,nbIter+1) : + probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2) + probaOracle = round(0.3*math.exp((-epoch+1)/2), 2) i = 0 totalLoss = 0.0 while True : @@ -188,7 +191,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti missingLinks = getMissingLinks(sentence) if debug : sentence.printForDebug(sys.stderr) - action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.1, probaOracle=0.1) + action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle) if action is None : break