From 22de30acf7603540b45d3905ec93d5582b34d445 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 4 May 2021 15:28:31 +0200 Subject: [PATCH] Improved debug print --- Config.py | 2 +- Train.py | 4 ++-- main.py | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Config.py b/Config.py index da9c361..1cbf462 100644 --- a/Config.py +++ b/Config.py @@ -82,7 +82,7 @@ class Config : right = 5 print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("history :",[trans.name for trans in self.history[-10:]], file=output) - print("historyPop :",[(c[0].name,c[1],c[2]) for c in self.historyPop[-10:]], file=output) + print("historyPop :",[(c[0].name,"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output) toPrint = [] for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : if lineIndex not in range(len(self.lines)) : diff --git a/Train.py b/Train.py index 09e1737..cac81e2 100644 --- a/Train.py +++ b/Train.py @@ -135,7 +135,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr advancement += targets.size(0) if not silent and advancement >= printInterval : advancement = 0 - print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) + print("Current epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) outputs = network(inputs) loss = lossFct(outputs, targets) network.zero_grad() @@ -182,7 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti sentIndex = 0 if not silent : - print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) + print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) sentence = sentences[sentIndex] sentence.moveWordIndex(0) state = policy_net.extractFeatures(dicts, sentence).to(getDevice()) diff --git a/main.py b/main.py index be3dfaf..8a18b49 100755 --- a/main.py +++ b/main.py @@ -38,6 +38,9 @@ if __name__ == "__main__" : help="Don't print advancement infos.") args = parser.parse_args() + if args.debug : + args.silent = True + os.makedirs(args.model, exist_ok=True) Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu")) -- GitLab