Skip to content
Snippets Groups Projects
Commit eca1095c authored by Franck Dary's avatar Franck Dary
Browse files

Merge branch 'master' into Rl

parents e59c75b3 8cc02de5
No related branches found
No related tags found
No related merge requests found
...@@ -3,4 +3,4 @@ bin/* ...@@ -3,4 +3,4 @@ bin/*
.idea .idea
total_test_gold.conllu total_test_gold.conllu
total_test_predicted.conllu total_test_predicted.conllu
venv/* venv
\ No newline at end of file
...@@ -11,6 +11,7 @@ class Config : ...@@ -11,6 +11,7 @@ class Config :
self.index2col = index2col self.index2col = index2col
self.predicted = set({"HEAD", "DEPREL"}) self.predicted = set({"HEAD", "DEPREL"})
self.wordIndex = 0 self.wordIndex = 0
self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack
self.stack = [] self.stack = []
self.comments = [] self.comments = []
self.history = [] self.history = []
...@@ -64,10 +65,12 @@ class Config : ...@@ -64,10 +65,12 @@ class Config :
if self.wordIndex+relMov in range(0, len((self.lines))) : if self.wordIndex+relMov in range(0, len((self.lines))) :
self.wordIndex += relMov self.wordIndex += relMov
else : else :
self.maxWordIndex = max(self.maxWordIndex, self.wordIndex)
return False return False
if self.isMultiword(self.wordIndex) : if self.isMultiword(self.wordIndex) :
self.wordIndex += relMov self.wordIndex += relMov
done += 1 done += 1
self.maxWordIndex = max(self.maxWordIndex, self.wordIndex)
return True return True
def isMultiword(self, index) : def isMultiword(self, index) :
......
...@@ -12,6 +12,7 @@ class Dicts : ...@@ -12,6 +12,7 @@ class Dicts :
self.noDepLeft = "__nodepleft__" self.noDepLeft = "__nodepleft__"
self.noDepRight = "__nodepright__" self.noDepRight = "__nodepright__"
self.noGov = "__nogov__" self.noGov = "__nogov__"
self.notSeen = "__notseen__"
def addDict(self, name, d) : def addDict(self, name, d) :
if name in self.dicts : if name in self.dicts :
...@@ -38,7 +39,7 @@ class Dicts : ...@@ -38,7 +39,7 @@ class Dicts :
targetColumns = list(col2index.keys()) targetColumns = list(col2index.keys())
else : else :
targetColumns = list(colsSet) targetColumns = list(colsSet)
self.dicts = {col : {self.unkToken : (0,minCount), self.nullToken : (1,minCount), self.noStackToken : (2,minCount), self.oobToken : (3,minCount), self.noDepLeft : (4,minCount), self.noDepRight : (5,minCount), self.noGov : (6,minCount)} for col in targetColumns} self.dicts = {col : {self.unkToken : (0,minCount), self.nullToken : (1,minCount), self.noStackToken : (2,minCount), self.oobToken : (3,minCount), self.noDepLeft : (4,minCount), self.noDepRight : (5,minCount), self.noGov : (6,minCount), self.notSeen : (7,minCount)} for col in targetColumns}
splited = line.split('\t') splited = line.split('\t')
for col in targetColumns : for col in targetColumns :
......
...@@ -11,7 +11,10 @@ from Util import isEmpty ...@@ -11,7 +11,10 @@ from Util import isEmpty
# -3 : No dependent left # -3 : No dependent left
# -4 : No dependent right # -4 : No dependent right
# -5 : No gov # -5 : No gov
def extractIndexes(config, featureFunction) : # -6 : Not seen
# If incremental is true, only words that have been 'seen' (at wordIndex) can be used
# others will be marked as not seen.
def extractIndexes(config, featureFunction, incremental) :
features = featureFunction.split() features = featureFunction.split()
res = [] res = []
for feature in features : for feature in features :
...@@ -27,6 +30,8 @@ def extractIndexes(config, featureFunction) : ...@@ -27,6 +30,8 @@ def extractIndexes(config, featureFunction) :
index = -2 index = -2
else : else :
index = config.stack[-1-index] index = config.stack[-1-index]
if incremental and index > config.maxWordIndex :
index = -6
for depIndex in map(int,splited[2:]) : for depIndex in map(int,splited[2:]) :
if index < 0 : if index < 0 :
break break
...@@ -56,10 +61,10 @@ def extractIndexes(config, featureFunction) : ...@@ -56,10 +61,10 @@ def extractIndexes(config, featureFunction) :
################################################################################ ################################################################################
# For each element of the feature function and for each column, concatenante the dict index # For each element of the feature function and for each column, concatenante the dict index
def extractColsFeatures(dicts, config, featureFunction, cols) : def extractColsFeatures(dicts, config, featureFunction, cols, incremental) :
specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov} specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov, -6 : dicts.notSeen}
indexes = extractIndexes(config, featureFunction) indexes = extractIndexes(config, featureFunction, incremental)
totalSize = len(cols)*len(indexes) totalSize = len(cols)*len(indexes)
result = torch.zeros(totalSize, dtype=torch.int) result = torch.zeros(totalSize, dtype=torch.int)
......
...@@ -5,10 +5,11 @@ import Features ...@@ -5,10 +5,11 @@ import Features
################################################################################ ################################################################################
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, dicts, outputSize) : def __init__(self, dicts, outputSize, incremental) :
super().__init__() super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5 self.historyNb = 5
self.columns = ["UPOS", "FORM"] self.columns = ["UPOS", "FORM"]
...@@ -47,7 +48,7 @@ class BaseNet(nn.Module): ...@@ -47,7 +48,7 @@ class BaseNet(nn.Module):
m.bias.data.fill_(0.01) m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) : def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns) colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb) historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
return torch.cat([colsValues, historyValues]) return torch.cat([colsValues, historyValues])
......
...@@ -16,15 +16,15 @@ import Config ...@@ -16,15 +16,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate from conll18_ud_eval import load_conllu, evaluate
################################################################################ ################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) : def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, silent=False) :
sentences = Config.readConllu(filename) sentences = Config.readConllu(filename)
if type == "oracle" : if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, silent) trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, silent)
return return
if type == "rl": if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, silent)
return return
print("ERROR : unknown type '%s'"%type, file=sys.stderr) print("ERROR : unknown type '%s'"%type, file=sys.stderr)
...@@ -70,7 +70,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : ...@@ -70,7 +70,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
################################################################################ ################################################################################
################################################################################ ################################################################################
def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) : def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental) :
devScore = "" devScore = ""
saved = True if bestLoss is None else totalLoss < bestLoss saved = True if bestLoss is None else totalLoss < bestLoss
bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss) bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
...@@ -92,12 +92,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ...@@ -92,12 +92,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, silent=False) :
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir+"/dicts.json") dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
examples = [] examples = []
sentences = copy.deepcopy(sentencesOriginal) sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
...@@ -143,11 +143,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -143,11 +143,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
optimizer.step() optimizer.step()
totalLoss += float(loss) totalLoss += float(loss)
bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs) bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental)
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) : def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, silent=False) :
memory = None memory = None
dicts = Dicts() dicts = Dicts()
...@@ -155,8 +155,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -155,8 +155,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir + "/dicts.json") dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice()) target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
...@@ -226,6 +226,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -226,6 +226,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if i >= nbExByEpoch : if i >= nbExByEpoch :
break break
sentIndex += 1 sentIndex += 1
bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental)
################################################################################ ################################################################################
...@@ -4,7 +4,6 @@ from Util import isEmpty ...@@ -4,7 +4,6 @@ from Util import isEmpty
################################################################################ ################################################################################
class Transition : class Transition :
available = lambda self,x: x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2)
def __init__(self, name) : def __init__(self, name) :
if not self.available(name) : if not self.available(name) :
...@@ -14,11 +13,12 @@ class Transition : ...@@ -14,11 +13,12 @@ class Transition :
def __lt__(self, other) : def __lt__(self, other) :
return self.name < other.name return self.name < other.name
def available(self, x) :
return x in {"RIGHT", "LEFT", "SHIFT", "REDUCE", "EOS"} or ("BACK" in x and len(x.split()) == 2)
def apply(self, config, strategy) : def apply(self, config, strategy) :
data = None data = None
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
if self.name == "RIGHT" : if self.name == "RIGHT" :
applyRight(config) applyRight(config)
elif self.name == "LEFT" : elif self.name == "LEFT" :
...@@ -30,6 +30,7 @@ class Transition : ...@@ -30,6 +30,7 @@ class Transition :
elif self.name == "EOS" : elif self.name == "EOS" :
applyEOS(config) applyEOS(config)
elif "BACK" in self.name : elif "BACK" in self.name :
config.historyHistory.add(str([t[0].name for t in config.historyPop]))
size = int(self.name.split()[-1]) size = int(self.name.split()[-1])
applyBack(config, strategy, size) applyBack(config, strategy, size)
else : else :
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import argparse import argparse
import random import random
import torch import torch
import json
import Util import Util
import Train import Train
...@@ -32,6 +33,8 @@ if __name__ == "__main__" : ...@@ -32,6 +33,8 @@ if __name__ == "__main__" :
help="If not none, extract examples in bootstrap mode (oracle train only).") help="If not none, extract examples in bootstrap mode (oracle train only).")
parser.add_argument("--dev", default=None, parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.") help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--incr", "-i", default=False, action="store_true",
help="If true, the neural network will be 'incremenal' i.e. will not see right context words if they have never been the word index.")
parser.add_argument("--debug", "-d", default=False, action="store_true", parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.") help="Print debug infos on stderr.")
parser.add_argument("--silent", "-s", default=False, action="store_true", parser.add_argument("--silent", "-s", default=False, action="store_true",
...@@ -59,8 +62,15 @@ if __name__ == "__main__" : ...@@ -59,8 +62,15 @@ if __name__ == "__main__" :
print("Transition Set :", [trans.name for trans in transitionSet]) print("Transition Set :", [trans.name for trans in transitionSet])
if args.mode == "train" : if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
json.dump(strategy, open(args.model+"/strategy.json", "w"))
print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
transNames = json.load(open(args.model+"/transitions.json", "r"))
transitionSet = [Transition(elem) for elem in transNames]
strategy = json.load(open(args.model+"/strategy.json", "r"))
print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model) Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model)
else : else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr) print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment