Skip to content
Snippets Groups Projects
Commit b699477b authored by Franck Dary's avatar Franck Dary
Browse files

Added history of transitions in feature function

parent 1a127719
Branches
No related tags found
No related merge requests found
......@@ -13,6 +13,11 @@ class Dicts :
self.noDepRight = "__nodepright__"
self.noGov = "__nogov__"
def addDict(self, name, d) :
if name in self.dicts :
raise(Exception(name+" already in dicts"))
self.dicts[name] = d
def readConllu(self, filename, colsSet=None, minCount=0) :
defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
col2index, index2col = readMCD(defaultMCD)
......
......@@ -77,6 +77,19 @@ def extractColsFeatures(dicts, config, featureFunction, cols) :
result[insertIndex] = dicts.get(col, value)
insertIndex += 1
if insertIndex != totalSize :
raise(Exception("Missing features"))
return result
################################################################################
################################################################################
def extractHistoryFeatures(dicts, config, nbElements) :
result = torch.zeros(nbElements, dtype=torch.int)
for i in range(nbElements) :
name = config.history[-i].name if i in range(len(config.history)) else dicts.nullToken
result[i] = dicts.get("HISTORY", name)
return result
################################################################################
......@@ -10,11 +10,12 @@ class BaseNet(nn.Module):
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.historyNb = 5
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
self.inputSize = len(self.columns)*self.nbTargets
self.inputSize = len(self.columns)*self.nbTargets+self.historyNb
self.outputSize = outputSize
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
......@@ -28,11 +29,14 @@ class BaseNet(nn.Module):
embeddings = []
for i in range(len(self.columns)) :
embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
x = torch.cat(embeddings,-1)
x = self.dropout(x.view(x.size(0),-1))
x = F.relu(self.dropout(self.fc1(x)))
x = self.fc2(x)
return x
y = torch.cat(embeddings,-1).view(x.size(0),-1)
if self.historyNb > 0 :
historyEmb = getattr(self, "emb_HISTORY")(x[...,len(self.columns)*self.nbTargets:len(self.columns)*self.nbTargets+self.historyNb]).view(x.size(0),-1)
y = torch.cat([y, historyEmb],-1)
y = self.dropout(y)
y = F.relu(self.dropout(self.fc1(y)))
y = self.fc2(y)
return y
def currentDevice(self) :
return self.dummyParam.device
......@@ -43,7 +47,9 @@ class BaseNet(nn.Module):
m.bias.data.fill_(0.01)
def extractFeatures(self, dicts, config) :
return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns)
colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns)
historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
return torch.cat([colsValues, historyValues])
################################################################################
......@@ -95,6 +95,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir+"/dicts.json")
network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
examples = []
......@@ -151,6 +152,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
memory = None
dicts = Dicts()
dicts.readConllu(filename, ["FORM","UPOS"], 2)
dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment