Skip to content
Snippets Groups Projects
Commit 62b38a8c authored by Franck Dary's avatar Franck Dary
Browse files

Added incremental mode (-i) : parser cannot see right context if it has not...

Added incremental mode (-i) : parser cannot see right context if it has not been fixated (so only accessible after a BACK)
parent d06eac84
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ class Config : ...@@ -11,6 +11,7 @@ class Config :
self.index2col = index2col self.index2col = index2col
self.predicted = set({"HEAD", "DEPREL"}) self.predicted = set({"HEAD", "DEPREL"})
self.wordIndex = 0 self.wordIndex = 0
self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack
self.stack = [] self.stack = []
self.comments = [] self.comments = []
self.history = [] self.history = []
...@@ -64,10 +65,12 @@ class Config : ...@@ -64,10 +65,12 @@ class Config :
if self.wordIndex+relMov in range(0, len((self.lines))) : if self.wordIndex+relMov in range(0, len((self.lines))) :
self.wordIndex += relMov self.wordIndex += relMov
else : else :
self.maxWordIndex = max(self.maxWordIndex, self.wordIndex)
return False return False
if self.isMultiword(self.wordIndex) : if self.isMultiword(self.wordIndex) :
self.wordIndex += relMov self.wordIndex += relMov
done += 1 done += 1
self.maxWordIndex = max(self.maxWordIndex, self.wordIndex)
return True return True
def isMultiword(self, index) : def isMultiword(self, index) :
......
...@@ -76,7 +76,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -76,7 +76,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
################################################################################ ################################################################################
################################################################################ ################################################################################
def decodeMode(debug, filename, type, transitionSet, strategy, modelDir = None, network=None, dicts=None, output=sys.stdout) : def decodeMode(debug, filename, type, transitionSet, strategy, modelDir=None, network=None, dicts=None, output=sys.stdout) :
sentences = Config.readConllu(filename) sentences = Config.readConllu(filename)
......
...@@ -12,6 +12,7 @@ class Dicts : ...@@ -12,6 +12,7 @@ class Dicts :
self.noDepLeft = "__nodepleft__" self.noDepLeft = "__nodepleft__"
self.noDepRight = "__nodepright__" self.noDepRight = "__nodepright__"
self.noGov = "__nogov__" self.noGov = "__nogov__"
self.notSeen = "__notseen__"
def addDict(self, name, d) : def addDict(self, name, d) :
if name in self.dicts : if name in self.dicts :
...@@ -38,7 +39,7 @@ class Dicts : ...@@ -38,7 +39,7 @@ class Dicts :
targetColumns = list(col2index.keys()) targetColumns = list(col2index.keys())
else : else :
targetColumns = list(colsSet) targetColumns = list(colsSet)
self.dicts = {col : {self.unkToken : (0,minCount), self.nullToken : (1,minCount), self.noStackToken : (2,minCount), self.oobToken : (3,minCount), self.noDepLeft : (4,minCount), self.noDepRight : (5,minCount), self.noGov : (6,minCount)} for col in targetColumns} self.dicts = {col : {self.unkToken : (0,minCount), self.nullToken : (1,minCount), self.noStackToken : (2,minCount), self.oobToken : (3,minCount), self.noDepLeft : (4,minCount), self.noDepRight : (5,minCount), self.noGov : (6,minCount), self.notSeen : (7,minCount)} for col in targetColumns}
splited = line.split('\t') splited = line.split('\t')
for col in targetColumns : for col in targetColumns :
......
...@@ -11,7 +11,10 @@ from Util import isEmpty ...@@ -11,7 +11,10 @@ from Util import isEmpty
# -3 : No dependent left # -3 : No dependent left
# -4 : No dependent right # -4 : No dependent right
# -5 : No gov # -5 : No gov
def extractIndexes(config, featureFunction) : # -6 : Not seen
# If incremental is true, only words that have been 'seen' (at wordIndex) can be used
# others will be marked as not seen.
def extractIndexes(config, featureFunction, incremental) :
features = featureFunction.split() features = featureFunction.split()
res = [] res = []
for feature in features : for feature in features :
...@@ -27,6 +30,8 @@ def extractIndexes(config, featureFunction) : ...@@ -27,6 +30,8 @@ def extractIndexes(config, featureFunction) :
index = -2 index = -2
else : else :
index = config.stack[-1-index] index = config.stack[-1-index]
if incremental and index > config.maxWordIndex :
index = -6
for depIndex in map(int,splited[2:]) : for depIndex in map(int,splited[2:]) :
if index < 0 : if index < 0 :
break break
...@@ -56,10 +61,10 @@ def extractIndexes(config, featureFunction) : ...@@ -56,10 +61,10 @@ def extractIndexes(config, featureFunction) :
################################################################################ ################################################################################
# For each element of the feature function and for each column, concatenante the dict index # For each element of the feature function and for each column, concatenante the dict index
def extractColsFeatures(dicts, config, featureFunction, cols) : def extractColsFeatures(dicts, config, featureFunction, cols, incremental) :
specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov} specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov, -6 : dicts.notSeen}
indexes = extractIndexes(config, featureFunction) indexes = extractIndexes(config, featureFunction, incremental)
totalSize = len(cols)*len(indexes) totalSize = len(cols)*len(indexes)
result = torch.zeros(totalSize, dtype=torch.int) result = torch.zeros(totalSize, dtype=torch.int)
......
...@@ -5,10 +5,11 @@ import Features ...@@ -5,10 +5,11 @@ import Features
################################################################################ ################################################################################
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, dicts, outputSize) : def __init__(self, dicts, outputSize, incremental) :
super().__init__() super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5 self.historyNb = 5
self.columns = ["UPOS", "FORM"] self.columns = ["UPOS", "FORM"]
...@@ -47,7 +48,7 @@ class BaseNet(nn.Module): ...@@ -47,7 +48,7 @@ class BaseNet(nn.Module):
m.bias.data.fill_(0.01) m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) : def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns) colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
return torch.cat([colsValues, historyValues]) return torch.cat([colsValues, historyValues])
......
...@@ -16,15 +16,15 @@ import Config ...@@ -16,15 +16,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate from conll18_ud_eval import load_conllu, evaluate
################################################################################ ################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) : def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, silent=False) :
sentences = Config.readConllu(filename) sentences = Config.readConllu(filename)
if type == "oracle" : if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, silent) trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, silent)
return return
if type == "rl": if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, silent)
return return
print("ERROR : unknown type '%s'"%type, file=sys.stderr) print("ERROR : unknown type '%s'"%type, file=sys.stderr)
...@@ -70,7 +70,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : ...@@ -70,7 +70,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
################################################################################ ################################################################################
################################################################################ ################################################################################
def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) : def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental) :
devScore = "" devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
...@@ -92,12 +92,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ...@@ -92,12 +92,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, silent=False) :
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir+"/dicts.json") dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
examples = [] examples = []
sentences = copy.deepcopy(sentencesOriginal) sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
...@@ -143,11 +143,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -143,11 +143,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
optimizer.step() optimizer.step()
totalLoss += float(loss) totalLoss += float(loss)
bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs) bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental)
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) : def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, silent=False) :
memory = None memory = None
dicts = Dicts() dicts = Dicts()
...@@ -155,8 +155,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -155,8 +155,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir + "/dicts.json") dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
...@@ -226,6 +226,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -226,6 +226,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if i >= nbExByEpoch : if i >= nbExByEpoch :
break break
sentIndex += 1 sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental)
################################################################################ ################################################################################
...@@ -4,7 +4,6 @@ from Util import isEmpty ...@@ -4,7 +4,6 @@ from Util import isEmpty
################################################################################ ################################################################################
class Transition : class Transition :
available = lambda self,x: x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2)
def __init__(self, name) : def __init__(self, name) :
if not self.available(name) : if not self.available(name) :
...@@ -14,6 +13,9 @@ class Transition : ...@@ -14,6 +13,9 @@ class Transition :
def __lt__(self, other) : def __lt__(self, other) :
return self.name < other.name return self.name < other.name
def available(self, x) :
return x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2)
def apply(self, config, strategy) : def apply(self, config, strategy) :
data = None data = None
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import argparse import argparse
import random import random
import torch import torch
import json
import Util import Util
import Train import Train
...@@ -32,6 +33,8 @@ if __name__ == "__main__" : ...@@ -32,6 +33,8 @@ if __name__ == "__main__" :
help="If not none, extract examples in bootstrap mode (oracle train only).") help="If not none, extract examples in bootstrap mode (oracle train only).")
parser.add_argument("--dev", default=None, parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.") help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--incr", "-i", default=False, action="store_true",
help="If true, the neural network will be 'incremenal' i.e. will not see right context words if they have never been the word index.")
parser.add_argument("--debug", "-d", default=False, action="store_true", parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.") help="Print debug infos on stderr.")
parser.add_argument("--silent", "-s", default=False, action="store_true", parser.add_argument("--silent", "-s", default=False, action="store_true",
...@@ -57,8 +60,13 @@ if __name__ == "__main__" : ...@@ -57,8 +60,13 @@ if __name__ == "__main__" :
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
if args.mode == "train" : if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames]
strategy = json.load(open(args.model+"/strategy.json", "r"))
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model) Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model)
else : else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment