Skip to content
Snippets Groups Projects
Commit 39e488e9 authored by Franck Dary's avatar Franck Dary
Browse files

Added history in config

parent fb6f6ffa
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ class Config : ...@@ -13,6 +13,7 @@ class Config :
self.wordIndex = 0 self.wordIndex = 0
self.stack = [] self.stack = []
self.comments = [] self.comments = []
self.history = []
def addLine(self, cols) : def addLine(self, cols) :
self.lines.append([[val,""] for val in cols]) self.lines.append([[val,""] for val in cols])
...@@ -79,6 +80,7 @@ class Config : ...@@ -79,6 +80,7 @@ class Config :
left = 5 left = 5
right = 5 right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
print("history :",[trans.name for trans in self.history[-10:]], file=output)
toPrint = [] toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) : if lineIndex not in range(len(self.lines)) :
......
...@@ -10,7 +10,7 @@ class BaseNet(nn.Module): ...@@ -10,7 +10,7 @@ class BaseNet(nn.Module):
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False) self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1" self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.columns = ["UPOS"] self.columns = ["UPOS", "FORM"]
self.embSize = 64 self.embSize = 64
self.nbTargets = len(self.featureFunction.split()) self.nbTargets = len(self.featureFunction.split())
......
...@@ -88,7 +88,8 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss ...@@ -88,7 +88,8 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
devScore = ", Dev : UAS=%.2f"%(UAS) devScore = ", Dev : UAS=%.2f"%(UAS)
if saved : if saved :
torch.save(model, modelDir+"/network.pt") torch.save(model, modelDir+"/network.pt")
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr) for out in [sys.stderr, open(modelDir+"train.log", "w" if epoch == 1 else "a")] :
print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
return bestLoss, bestScore return bestLoss, bestScore
################################################################################ ################################################################################
...@@ -215,7 +216,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -215,7 +216,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
state = newState state = newState
if i % batchSize == 0 : if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer) totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
if i % (2*batchSize) == 0 : if i % (1*batchSize) == 0 :
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
target_net.eval() target_net.eval()
policy_net.train() policy_net.train()
......
...@@ -18,22 +18,18 @@ class Transition : ...@@ -18,22 +18,18 @@ class Transition :
def apply(self, config) : def apply(self, config) :
if self.name == "RIGHT" : if self.name == "RIGHT" :
applyRight(config) applyRight(config)
return elif self.name == "LEFT" :
if self.name == "LEFT" :
applyLeft(config) applyLeft(config)
return elif self.name == "SHIFT" :
if self.name == "SHIFT" :
applyShift(config) applyShift(config)
return elif self.name == "REDUCE" :
if self.name == "REDUCE" :
applyReduce(config) applyReduce(config)
return elif self.name == "EOS" :
if self.name == "EOS" :
applyEOS(config) applyEOS(config)
return else :
print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr) print("ERROR : nothing to apply for '%s'"%self.name, file=sys.stderr)
exit(1) exit(1)
config.history.append(self)
def appliable(self, config) : def appliable(self, config) :
if self.name == "RIGHT" : if self.name == "RIGHT" :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment