Skip to content
Snippets Groups Projects
Commit 67435b57 authored by Franck Dary's avatar Franck Dary
Browse files

States working for oracle learning

parent bbdc365a
No related branches found
No related tags found
No related merge requests found
......@@ -54,7 +54,7 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami
config.printForDebug(sys.stderr)
print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if dynamic :
netword.setState(config.state)
network.setState(config.state)
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
......@@ -117,7 +117,9 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
for e in range(len(examples)) :
examples[e] += extracted[e]
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
totalNbExamples = sum(map(len,examples))
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
for e in range(len(examples)) :
examples[e] = torch.stack(examples[e])
......@@ -128,32 +130,50 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
bestScore = None
for epoch in range(1,nbEpochs+1) :
if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
examples = []
examples = [[] for _ in transitionSets]
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
for e in range(len(examples)) :
examples[e] += extracted[e]
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
totalNbExamples = sum(map(len,examples))
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
for e in range(len(examples)) :
examples[e] = torch.stack(examples[e])
network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0)))
for e in range(len(examples)) :
examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0)))
totalLoss = 0.0
nbEx = 0
printInterval = 2000
advancement = 0
for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
batch = examples[batchIndex:batchIndex+batchSize].to(getDevice())
distribution = [len(e)/totalNbExamples for e in examples]
curIndexes = [0 for _ in examples]
while True :
state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0]
if curIndexes[state] >= len(examples[state]) :
state = -1
for i in range(len(examples)) :
if curIndexes[i] < len(examples[i]) :
state = i
if state == -1 :
break
batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice())
curIndexes[state] += batchSize
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
nbEx += targets.size(0)
advancement += targets.size(0)
if not silent and advancement >= printInterval :
advancement = 0
print("Current epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr)
network.setState(state)
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment