Newer
Older
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
import Networks
import Decode
import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
examples = []
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
candidateIndex = [trans.name for trans in ts].index(candidate)
features = Features.extractFeatures(dicts, config)
example = torch.cat([torch.LongTensor([candidateIndex]), features])
examples.append(example)
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
return examples
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(network, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :
memory = ReplayMemory(1000)
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
bestLoss = None
bestScore = None
for epoch in range(nbIter) :
i = 0
totalLoss = 0.0
sentences = copy.deepcopy(sentencesOriginal)
for sentIndex in range(len(sentences)) :
if not silent :
print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence)
while True :
missingLinks = getMissingLinks(sentence)
if debug :
sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.3, probaOracle=0.15)
if action is None :
break
reward = -1.0*action.getOracleScore(sentence, missingLinks)
reward = torch.FloatTensor([reward])
applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward))
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
if i % (2*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
i += 1
# Fin epoch, compute score and save model
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(policy_net, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################