diff --git a/Train.py b/Train.py
index 7c009f6763e1df48ab4995ecd8956e7eab70cb47..de709df3a709fecb7917723722fac1a50fc48466 100644
--- a/Train.py
+++ b/Train.py
@@ -54,7 +54,7 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami
         config.printForDebug(sys.stderr)
         print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
       if dynamic :
-        netword.setState(config.state)
+        network.setState(config.state)
         output = network(features.unsqueeze(0).to(getDevice()))
         scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
         candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
@@ -117,7 +117,9 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
     extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
     for e in range(len(examples)) :
       examples[e] += extracted[e]
-  print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
+
+  totalNbExamples = sum(map(len,examples))
+  print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
   for e in range(len(examples)) :
     examples[e] = torch.stack(examples[e])
 
@@ -128,32 +130,50 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
   bestScore = None
   for epoch in range(1,nbEpochs+1) :
     if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
-      examples = []
+      examples = [[] for _ in transitionSets]
       sentences = copy.deepcopy(sentencesOriginal)
       print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
       for config in sentences :
         extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
         for e in range(len(examples)) :
           examples[e] += extracted[e]
-      print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
+      totalNbExamples = sum(map(len,examples))
+      print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
       for e in range(len(examples)) :
         examples[e] = torch.stack(examples[e])
 
     network.train()
-    examples = examples.index_select(0, torch.randperm(examples.size(0)))
+    for e in range(len(examples)) :
+      examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0)))
     totalLoss = 0.0
     nbEx = 0
     printInterval = 2000
     advancement = 0
-    for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
-      batch = examples[batchIndex:batchIndex+batchSize].to(getDevice())
+
+    distribution = [len(e)/totalNbExamples for e in examples]
+    curIndexes = [0 for _ in examples]
+
+    while True :
+      state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0]
+      if curIndexes[state] >= len(examples[state]) :
+        state = -1
+        for i in range(len(examples)) :
+          if curIndexes[i] < len(examples[i]) :
+            state = i
+      if state == -1 :
+        break
+
+      batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice())
+      curIndexes[state] += batchSize
       targets = batch[:,:1].view(-1)
       inputs = batch[:,1:]
       nbEx += targets.size(0)
       advancement += targets.size(0)
       if not silent and advancement >= printInterval :
         advancement = 0
-        print("Current epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
+        print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr)
+
+      network.setState(state)
       outputs = network(inputs)
       loss = lossFct(outputs, targets)
       network.zero_grad()