diff --git a/Config.py b/Config.py index e7776808d6c31cd8fd1f9064af61b76a7647d847..7874c4ab2e745db42607440a77b92523bbbcd4e6 100644 --- a/Config.py +++ b/Config.py @@ -86,17 +86,16 @@ class Config : def __len__(self) : return len(self.lines) + # This print format is used by the script readTrace.py, avoid changes def printForDebug(self, output) : printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"] left = 5 right = 5 - historySize = 8 - historyPopSize = 6 print("state :", self.state, file=output) - print("stack :",[int(self.getAsFeature(ind, "ID")) for ind in self.stack], file=output) + print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("nbUndone :", self.nbUndone, file=output) - print("history :",[str(trans) for trans in self.history[-historySize:]], file=output) - print("historyPop :",[(str(c[0]),"reward:"+str(c[3])) for c in self.historyPop[-historyPopSize:]], file=output) + print("history :",[str(trans) for trans in self.history], file=output) + print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3]),"state:"+str(c[4])) for c in self.historyPop], file=output) toPrint = [] for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : if lineIndex not in range(len(self.lines)) : diff --git a/Train.py b/Train.py index de8482e62c7b40b7b4f4187f2cc76708855ed3d5..ca196577125a532909212daf27bfb76c53e1407d 100644 --- a/Train.py +++ b/Train.py @@ -271,7 +271,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF probaOracle = list_probas[fromState][1] if debug : - print("-"*80, file=sys.stderr) sentence.printForDebug(sys.stderr) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)