Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
import torch
################################################################################
def randomDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
while True :
candidates = [trans for trans in ts if trans.appliable(config)]
if len(candidates) == 0 :
break
candidate = candidates[random.randint(0, 100) % len(candidates)]
if debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(ts, strat, config, candidate.name)
EOS.apply(config)
################################################################################
################################################################################
def oracleDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
################################################################################
################################################################################
def decodeModel(ts, strat, config, network, dicts, debug) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
network.eval()
with torch.no_grad():
while moved :
features = extractFeatures(dicts, config).unsqueeze(0)
output = torch.nn.functional.softmax(network(features), dim=1)
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
if len(candidates) == 0 :
break
candidate = candidates[0]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
################################################################################