Skip to content
Snippets Groups Projects
Commit 1a127719 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected BACK action : 1) cannot create cycle 2) correct backward movement

parent bf43cbf5
No related branches found
No related tags found
No related merge requests found
...@@ -82,7 +82,7 @@ class Config : ...@@ -82,7 +82,7 @@ class Config :
right = 5 right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[trans.name for trans in self.history[-10:]], file=output) print("history :",[trans.name for trans in self.history[-10:]], file=output)
print("historyPop :",[(c[0].name,c[1]) for c in self.historyPop[-10:]], file=output) print("historyPop :",[(c[0].name,c[1],c[2]) for c in self.historyPop[-10:]], file=output)
toPrint = [] toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
......
...@@ -64,7 +64,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -64,7 +64,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
candidate = candidates[0][1] candidate = candidates[0][1]
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr) print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate) moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config, strat) EOS.apply(config, strat)
......
...@@ -18,8 +18,7 @@ class Transition : ...@@ -18,8 +18,7 @@ class Transition :
def apply(self, config, strategy) : def apply(self, config, strategy) :
data = None data = None
if "BACK" not in self.name : config.historyHistory.add(str([t[0].name for t in config.historyPop]))
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
if self.name == "RIGHT" : if self.name == "RIGHT" :
applyRight(config) applyRight(config)
...@@ -39,7 +38,7 @@ class Transition : ...@@ -39,7 +38,7 @@ class Transition :
exit(1) exit(1)
config.history.append(self) config.history.append(self)
if "BACK" not in self.name : if "BACK" not in self.name :
config.historyPop.append((self,data)) config.historyPop.append((self,data,None))
def appliable(self, config) : def appliable(self, config) :
if self.name == "RIGHT" : if self.name == "RIGHT" :
...@@ -146,8 +145,8 @@ def scoreOracleReduce(config, ml) : ...@@ -146,8 +145,8 @@ def scoreOracleReduce(config, ml) :
################################################################################ ################################################################################
def applyBack(config, strategy, size) : def applyBack(config, strategy, size) :
for i in range(size) : for i in range(size) :
trans, data = config.historyPop.pop() trans, data, movement = config.historyPop.pop()
config.moveWordIndex(-strategy[trans.name]) config.moveWordIndex(-movement)
if trans.name == "RIGHT" : if trans.name == "RIGHT" :
applyBackRight(config) applyBackRight(config)
elif trans.name == "LEFT" : elif trans.name == "LEFT" :
...@@ -236,6 +235,10 @@ def applyTransition(ts, strat, config, name) : ...@@ -236,6 +235,10 @@ def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0] transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name] if transition.name in strat else 0 movement = strat[transition.name] if transition.name in strat else 0
transition.apply(config, strat) transition.apply(config, strat)
return config.moveWordIndex(movement) moved = config.moveWordIndex(movement)
movement = movement if moved else 0
if len(config.historyPop) > 0 and "BACK" not in name :
config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement)
return moved
################################################################################ ################################################################################
...@@ -41,7 +41,7 @@ if __name__ == "__main__" : ...@@ -41,7 +41,7 @@ if __name__ == "__main__" :
os.makedirs(args.model, exist_ok=True) os.makedirs(args.model, exist_ok=True)
Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu")) Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Using device : %s"%Util.getDevice()) print("Using device : %s"%Util.getDevice(), file=sys.stderr)
random.seed(args.seed) random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment