Skip to content
Snippets Groups Projects
Commit 74a8e16e authored by Franck Dary's avatar Franck Dary
Browse files

Working neural network training with oracle

parent b20da207
No related branches found
No related tags found
No related merge requests found
__pycache__ __pycache__
data bin/*
...@@ -91,7 +91,8 @@ class Config : ...@@ -91,7 +91,8 @@ class Config :
def print(self, output, header=False) : def print(self, output, header=False) :
if header : if header :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output) print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
print("\n".join(self.comments)) if len(self.comments) > 0 :
print("\n".join(self.comments), file=output)
for index in range(len(self.lines)) : for index in range(len(self.lines)) :
toPrint = [] toPrint = []
for colIndex in range(len(self.lines[index])) : for colIndex in range(len(self.lines[index])) :
...@@ -104,7 +105,7 @@ class Config : ...@@ -104,7 +105,7 @@ class Config :
value = "0" value = "0"
toPrint.append(value) toPrint.append(value)
print("\t".join(toPrint), file=output) print("\t".join(toPrint), file=output)
print("") print("", file=output)
################################################################################ ################################################################################
################################################################################ ################################################################################
......
import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
import torch
################################################################################
def randomDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
while True :
candidates = [trans for trans in ts if trans.appliable(config)]
if len(candidates) == 0 :
break
candidate = candidates[random.randint(0, 100) % len(candidates)]
if debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(ts, strat, config, candidate.name)
EOS.apply(config)
################################################################################
################################################################################
def oracleDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
################################################################################
################################################################################
def decodeModel(ts, strat, config, network, dicts, debug) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
network.eval()
with torch.no_grad():
while moved :
features = extractFeatures(dicts, config).unsqueeze(0)
output = torch.nn.functional.softmax(network(features), dim=1)
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
if len(candidates) == 0 :
break
candidate = candidates[0]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
################################################################################
Dicts.py 0 → 100644
import json
from readMCD import readMCD
################################################################################
class Dicts :
def __init__(self) :
self.dicts = {}
self.unkToken = "__unknown__"
self.nullToken = "__null__"
def readConllu(self, filename, colsSet=None) :
defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
col2index, index2col = readMCD(defaultMCD)
targetColumns = []
for line in open(filename, "r") :
line = line.strip()
if "# global.columns =" in line :
mcd = line.split('=')[-1].strip()
col2index, index2col = readMCD(mcd)
continue
if len(line) == 0 or line[0] == '#' :
continue
if len(targetColumns) == 0 :
if colsSet is None :
targetColumns = list(col2index.keys())
else :
targetColumns = list(colsSet)
self.dicts = {col : {self.unkToken : 0, self.nullToken : 1} for col in targetColumns}
splited = line.split('\t')
for col in targetColumns :
value = splited[col2index[col]]
if value not in self.dicts[col] :
self.dicts[col][value] = len(self.dicts[col])
def get(self, col, value) :
if value in self.dicts[col] :
return self.dicts[col][value]
if value.lower() in self.dicts[col] :
return self.dicts[col][value.lower()]
return self.dicts[col][self.unkToken]
def save(self, target) :
json.dump(self.dicts, open(target, "w"))
def load(self, target) :
self.dicts = json.load(open(target, "r"))
################################################################################
import torch
import sys
################################################################################
def extractFeatures(dicts, config) :
return extractFeaturesPos(dicts, config)
################################################################################
################################################################################
def extractFeaturesPos(dicts, config) :
bufferWindow = range(-2,2+1)
stackWindow = range(0,3+1)
totalSize = len(bufferWindow)+len(stackWindow)
result = torch.zeros(totalSize, dtype=torch.int)
insertIndex = 0
for i in bufferWindow :
index = config.wordIndex + i
bufferPos = dicts.nullToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
result[insertIndex] = dicts.get("UPOS", bufferPos)
insertIndex += 1
for i in stackWindow :
stackPos = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
result[insertIndex] = dicts.get("UPOS", stackPos)
insertIndex += 1
return result
################################################################################
import torch.nn as nn
import torch.nn.functional as F
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, inputSize, outputSize) :
super().__init__()
self.embSize = 64
self.inputSize = inputSize
self.outputSize = outputSize
self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()}
self.fc1 = nn.Linear(inputSize * self.embSize, 128)
self.fc2 = nn.Linear(128, outputSize)
self.dropout = nn.Dropout(0.3)
def forward(self, x) :
x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1))
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
################################################################################
Train.py 0 → 100644
import sys
import random
from Transition import Transition, getMissingLinks, applyTransition
import Features
import torch
################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
examples = []
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
candidateIndex = [trans.name for trans in ts].index(candidate)
features = Features.extractFeatures(dicts, config)
example = torch.cat([torch.LongTensor([candidateIndex]), features])
examples.append(example)
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
return examples
################################################################################
...@@ -68,7 +68,7 @@ class Transition : ...@@ -68,7 +68,7 @@ class Transition :
################################################################################ ################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions # Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) : def getMissingLinks(config) :
return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config)} return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}
################################################################################ ################################################################################
################################################################################ ################################################################################
...@@ -78,6 +78,12 @@ def nbLinksBufferRight(config) : ...@@ -78,6 +78,12 @@ def nbLinksBufferRight(config) :
return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex]) return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex])
################################################################################ ################################################################################
################################################################################
# Number of missing childs between wordIndex and the right of the sentence
def nbLinksBufferRightHead(config) :
return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
################################################################################
################################################################################ ################################################################################
# Number of missing links between stack top and the right of the sentence # Number of missing links between stack top and the right of the sentence
def nbLinksStackRight(config) : def nbLinksStackRight(config) :
...@@ -107,7 +113,7 @@ def linkCauseCycle(config, fromIndex, toIndex) : ...@@ -107,7 +113,7 @@ def linkCauseCycle(config, fromIndex, toIndex) :
################################################################################ ################################################################################
def scoreOracleRight(config, ml) : def scoreOracleRight(config, ml) :
return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRight"]) return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"])
################################################################################ ################################################################################
################################################################################ ################################################################################
...@@ -168,3 +174,11 @@ def applyEOS(config) : ...@@ -168,3 +174,11 @@ def applyEOS(config) :
config.set(index, "HEAD", str(rootIndex)) config.set(index, "HEAD", str(rootIndex))
################################################################################ ################################################################################
################################################################################
def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name]
transition.apply(config)
return config.moveWordIndex(movement)
################################################################################
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#! /usr/bin/env python3 #! /usr/bin/env python3
import sys import sys
import random import os
import argparse import argparse
from datetime import datetime
import Config import Config
from Transition import Transition, getMissingLinks import Decode
import Train
from Transition import Transition
import Networks
from Dicts import Dicts
from conll18_ud_eval import load_conllu, evaluate
import torch
################################################################################ ################################################################################
def applyTransition(ts, strat, config, name) : def timeStamp() :
transition = [trans for trans in ts if trans.name == name][0] return "[%s]"%datetime.now().strftime("%H:%M:%S")
movement = strat[transition.name]
transition.apply(config)
return config.moveWordIndex(movement)
################################################################################ ################################################################################
################################################################################ ################################################################################
def randomDecode(ts, strat, config) : def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) :
EOS = Transition("EOS") transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
config.moveWordIndex(0) strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
while True :
candidates = [trans for trans in transitionSet if trans.appliable(config)] sentences = Config.readConllu(filename)
if len(candidates) == 0 :
break if type == "oracle" :
candidate = candidates[random.randint(0, 100) % len(candidates)] examples = []
if args.debug : dicts = Dicts()
config.printForDebug(sys.stderr) dicts.readConllu(filename, ["FORM", "UPOS"])
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) dicts.save(modelDir+"/dicts.json")
applyTransition(transitionSet, strategy, config, candidate.name) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
EOS.apply(config) examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
for batchIndex in range(0,examples.size(0)-6,6) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
decodeMode(debug, filename, "model", network, dicts)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################ ################################################################################
################################################################################ ################################################################################
def oracleDecode(ts, strat, config) : def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
EOS = Transition("EOS") transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
config.moveWordIndex(0) strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
moved = True
while moved : sentences = Config.readConllu(filename)
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)]) if type in ["random", "oracle"] :
if len(candidates) == 0 : decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode
break for config in sentences :
candidate = candidates[0][1] decodeFunc(transitionSet, strategy, config, args.debug)
if args.debug : sentences[0].print(sys.stdout, header=True)
config.printForDebug(sys.stderr) for config in sentences[1:] :
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) config.print(sys.stdout, header=False)
moved = applyTransition(transitionSet, strategy, config, candidate) elif type == "model" :
for config in sentences :
EOS.apply(config) Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug)
sentences[0].print(output, header=True)
for config in sentences[1:] :
config.print(output, header=False)
else :
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################ ################################################################################
################################################################################ ################################################################################
if __name__ == "__main__" : if __name__ == "__main__" :
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("trainCorpus", type=str, parser.add_argument("mode", type=str,
help="Name of the CoNLL-U training file.") help="What to do : train | decode")
parser.add_argument("type", type=str,
help="Type of train or decode. random | oracle")
parser.add_argument("corpus", type=str,
help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
parser.add_argument("model", type=str,
help="Path to the model directory.")
parser.add_argument("--iter", "-n", default=5,
help="Number of training epoch.")
parser.add_argument("--batchSize", default=64,
help="Size of each batch.")
parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--debug", "-d", default=False, action="store_true", parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.") help="Print debug infos on stderr.")
args = parser.parse_args() args = parser.parse_args()
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] os.makedirs(args.model, exist_ok=True)
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(sys.argv[1]) if args.mode == "train" :
trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev)
first = True elif args.mode == "decode" :
for config in sentences : decodeMode(args.debug, args.corpus, args.type)
oracleDecode(transitionSet, strategy, config) else :
config.print(sys.stdout, header=first) print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
first = False exit(1)
################################################################################ ################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment