Skip to content
Snippets Groups Projects
Commit 17aac208 authored by Franck Dary's avatar Franck Dary
Browse files

Improved script to print results when loss is used instead of f1score

parent 442d956c
No related branches found
No related tags found
No related merge requests found
...@@ -18,9 +18,10 @@ if __name__ == "__main__" : ...@@ -18,9 +18,10 @@ if __name__ == "__main__" :
print("\t"+line,end="", file=sys.stderr) print("\t"+line,end="", file=sys.stderr)
for pathToFile in glob.iglob("" + '*stdout') : for pathToFile in glob.iglob("" + '*stdout') :
model = pathToFile.split('.')[0] splited = pathToFile.split('.')
corpus = (".".join(pathToFile.split('.')[1:])).split('.')[0] model = ".".join(splited[:-3])
index = (".".join(pathToFile.split('.')[1:])).split('.')[1] corpus = splited[-3]
index = splited[-2]
if corpus not in outputByModelScore : if corpus not in outputByModelScore :
outputByModelScore[corpus] = dict() outputByModelScore[corpus] = dict()
...@@ -50,29 +51,32 @@ if __name__ == "__main__" : ...@@ -50,29 +51,32 @@ if __name__ == "__main__" :
standardDeviation += (float(exp[2])-score)**2 standardDeviation += (float(exp[2])-score)**2
standardDeviation /= len(outputByModelScore[corpus][model][metric]) standardDeviation /= len(outputByModelScore[corpus][model][metric])
standardDeviation = math.sqrt(standardDeviation) standardDeviation = math.sqrt(standardDeviation)
baseScore = score
if standardDeviation > 0 : if standardDeviation > 0 :
score = "%.2f[±%.2f]%%"%(score,standardDeviation) score = "%.2f[±%.2f]%%"%(score,standardDeviation)
else : else :
score = "%.2f%%"%score score = "%.2f%%"%score
if '-' in score :
score = score.replace('-','').replace('%','')
output.append(outputByModelScore[corpus][model][metric][0]) output.append(outputByModelScore[corpus][model][metric][0])
output[-1][2] = score output[-1][2] = score
output[-1] = [baseScore] + output[-1]
if len(output) == 0 : if len(output) == 0 :
print("ERROR : Output length is 0", file=sys.stderr) print("ERROR : Output length is 0", file=sys.stderr)
print(" did you run evaluate.sh ?", file=sys.stderr) print(" did you run evaluate.sh ?", file=sys.stderr)
exit(1) exit(1)
output.sort()
output = [val[1:] for val in output]
maxColLens = [0 for _ in range(len(output[0]))] maxColLens = [0 for _ in range(len(output[0]))]
output = [["Corpus","Metric","F1.score","Model"]] + output output = [["Corpus","Metric","F1.score","Model"]] + output
for line in output : for line in output :
for i in range(len(line)) : for i in range(len(line)) :
maxColLens[i] = max(maxColLens[i], len(line[i])) maxColLens[i] = max(maxColLens[i], len(str(line[i])))
output = output[1:]
output.sort()
output = [["Corpus","Metric","F1.score","Model"]] + output
dashLine = '-' * 80 dashLine = '-' * 80
for i in range(len(output)) : for i in range(len(output)) :
...@@ -81,6 +85,6 @@ if __name__ == "__main__" : ...@@ -81,6 +85,6 @@ if __name__ == "__main__" :
elif i > 0 and output[i][1] != output[i-1][1] : elif i > 0 and output[i][1] != output[i-1][1] :
print("") print("")
for j in range(len(output[i])) : for j in range(len(output[i])) :
padding = (' '*(maxColLens[j]-len(output[i][j])))+" "*3 padding = (' '*(maxColLens[j]-len(str(output[i][j]))))+" "*3
print(output[i][j], end=padding) print(output[i][j], end=padding)
print("") print("")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment