Skip to content
Snippets Groups Projects
Commit 3d80804a authored by Franck Dary's avatar Franck Dary
Browse files

Refactored feature extraction, and added extractFeature function to NeuralNetwork

parent b1c976ae
Branches
No related tags found
No related merge requests found
import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
from Dicts import Dicts
from Util import getDevice
import Config
......@@ -56,7 +55,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
with torch.no_grad():
while moved :
features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = network(features)
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
......
......@@ -9,6 +9,9 @@ class Dicts :
self.nullToken = "__null__"
self.noStackToken = "__nostack__"
self.oobToken = "__oob__"
self.noDepLeft = "__nodepleft__"
self.noDepRight = "__nodepright__"
self.noGov = "__nogov__"
def readConllu(self, filename, colsSet=None) :
defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
......@@ -30,7 +33,7 @@ class Dicts :
targetColumns = list(col2index.keys())
else :
targetColumns = list(colsSet)
self.dicts = {col : {self.unkToken : 0, self.nullToken : 1} for col in targetColumns}
self.dicts = {col : {self.unkToken : 0, self.nullToken : 1, self.noStackToken : 2, self.oobToken : 3, self.noDepLeft : 4, self.noDepRight : 5, self.noGov : 6} for col in targetColumns}
splited = line.split('\t')
for col in targetColumns :
......
......@@ -3,60 +3,78 @@ import sys
from Util import isEmpty
################################################################################
def extractFeatures(dicts, config) :
return extractFeaturesPosExtended(dicts, config)
################################################################################
# Input : b=buffer s=stack .0=governor .x=rightChild#x+1 .-x=leftChild#-x-1
# Output : list of sentence indexes pointing to elements of featureFunction
# Special output values :
# -1 : Out of bounds
# -2 : Not in stack
# -3 : No dependent left
# -4 : No dependent right
# -5 : No gov
def extractIndexes(config, featureFunction) :
features = featureFunction.split()
res = []
for feature in features :
splited = feature.split('.')
obj = splited[0]
index = int(splited[1])
if obj == "b" :
index = config.wordIndex + index
if index not in (range(len(config.lines))) :
index = -1
elif obj == "s" :
if index not in range(len(config.stack)) :
index = -2
else :
index = config.stack[-1-index]
for depIndex in map(int,splited[2:]) :
if index < 0 :
break
if depIndex == 0 :
head = config.getAsFeature(index, "HEAD")
if isEmpty(head) :
index = -5
else :
index = int(head)
continue
if depIndex > 0 :
rightChilds = [child for child in config.predChilds[index] if child > index]
if depIndex-1 in range(len(rightChilds)) :
index = rightChilds[depIndex-1]
else :
index = -4
else :
leftChilds = [child for child in config.predChilds[index] if child < index]
if abs(depIndex)-1 in range(len(leftChilds)) :
index = leftChilds[abs(depIndex)-1]
else :
index = -3
res.append(index)
################################################################################
def extractFeaturesPos(dicts, config) :
bufferWindow = range(-2,2+1)
stackWindow = range(0,3+1)
totalSize = len(bufferWindow)+len(stackWindow)
result = torch.zeros(totalSize, dtype=torch.int)
insertIndex = 0
for i in bufferWindow :
index = config.wordIndex + i
bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
result[insertIndex] = dicts.get("UPOS", bufferPos)
insertIndex += 1
for i in stackWindow :
stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
result[insertIndex] = dicts.get("UPOS", stackPos)
insertIndex += 1
return result
return res
################################################################################
################################################################################
# For each stack element, add its POS and the POS of its governor
def extractFeaturesPosExtended(dicts, config) :
bufferWindow = range(-2,2+1)
stackWindow = range(0,3+1)
totalSize = len(bufferWindow)+2*len(stackWindow)
# For each element of the feature function and for each column, concatenante the dict index
def extractColsFeatures(dicts, config, featureFunction, cols) :
specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov}
indexes = extractIndexes(config, featureFunction)
totalSize = len(cols)*len(indexes)
result = torch.zeros(totalSize, dtype=torch.int)
insertIndex = 0
for i in bufferWindow :
index = config.wordIndex + i
bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
result[insertIndex] = dicts.get("UPOS", bufferPos)
insertIndex += 1
for i in stackWindow :
stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
stackGovHead = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "HEAD")
stackGovPos = dicts.nullToken
if not isEmpty(stackGovHead) and stackGovHead != dicts.nullToken :
stackGovPos = config.getAsFeature(int(stackGovHead), "UPOS")
elif stackGovHead == dicts.nullToken :
stackGovPos = dicts.noStackToken
result[insertIndex] = dicts.get("UPOS", stackPos)
for index in indexes :
if index < 0 :
for col in cols :
result[insertIndex] = dicts.get(col, specialValues[index])
insertIndex += 1
result[insertIndex] = dicts.get("UPOS", stackGovPos)
else :
for col in cols :
value = config.getAsFeature(index, col)
if isEmpty(value) :
value = dicts.nullToken
result[insertIndex] = dicts.get(col, value)
insertIndex += 1
return result
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import Features
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, inputSize, outputSize) :
def __init__(self, dicts, outputSize) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.columns = ["UPOS"]
self.embSize = 64
self.inputSize = inputSize
self.inputSize = len(self.columns)*len(self.featureFunction.split())
self.outputSize = outputSize
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.fc1 = nn.Linear(inputSize * self.embSize, 1600)
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize)
self.dropout = nn.Dropout(0.3)
......@@ -32,5 +36,9 @@ class BaseNet(nn.Module):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns)
################################################################################
......@@ -34,7 +34,7 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, boots
################################################################################
################################################################################
def extractExamples(debug, ts, strat, config, dicts, network=None) :
def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
examples = []
with torch.no_grad() :
EOS = Transition("EOS")
......@@ -47,12 +47,12 @@ def extractExamples(debug, ts, strat, config, dicts, network=None) :
break
best = min([cand[0] for cand in candidates])
candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
features = Features.extractFeatures(dicts, config)
features = network.extractFeatures(dicts, config)
candidate = candidateOracle.name
if debug :
config.printForDebug(sys.stderr)
print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if network is not None :
if dynamic :
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
......@@ -95,17 +95,17 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.readConllu(filename, ["UPOS"])
dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts)
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice())
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
......@@ -117,7 +117,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network)
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
......@@ -154,9 +154,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json")
policy_net = None
target_net = None
optimizer = None
policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
target_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
bestLoss = None
bestScore = None
......@@ -178,16 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
if policy_net is None :
policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
while True :
missingLinks = getMissingLinks(sentence)
......@@ -209,7 +204,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
newState = None
if appliable :
applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
if memory is None :
memory = ReplayMemory(5000, state.numel())
......
......@@ -44,6 +44,9 @@ if __name__ == "__main__" :
random.seed(args.seed)
torch.manual_seed(args.seed)
if args.bootstrap is not None :
args.bootstrap = int(args.bootstrap)
if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent)
elif args.mode == "decode" :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment