Skip to content
Snippets Groups Projects
Commit 22de30ac authored by Franck Dary's avatar Franck Dary
Browse files

Improved debug print

parent 394637aa
No related branches found
No related tags found
No related merge requests found
...@@ -82,7 +82,7 @@ class Config : ...@@ -82,7 +82,7 @@ class Config :
right = 5 right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[trans.name for trans in self.history[-10:]], file=output) print("history :",[trans.name for trans in self.history[-10:]], file=output)
print("historyPop :",[(c[0].name,c[1],c[2]) for c in self.historyPop[-10:]], file=output) print("historyPop :",[(c[0].name,"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output)
toPrint = [] toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
......
...@@ -135,7 +135,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -135,7 +135,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
advancement += targets.size(0) advancement += targets.size(0)
if not silent and advancement >= printInterval : if not silent and advancement >= printInterval :
advancement = 0 advancement = 0
print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr) print("Current epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
outputs = network(inputs) outputs = network(inputs)
loss = lossFct(outputs, targets) loss = lossFct(outputs, targets)
network.zero_grad() network.zero_grad()
...@@ -182,7 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -182,7 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentIndex = 0 sentIndex = 0
if not silent : if not silent :
print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
sentence = sentences[sentIndex] sentence = sentences[sentIndex]
sentence.moveWordIndex(0) sentence.moveWordIndex(0)
state = policy_net.extractFeatures(dicts, sentence).to(getDevice()) state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
......
...@@ -38,6 +38,9 @@ if __name__ == "__main__" : ...@@ -38,6 +38,9 @@ if __name__ == "__main__" :
help="Don't print advancement infos.") help="Don't print advancement infos.")
args = parser.parse_args() args = parser.parse_args()
if args.debug :
args.silent = True
os.makedirs(args.model, exist_ok=True) os.makedirs(args.model, exist_ok=True)
Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu")) Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment