Newer
Older
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
import Networks
import Decode
import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
examples = []
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
candidateIndex = [trans.name for trans in ts].index(candidate)
features = Features.extractFeatures(dicts, config)
example = torch.cat([torch.LongTensor([candidateIndex]), features])
examples.append(example)
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
return examples
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
UAS = res["UAS"][0].f1
score = UAS
saved = True if bestScore is None else score > bestScore
bestScore = score if bestScore is None else max(bestScore, score)
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(network, modelDir+"/network.pt")
print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
memory = ReplayMemory(1000)
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
bestLoss = None
bestScore = None
for i_episode in range(nbIter):
sentence = sentences[i_episode%len(sentences)]
state = Features.extractFeaturesPosExtended(dicts, sentence)
notDone = True
while notDone:
action = selectAction(policy_net, state, transitionSet)
print(action, file=sys.stderr)
notDone = applyTransition(transitionSet, strategy, sentence, action)
reward = getReward(state, newState)
reward = torch.tensor([reward])
if notDone:
newState = Features.extractFeaturesPosExtended(dicts, sentence)
else:
newState = None
memory.push((state, action, newState, reward))
state = newState
optimizeModel()