Skip to content
Snippets Groups Projects
Commit ade901ad authored by Franck Dary's avatar Franck Dary
Browse files

Added argument to chose neural network type. Added separated network, where...

Added argument to chose neural network type. Added separated network, where weights are not shared between tasks.
parent f336968b
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,18 @@ import torch.nn as nn
import torch.nn.functional as F
import Features
################################################################################
def createNetwork(name, dicts, outputSizes, incremental) :
if name == "base" :
return BaseNet(dicts, outputSizes, incremental)
elif name == "lstm" :
return LSTMNet(dicts, outputSizes, incremental)
elif name == "separated" :
return SeparatedNet(dicts, outputSizes, incremental)
raise Exception("Unknown network name '%s'"%name)
################################################################################
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental) :
......@@ -73,6 +85,77 @@ class BaseNet(nn.Module):
################################################################################
################################################################################
class SeparatedNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental) :
super().__init__()
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.incremental = incremental
self.state = 0
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5
self.suffixSize = 4
self.prefixSize = 4
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
self.outputSizes = outputSizes
for i in range(len(outputSizes)) :
for name in dicts.dicts :
self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, 1600))
self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
self.dropout = nn.Dropout(0.3)
self.apply(self.initWeights)
def setState(self, state) :
self.state = state
def forward(self, x) :
embeddings = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
y = torch.cat(embeddings,-1).view(x.size(0),-1)
curIndex = len(self.columns)*self.nbTargets
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
curIndex = curIndex+self.historyNb
if self.prefixSize > 0 :
prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
y = torch.cat([y, prefixEmb],-1)
curIndex = curIndex+self.prefixSize
if self.suffixSize > 0 :
suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
y = torch.cat([y, suffixEmb],-1)
curIndex = curIndex+self.suffixSize
y = self.dropout(y)
y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y)))
y = getattr(self, "output_"+str(self.state))(y)
return y
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
################################################################################
################################################################################
class LSTMNet(nn.Module):
def __init__(self, dicts, outputSizes, incremental) :
......@@ -91,7 +174,7 @@ class LSTMNet(nn.Module):
self.nbInputBase = len(self.featureFunction.split())
self.nbTargets = self.nbInputBase + self.nbInputLSTM
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb
self.outputSize = outputSize
self.outputSizes = outputSizes
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True)
......
......@@ -16,15 +16,15 @@ import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
sentences = Config.readConllu(filename, predicted)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
return
if type == "rl":
trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
......@@ -98,7 +98,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
transitionNames = {}
......@@ -109,7 +109,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
examples = [[] for _ in transitionSets]
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
......@@ -185,7 +185,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
memory = None
dicts = Dicts()
......@@ -198,8 +198,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
dicts.addDict("HISTORY", transitionNames)
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
target_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
......
......@@ -55,6 +55,8 @@ if __name__ == "__main__" :
help="Transition set to use (eager | swift | tagparser).")
parser.add_argument("--ts", default="",
help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
parser.add_argument("--network", default="base",
help="Name of the neural network to use (base | lstm | separated).")
parser.add_argument("--reward", default="A",
help="Reward function to use (A,B,C,D,E)")
parser.add_argument("--probaRandom", default="0.6,4,0.1",
......@@ -105,7 +107,7 @@ if __name__ == "__main__" :
json.dump(strategy, open(args.model+"/strategy.json", "w"))
printTS(transitionSets, sys.stderr)
probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
Train.trainMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
elif args.mode == "decode" :
transInfos = json.load(open(args.model+"/transitions.json", "r"))
transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment