diff --git a/Train.py b/Train.py index 7c009f6763e1df48ab4995ecd8956e7eab70cb47..de709df3a709fecb7917723722fac1a50fc48466 100644 --- a/Train.py +++ b/Train.py @@ -54,7 +54,7 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami config.printForDebug(sys.stderr) print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) if dynamic : - netword.setState(config.state) + network.setState(config.state) output = network(features.unsqueeze(0).to(getDevice())) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1] candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] @@ -117,7 +117,9 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False) for e in range(len(examples)) : examples[e] += extracted[e] - print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) + + totalNbExamples = sum(map(len,examples)) + print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr) for e in range(len(examples)) : examples[e] = torch.stack(examples[e]) @@ -128,32 +130,50 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr bestScore = None for epoch in range(1,nbEpochs+1) : if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 : - examples = [] + examples = [[] for _ in transitionSets] sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr) for config in sentences : extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True) for e in range(len(examples)) : examples[e] += extracted[e] - print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) + totalNbExamples = sum(map(len,examples)) + print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr) for e in range(len(examples)) : examples[e] = torch.stack(examples[e]) network.train() - examples = examples.index_select(0, torch.randperm(examples.size(0))) + for e in range(len(examples)) : + examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0))) totalLoss = 0.0 nbEx = 0 printInterval = 2000 advancement = 0 - for batchIndex in range(0,examples.size(0)-batchSize,batchSize) : - batch = examples[batchIndex:batchIndex+batchSize].to(getDevice()) + + distribution = [len(e)/totalNbExamples for e in examples] + curIndexes = [0 for _ in examples] + + while True : + state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0] + if curIndexes[state] >= len(examples[state]) : + state = -1 + for i in range(len(examples)) : + if curIndexes[i] < len(examples[i]) : + state = i + if state == -1 : + break + + batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice()) + curIndexes[state] += batchSize targets = batch[:,:1].view(-1) inputs = batch[:,1:] nbEx += targets.size(0) advancement += targets.size(0) if not silent and advancement >= printInterval : advancement = 0 - print("Current epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) + print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr) + + network.setState(state) outputs = network(inputs) loss = lossFct(outputs, targets) network.zero_grad()