From fb6f6ffa3a72b312ffd64b4d0b837fa9c19cbdbe Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 27 Apr 2021 13:56:38 +0200 Subject: [PATCH] ProbaOracle and probaRandom get smaller with number of epochs --- Networks.py | 2 +- Rl.py | 2 +- Train.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Networks.py b/Networks.py index 925049a..4ef91ec 100644 --- a/Networks.py +++ b/Networks.py @@ -10,7 +10,7 @@ class BaseNet(nn.Module): self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" - self.columns = ["UPOS", "FORM"] + self.columns = ["UPOS"] self.embSize = 64 self.nbTargets = len(self.featureFunction.split()) diff --git a/Rl.py b/Rl.py index d1f63d6..2ec5e66 100644 --- a/Rl.py +++ b/Rl.py @@ -52,7 +52,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : - gamma = 0.9 + gamma = 0.8 if len(memory) < batchSize : return 0.0 diff --git a/Train.py b/Train.py index 09baefc..0cc5df9 100644 --- a/Train.py +++ b/Train.py @@ -2,6 +2,7 @@ import sys import random import torch import copy +import math from Transition import Transition, getMissingLinks, applyTransition import Features @@ -151,7 +152,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti memory = None dicts = Dicts() - dicts.readConllu(filename, ["FORM", "UPOS"]) + dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.save(modelDir + "/dicts.json") policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) @@ -170,6 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti sentIndex = 0 for epoch in range(1,nbIter+1) : + probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2) + probaOracle = round(0.3*math.exp((-epoch+1)/2), 2) i = 0 totalLoss = 0.0 while True : @@ -188,7 +191,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti missingLinks = getMissingLinks(sentence) if debug : sentence.printForDebug(sys.stderr) - action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.1, probaOracle=0.1) + action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle) if action is None : break -- GitLab