Skip to content
Snippets Groups Projects
Commit ead830cc authored by Maxime Petit's avatar Maxime Petit
Browse files

New parameters : Nb of unaplayable actions picked before breaking the analysis

parent 85a275f0
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,7 @@ import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) :
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
sentences = Config.readConllu(filename, predicted)
if type == "oracle" :
......@@ -24,7 +24,7 @@ def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter,
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, predicted, silent)
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
......@@ -56,15 +56,15 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
if debug :
print(str(candidate), file=sys.stderr)
goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
example = torch.cat([torch.LongTensor([goldIndex]), features])
examples.append(example)
moved = applyTransition(strat, config, candidate, None)
EOS.apply(config, strat)
return examples
################################################################################
......@@ -149,7 +149,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) :
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
memory = None
dicts = Dicts()
......@@ -189,6 +189,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentence.moveWordIndex(0)
state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
count = 0
while True :
missingLinks = getMissingLinks(sentence)
if debug :
......@@ -206,11 +208,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
reward = torch.FloatTensor([reward_]).to(getDevice())
newState = None
#newState = None
if appliable :
applyTransition(strategy, sentence, action, reward_)
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
else:
count+=1
if memory is None :
memory = ReplayMemory(5000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
......@@ -223,7 +228,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
policy_net.train()
i += 1
if state is None :
if state is None or count == countBreak:
break
if i >= nbExByEpoch :
break
......
......@@ -41,7 +41,7 @@ if __name__ == "__main__" :
parser.add_argument("--gamma", default=0.99,
help="Importance given to future rewards.")
parser.add_argument("--bootstrap", default=None,
help="If not none, extract examples in bootstrap mode (oracle train only).")
help="If not none, extract examples in bootstrap mode every n epochs (oracle train only).")
parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--incr", "-i", default=False, action="store_true",
......@@ -51,7 +51,7 @@ if __name__ == "__main__" :
parser.add_argument("--silent", "-s", default=False, action="store_true",
help="Don't print advancement infos.")
parser.add_argument("--transitions", default="eager",
help="Transition set to use (eager | swift).")
help="Transition set to use (eager | swift | tagparser).")
parser.add_argument("--ts", default="",
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--reward", default="A",
......@@ -60,6 +60,8 @@ if __name__ == "__main__" :
help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
parser.add_argument("--probaOracle", default="0.3,2,0.0",
help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
parser.add_argument("--countBreak", default=1,
help="Number of unaplayable transition picked before breaking the analysis.")
args = parser.parse_args()
if args.debug :
......@@ -99,7 +101,7 @@ if __name__ == "__main__" :
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSet, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.predicted, args.silent)
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment