Skip to content
Snippets Groups Projects
Commit 3d80804a authored by Franck Dary's avatar Franck Dary
Browse files

Refactored feature extraction, and added extractFeature function to NeuralNetwork

parent b1c976ae
Branches
No related tags found
No related merge requests found
import random import random
import sys import sys
from Transition import Transition, getMissingLinks, applyTransition from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
from Dicts import Dicts from Dicts import Dicts
from Util import getDevice from Util import getDevice
import Config import Config
...@@ -56,7 +55,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -56,7 +55,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
with torch.no_grad(): with torch.no_grad():
while moved : while moved :
features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = network(features) output = network(features)
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
......
...@@ -9,6 +9,9 @@ class Dicts : ...@@ -9,6 +9,9 @@ class Dicts :
self.nullToken = "__null__" self.nullToken = "__null__"
self.noStackToken = "__nostack__" self.noStackToken = "__nostack__"
self.oobToken = "__oob__" self.oobToken = "__oob__"
self.noDepLeft = "__nodepleft__"
self.noDepRight = "__nodepright__"
self.noGov = "__nogov__"
def readConllu(self, filename, colsSet=None) : def readConllu(self, filename, colsSet=None) :
defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC" defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
...@@ -30,7 +33,7 @@ class Dicts : ...@@ -30,7 +33,7 @@ class Dicts :
targetColumns = list(col2index.keys()) targetColumns = list(col2index.keys())
else : else :
targetColumns = list(colsSet) targetColumns = list(colsSet)
self.dicts = {col : {self.unkToken : 0, self.nullToken : 1} for col in targetColumns} self.dicts = {col : {self.unkToken : 0, self.nullToken : 1, self.noStackToken : 2, self.oobToken : 3, self.noDepLeft : 4, self.noDepRight : 5, self.noGov : 6} for col in targetColumns}
splited = line.split('\t') splited = line.split('\t')
for col in targetColumns : for col in targetColumns :
......
...@@ -3,60 +3,78 @@ import sys ...@@ -3,60 +3,78 @@ import sys
from Util import isEmpty from Util import isEmpty
################################################################################ ################################################################################
def extractFeatures(dicts, config) : # Input : b=buffer s=stack .0=governor .x=rightChild#x+1 .-x=leftChild#-x-1
return extractFeaturesPosExtended(dicts, config) # Output : list of sentence indexes pointing to elements of featureFunction
################################################################################ # Special output values :
# -1 : Out of bounds
# -2 : Not in stack
# -3 : No dependent left
# -4 : No dependent right
# -5 : No gov
def extractIndexes(config, featureFunction) :
features = featureFunction.split()
res = []
for feature in features :
splited = feature.split('.')
obj = splited[0]
index = int(splited[1])
if obj == "b" :
index = config.wordIndex + index
if index not in (range(len(config.lines))) :
index = -1
elif obj == "s" :
if index not in range(len(config.stack)) :
index = -2
else :
index = config.stack[-1-index]
for depIndex in map(int,splited[2:]) :
if index < 0 :
break
if depIndex == 0 :
head = config.getAsFeature(index, "HEAD")
if isEmpty(head) :
index = -5
else :
index = int(head)
continue
if depIndex > 0 :
rightChilds = [child for child in config.predChilds[index] if child > index]
if depIndex-1 in range(len(rightChilds)) :
index = rightChilds[depIndex-1]
else :
index = -4
else :
leftChilds = [child for child in config.predChilds[index] if child < index]
if abs(depIndex)-1 in range(len(leftChilds)) :
index = leftChilds[abs(depIndex)-1]
else :
index = -3
res.append(index)
################################################################################ return res
def extractFeaturesPos(dicts, config) :
bufferWindow = range(-2,2+1)
stackWindow = range(0,3+1)
totalSize = len(bufferWindow)+len(stackWindow)
result = torch.zeros(totalSize, dtype=torch.int)
insertIndex = 0
for i in bufferWindow :
index = config.wordIndex + i
bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
result[insertIndex] = dicts.get("UPOS", bufferPos)
insertIndex += 1
for i in stackWindow :
stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
result[insertIndex] = dicts.get("UPOS", stackPos)
insertIndex += 1
return result
################################################################################ ################################################################################
################################################################################ ################################################################################
# For each stack element, add its POS and the POS of its governor # For each element of the feature function and for each column, concatenante the dict index
def extractFeaturesPosExtended(dicts, config) : def extractColsFeatures(dicts, config, featureFunction, cols) :
bufferWindow = range(-2,2+1) specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov}
stackWindow = range(0,3+1)
totalSize = len(bufferWindow)+2*len(stackWindow)
indexes = extractIndexes(config, featureFunction)
totalSize = len(cols)*len(indexes)
result = torch.zeros(totalSize, dtype=torch.int) result = torch.zeros(totalSize, dtype=torch.int)
insertIndex = 0 insertIndex = 0
for i in bufferWindow : for index in indexes :
index = config.wordIndex + i if index < 0 :
bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS") for col in cols :
result[insertIndex] = dicts.get("UPOS", bufferPos) result[insertIndex] = dicts.get(col, specialValues[index])
insertIndex += 1
for i in stackWindow :
stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
stackGovHead = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "HEAD")
stackGovPos = dicts.nullToken
if not isEmpty(stackGovHead) and stackGovHead != dicts.nullToken :
stackGovPos = config.getAsFeature(int(stackGovHead), "UPOS")
elif stackGovHead == dicts.nullToken :
stackGovPos = dicts.noStackToken
result[insertIndex] = dicts.get("UPOS", stackPos)
insertIndex += 1 insertIndex += 1
result[insertIndex] = dicts.get("UPOS", stackGovPos) else :
for col in cols :
value = config.getAsFeature(index, col)
if isEmpty(value) :
value = dicts.nullToken
result[insertIndex] = dicts.get(col, value)
insertIndex += 1 insertIndex += 1
return result return result
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import Features
################################################################################ ################################################################################
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, dicts, inputSize, outputSize) : def __init__(self, dicts, outputSize) :
super().__init__() super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.columns = ["UPOS"]
self.embSize = 64 self.embSize = 64
self.inputSize = inputSize self.inputSize = len(self.columns)*len(self.featureFunction.split())
self.outputSize = outputSize self.outputSize = outputSize
for name in dicts.dicts : for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize)) self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(inputSize * self.embSize, 1600) self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize) self.fc2 = nn.Linear(1600, outputSize)
self.dropout = nn.Dropout(0.3) self.dropout = nn.Dropout(0.3)
...@@ -32,5 +36,9 @@ class BaseNet(nn.Module): ...@@ -32,5 +36,9 @@ class BaseNet(nn.Module):
if type(m) == nn.Linear: if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01) m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns)
################################################################################ ################################################################################
...@@ -34,7 +34,7 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, boots ...@@ -34,7 +34,7 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, boots
################################################################################ ################################################################################
################################################################################ ################################################################################
def extractExamples(debug, ts, strat, config, dicts, network=None) : def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
examples = [] examples = []
with torch.no_grad() : with torch.no_grad() :
EOS = Transition("EOS") EOS = Transition("EOS")
...@@ -47,12 +47,12 @@ def extractExamples(debug, ts, strat, config, dicts, network=None) : ...@@ -47,12 +47,12 @@ def extractExamples(debug, ts, strat, config, dicts, network=None) :
break break
best = min([cand[0] for cand in candidates]) best = min([cand[0] for cand in candidates])
candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1] candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
features = Features.extractFeatures(dicts, config) features = network.extractFeatures(dicts, config)
candidate = candidateOracle.name candidate = candidateOracle.name
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if network is not None : if dynamic :
output = network(features.unsqueeze(0).to(getDevice())) output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
...@@ -95,17 +95,17 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss ...@@ -95,17 +95,17 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) :
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.readConllu(filename, ["UPOS"])
dicts.save(modelDir+"/dicts.json") dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
examples = [] examples = []
sentences = copy.deepcopy(sentencesOriginal) sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences : for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts) examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples) examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice())
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss() lossFct = torch.nn.CrossEntropyLoss()
...@@ -117,7 +117,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -117,7 +117,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
sentences = copy.deepcopy(sentencesOriginal) sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences : for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network) examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples) examples = torch.stack(examples)
...@@ -154,9 +154,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -154,9 +154,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json") dicts.save(modelDir + "/dicts.json")
policy_net = None policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
target_net = None target_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
optimizer = None target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
bestLoss = None bestLoss = None
bestScore = None bestScore = None
...@@ -178,16 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -178,16 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr) print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
sentence = sentences[sentIndex] sentence = sentences[sentIndex]
sentence.moveWordIndex(0) sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
if policy_net is None :
policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
while True : while True :
missingLinks = getMissingLinks(sentence) missingLinks = getMissingLinks(sentence)
...@@ -209,7 +204,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -209,7 +204,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
newState = None newState = None
if appliable : if appliable :
applyTransition(transitionSet, strategy, sentence, action.name) applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice()) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
if memory is None : if memory is None :
memory = ReplayMemory(5000, state.numel()) memory = ReplayMemory(5000, state.numel())
......
...@@ -44,6 +44,9 @@ if __name__ == "__main__" : ...@@ -44,6 +44,9 @@ if __name__ == "__main__" :
random.seed(args.seed) random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if args.bootstrap is not None :
args.bootstrap = int(args.bootstrap)
if args.mode == "train" : if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment