From 61e4dc79247201bde9d6dce803f37efb12b57c10 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 12 Oct 2021 17:21:02 +0200
Subject: [PATCH] feature canback now uses the right back action

---
 Networks.py | 10 +++++-----
 Train.py    |  4 ++--
 main.py     | 10 +++++-----
 3 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/Networks.py b/Networks.py
index bf8a264..5612b99 100644
--- a/Networks.py
+++ b/Networks.py
@@ -120,7 +120,7 @@ class BaseNet(nn.Module) :
     self.inputSize = (self.historyNb+self.historyPopNb)*embSizes.get("HISTORY",0)+(self.suffixSize+self.prefixSize)*embSizes.get("LETTER",0) + sum([self.nbTargets*embSizes.get(col,0) for col in self.columns])
     self.fc1 = nn.Linear(self.inputSize, hiddenSize)
     for i in range(len(outputSizes)) :
-      self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack else 0), outputSizes[i]))
+      self.add_module("output_"+str(i), nn.Linear(hiddenSize+(1 if self.hasBack > 0 else 0), outputSizes[i]))
     self.dropout = nn.Dropout(0.3)
 
     self.apply(self.initWeights)
@@ -130,7 +130,7 @@ class BaseNet(nn.Module) :
 
   def forward(self, x) :
     embeddings = []
-    if self.hasBack :
+    if self.hasBack > 0 :
       canBack = x[...,0:1]
       x = x[...,1:]
 
@@ -156,7 +156,7 @@ class BaseNet(nn.Module) :
       curIndex = curIndex+self.suffixSize
     y = self.dropout(y)
     y = F.relu(self.dropout(self.fc1(y)))
-    if self.hasBack :
+    if self.hasBack > 0 :
       y = torch.cat([y,canBack], 1)
     y = getattr(self, "output_"+str(self.state))(y)
     return y
@@ -176,8 +176,8 @@ class BaseNet(nn.Module) :
     prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
     suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
     backAction = None
-    if self.hasBack :
-      backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK 1").appliable(config) else torch.zeros(1, dtype=torch.int)
+    if self.hasBack > 0 :
+      backAction = torch.ones(1, dtype=torch.int) if Transition.Transition("BACK %d"%self.hasBack).appliable(config) else torch.zeros(1, dtype=torch.int)
     allFeatures = [f for f in [backAction, colsValues, historyValues, historyPopValues, prefixValues, suffixValues] if f is not None]
     return torch.cat(allFeatures)
 ################################################################################
diff --git a/Train.py b/Train.py
index d83f2a6..a15c600 100644
--- a/Train.py
+++ b/Train.py
@@ -111,7 +111,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
 ################################################################################
 
 ################################################################################
-def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=False) :
+def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, pretrained, silent=False, hasBack=0) :
   dicts = Dicts()
   dicts.readConllu(filename, Networks.getNeededDicts(networkName), 2, pretrained)
   transitionNames = {}
@@ -198,7 +198,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize
 ################################################################################
 
 ################################################################################
-def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=False) :
+def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, pretrained, silent=False, hasBack=0) :
 
   memory = None
   dicts = Dicts()
diff --git a/main.py b/main.py
index 2fa2127..bc8e44d 100755
--- a/main.py
+++ b/main.py
@@ -85,7 +85,7 @@ if __name__ == "__main__" :
     args.bootstrap = int(args.bootstrap)
 
   networkName = args.network
-  hasBack = False
+  hasBack = 0
 
   if args.transitions == "tagger" :
     tmpDicts = Dicts()
@@ -99,7 +99,7 @@ if __name__ == "__main__" :
       networkName = "tagger"
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "taggerbt" :
-    hasBack = True
+    hasBack = int(args.backSize)
     tmpDicts = Dicts()
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
     tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
@@ -120,7 +120,7 @@ if __name__ == "__main__" :
       networkName = "base"
     probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "eagerbt" :
-    hasBack = True
+    hasBack = int(args.backSize)
     transitionSets = [[Transition("NOBACK"),Transition("BACK "+args.backSize)], [Transition(elem) for elem in ["SHIFT","REDUCE","LEFT","RIGHT"] if len(elem) > 0]]
     args.predictedStr = "HEAD"
     args.states = ["backer", "parser"]
@@ -173,7 +173,7 @@ if __name__ == "__main__" :
               [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "tagparserbt" :
-    hasBack = True
+    hasBack = int(args.backSize)
     tmpDicts = Dicts()
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
     tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
@@ -187,7 +187,7 @@ if __name__ == "__main__" :
              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
               [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "recovery" :
-    hasBack = True
+    hasBack = int(args.backSize)
     tmpDicts = Dicts()
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
     tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
-- 
GitLab