Skip to content
Snippets Groups Projects
Commit 2038e1bc authored by Franck Dary's avatar Franck Dary
Browse files

Added hyperparameters as program arguments

parent e8b9c9f0
No related branches found
No related tags found
No related merge requests found
......@@ -51,8 +51,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
################################################################################
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
gamma = 0.8
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
if len(memory) < batchSize :
return 0.0
......
......@@ -16,15 +16,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, silent=False) :
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, silent=False) :
sentences = Config.readConllu(filename)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, silent)
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, silent)
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, silent)
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
......@@ -92,7 +92,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, silent=False) :
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
......@@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
examples = torch.stack(examples)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
optimizer = torch.optim.Adam(network.parameters(), lr=lr)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
......@@ -147,7 +147,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, silent=False) :
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, silent=False) :
memory = None
dicts = Dicts()
......@@ -160,7 +160,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
bestLoss = None
......@@ -171,8 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentIndex = 0
for epoch in range(1,nbIter+1) :
probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2)
probaOracle = round(0.3*math.exp((-epoch+1)/2), 2)
probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
i = 0
totalLoss = 0.0
while True :
......@@ -214,7 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
if i % (1*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
......
......@@ -29,6 +29,10 @@ if __name__ == "__main__" :
help="Size of each batch.")
parser.add_argument("--seed", default=100,
help="Random seed.")
parser.add_argument("--lr", default=0.0001,
help="Learning rate.")
parser.add_argument("--gamma", default=0.99,
help="Importance given to future rewards.")
parser.add_argument("--bootstrap", default=None,
help="If not none, extract examples in bootstrap mode (oracle train only).")
parser.add_argument("--dev", default=None,
......@@ -43,6 +47,10 @@ if __name__ == "__main__" :
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--reward", default="A",
help="Reward function to use (A,B,C,D,E)")
parser.add_argument("--probaRandom", default="0.6,4,0.1",
help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
parser.add_argument("--probaOracle", default="0.3,2,0.0",
help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
args = parser.parse_args()
if args.debug :
......@@ -65,7 +73,8 @@ if __name__ == "__main__" :
json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, args.silent)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent)
elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment