Skip to content
Snippets Groups Projects
Commit 2038e1bc authored by Franck Dary's avatar Franck Dary
Browse files

Added hyperparameters as program arguments

parent e8b9c9f0
No related branches found
No related tags found
No related merge requests found
...@@ -51,8 +51,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ...@@ -51,8 +51,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
################################################################################ ################################################################################
################################################################################ ################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
gamma = 0.8
if len(memory) < batchSize : if len(memory) < batchSize :
return 0.0 return 0.0
......
...@@ -16,15 +16,15 @@ import Config ...@@ -16,15 +16,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate from conll18_ud_eval import load_conllu, evaluate
################################################################################ ################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, silent=False) : def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, silent=False) :
sentences = Config.readConllu(filename) sentences = Config.readConllu(filename)
if type == "oracle" : if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, silent) trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, silent)
return return
if type == "rl": if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, silent) trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, silent)
return return
print("ERROR : unknown type '%s'"%type, file=sys.stderr) print("ERROR : unknown type '%s'"%type, file=sys.stderr)
...@@ -92,7 +92,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ...@@ -92,7 +92,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, silent=False) : def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) :
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
...@@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
examples = torch.stack(examples) examples = torch.stack(examples)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) optimizer = torch.optim.Adam(network.parameters(), lr=lr)
lossFct = torch.nn.CrossEntropyLoss() lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None bestLoss = None
bestScore = None bestScore = None
...@@ -147,7 +147,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -147,7 +147,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, silent=False) : def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, silent=False) :
memory = None memory = None
dicts = Dicts() dicts = Dicts()
...@@ -160,7 +160,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -160,7 +160,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
bestLoss = None bestLoss = None
...@@ -171,8 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -171,8 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentIndex = 0 sentIndex = 0
for epoch in range(1,nbIter+1) : for epoch in range(1,nbIter+1) :
probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2) probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
probaOracle = round(0.3*math.exp((-epoch+1)/2), 2) probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
i = 0 i = 0
totalLoss = 0.0 totalLoss = 0.0
while True : while True :
...@@ -214,7 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -214,7 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState state = newState
if i % batchSize == 0 : if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer) totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
if i % (1*batchSize) == 0 : if i % (1*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
......
...@@ -29,6 +29,10 @@ if __name__ == "__main__" : ...@@ -29,6 +29,10 @@ if __name__ == "__main__" :
help="Size of each batch.") help="Size of each batch.")
parser.add_argument("--seed", default=100, parser.add_argument("--seed", default=100,
help="Random seed.") help="Random seed.")
parser.add_argument("--lr", default=0.0001,
help="Learning rate.")
parser.add_argument("--gamma", default=0.99,
help="Importance given to future rewards.")
parser.add_argument("--bootstrap", default=None, parser.add_argument("--bootstrap", default=None,
help="If not none, extract examples in bootstrap mode (oracle train only).") help="If not none, extract examples in bootstrap mode (oracle train only).")
parser.add_argument("--dev", default=None, parser.add_argument("--dev", default=None,
...@@ -43,6 +47,10 @@ if __name__ == "__main__" : ...@@ -43,6 +47,10 @@ if __name__ == "__main__" :
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--reward", default="A", parser.add_argument("--reward", default="A",
help="Reward function to use (A,B,C,D,E)") help="Reward function to use (A,B,C,D,E)")
parser.add_argument("--probaRandom", default="0.6,4,0.1",
help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
parser.add_argument("--probaOracle", default="0.3,2,0.0",
help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
args = parser.parse_args() args = parser.parse_args()
if args.debug : if args.debug :
...@@ -65,7 +73,8 @@ if __name__ == "__main__" : ...@@ -65,7 +73,8 @@ if __name__ == "__main__" :
json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w")) json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w"))
print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr) print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, args.silent) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r")) transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames] transitionSet = [Transition(elem) for elem in transNames]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment