Skip to content
Snippets Groups Projects
Commit 39e488e9 authored by Franck Dary's avatar Franck Dary
Browse files

Added history in config

parent fb6f6ffa
Branches
Tags
No related merge requests found
......@@ -13,6 +13,7 @@ class Config :
self.wordIndex = 0
self.stack = []
self.comments = []
self.history = []
def addLine(self, cols) :
self.lines.append([[val,""] for val in cols])
......@@ -79,6 +80,7 @@ class Config :
left = 5
right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[trans.name for trans in self.history[-10:]], file=output)
toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) :
......
......@@ -10,7 +10,7 @@ class BaseNet(nn.Module):
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.columns = ["UPOS"]
self.columns = ["UPOS", "FORM"]
self.embSize = 64
self.nbTargets = len(self.featureFunction.split())
......
......@@ -88,7 +88,8 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
devScore = ", Dev : UAS=%.2f"%(UAS)
if saved :
torch.save(model, modelDir+"/network.pt")
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
for out in [sys.stderr, open(modelDir+"train.log", "w" if epoch == 1 else "a")] :
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
return bestLoss, bestScore
################################################################################
......@@ -215,7 +216,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
if i % (2*batchSize) == 0 :
if i % (1*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
......
......@@ -18,22 +18,18 @@ class Transition :
def apply(self, config) :
if self.name == "RIGHT" :
applyRight(config)
return
if self.name == "LEFT" :
elif self.name == "LEFT" :
applyLeft(config)
return
if self.name == "SHIFT" :
elif self.name == "SHIFT" :
applyShift(config)
return
if self.name == "REDUCE" :
elif self.name == "REDUCE" :
applyReduce(config)
return
if self.name == "EOS" :
elif self.name == "EOS" :
applyEOS(config)
return
else :
print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
exit(1)
config.history.append(self)
def appliable(self, config) :
if self.name == "RIGHT" :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment