Skip to content
Snippets Groups Projects
Commit 1a127719 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected BACK action : 1) cannot create cycle 2) correct backward movement

parent bf43cbf5
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,7 @@ class Config :
right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[trans.name for trans in self.history[-10:]], file=output)
print("historyPop :",[(c[0].name,c[1]) for c in self.historyPop[-10:]], file=output)
print("historyPop :",[(c[0].name,c[1],c[2]) for c in self.historyPop[-10:]], file=output)
toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) :
......
......@@ -64,7 +64,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
candidate = candidates[0][1]
if debug :
config.printForDebug(sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config, strat)
......
......@@ -18,8 +18,7 @@ class Transition :
def apply(self, config, strategy) :
data = None
if "BACK" not in self.name :
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
if self.name == "RIGHT" :
applyRight(config)
......@@ -39,7 +38,7 @@ class Transition :
exit(1)
config.history.append(self)
if "BACK" not in self.name :
config.historyPop.append((self,data))
config.historyPop.append((self,data,None))
def appliable(self, config) :
if self.name == "RIGHT" :
......@@ -146,8 +145,8 @@ def scoreOracleReduce(config, ml) :
################################################################################
def applyBack(config, strategy, size) :
for i in range(size) :
trans, data = config.historyPop.pop()
config.moveWordIndex(-strategy[trans.name])
trans, data, movement = config.historyPop.pop()
config.moveWordIndex(-movement)
if trans.name == "RIGHT" :
applyBackRight(config)
elif trans.name == "LEFT" :
......@@ -236,6 +235,10 @@ def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name] if transition.name in strat else 0
transition.apply(config, strat)
return config.moveWordIndex(movement)
moved = config.moveWordIndex(movement)
movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in name :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement)
return moved
################################################################################
......@@ -41,7 +41,7 @@ if __name__ == "__main__" :
os.makedirs(args.model, exist_ok=True)
Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Using device : %s"%Util.getDevice())
print("Using device : %s"%Util.getDevice(), file=sys.stderr)
random.seed(args.seed)
torch.manual_seed(args.seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment