Skip to content
Snippets Groups Projects
Commit 74a8e16e authored by Franck Dary's avatar Franck Dary
Browse files

Working neural network training with oracle

parent b20da207
No related branches found
No related tags found
No related merge requests found
__pycache__
data
bin/*
......@@ -91,7 +91,8 @@ class Config :
def print(self, output, header=False) :
if header :
print("# global.columns = %s"%(" ".join(self.col2index.keys())), file=output)
print("\n".join(self.comments))
if len(self.comments) > 0 :
print("\n".join(self.comments), file=output)
for index in range(len(self.lines)) :
toPrint = []
for colIndex in range(len(self.lines[index])) :
......@@ -104,7 +105,7 @@ class Config :
value = "0"
toPrint.append(value)
print("\t".join(toPrint), file=output)
print("")
print("", file=output)
################################################################################
################################################################################
......
import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
import torch
################################################################################
def randomDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
while True :
candidates = [trans for trans in ts if trans.appliable(config)]
if len(candidates) == 0 :
break
candidate = candidates[random.randint(0, 100) % len(candidates)]
if debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(ts, strat, config, candidate.name)
EOS.apply(config)
################################################################################
################################################################################
def oracleDecode(ts, strat, config, debug=False) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
################################################################################
################################################################################
def decodeModel(ts, strat, config, network, dicts, debug) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
network.eval()
with torch.no_grad():
while moved :
features = extractFeatures(dicts, config).unsqueeze(0)
output = torch.nn.functional.softmax(network(features), dim=1)
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
if len(candidates) == 0 :
break
candidate = candidates[0]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
################################################################################
Dicts.py 0 → 100644
import json
from readMCD import readMCD
################################################################################
class Dicts :
def __init__(self) :
self.dicts = {}
self.unkToken = "__unknown__"
self.nullToken = "__null__"
def readConllu(self, filename, colsSet=None) :
defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
col2index, index2col = readMCD(defaultMCD)
targetColumns = []
for line in open(filename, "r") :
line = line.strip()
if "# global.columns =" in line :
mcd = line.split('=')[-1].strip()
col2index, index2col = readMCD(mcd)
continue
if len(line) == 0 or line[0] == '#' :
continue
if len(targetColumns) == 0 :
if colsSet is None :
targetColumns = list(col2index.keys())
else :
targetColumns = list(colsSet)
self.dicts = {col : {self.unkToken : 0, self.nullToken : 1} for col in targetColumns}
splited = line.split('\t')
for col in targetColumns :
value = splited[col2index[col]]
if value not in self.dicts[col] :
self.dicts[col][value] = len(self.dicts[col])
def get(self, col, value) :
if value in self.dicts[col] :
return self.dicts[col][value]
if value.lower() in self.dicts[col] :
return self.dicts[col][value.lower()]
return self.dicts[col][self.unkToken]
def save(self, target) :
json.dump(self.dicts, open(target, "w"))
def load(self, target) :
self.dicts = json.load(open(target, "r"))
################################################################################
import torch
import sys
################################################################################
def extractFeatures(dicts, config) :
return extractFeaturesPos(dicts, config)
################################################################################
################################################################################
def extractFeaturesPos(dicts, config) :
bufferWindow = range(-2,2+1)
stackWindow = range(0,3+1)
totalSize = len(bufferWindow)+len(stackWindow)
result = torch.zeros(totalSize, dtype=torch.int)
insertIndex = 0
for i in bufferWindow :
index = config.wordIndex + i
bufferPos = dicts.nullToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
result[insertIndex] = dicts.get("UPOS", bufferPos)
insertIndex += 1
for i in stackWindow :
stackPos = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
result[insertIndex] = dicts.get("UPOS", stackPos)
insertIndex += 1
return result
################################################################################
import torch.nn as nn
import torch.nn.functional as F
################################################################################
class BaseNet(nn.Module):
def __init__(self, dicts, inputSize, outputSize) :
super().__init__()
self.embSize = 64
self.inputSize = inputSize
self.outputSize = outputSize
self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()}
self.fc1 = nn.Linear(inputSize * self.embSize, 128)
self.fc2 = nn.Linear(128, outputSize)
self.dropout = nn.Dropout(0.3)
def forward(self, x) :
x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1))
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
################################################################################
Train.py 0 → 100644
import sys
import random
from Transition import Transition, getMissingLinks, applyTransition
import Features
import torch
################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
examples = []
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
candidateIndex = [trans.name for trans in ts].index(candidate)
features = Features.extractFeatures(dicts, config)
example = torch.cat([torch.LongTensor([candidateIndex]), features])
examples.append(example)
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
return examples
################################################################################
......@@ -68,7 +68,7 @@ class Transition :
################################################################################
# Compute numeric values that will be used in the oracle to decide score of transitions
def getMissingLinks(config) :
return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config)}
return {"StackRight" : nbLinksStackRight(config), "BufferRight" : nbLinksBufferRight(config), "BufferStack" : nbLinksBufferStack(config), "BufferRightHead" : nbLinksBufferRightHead(config)}
################################################################################
################################################################################
......@@ -78,6 +78,12 @@ def nbLinksBufferRight(config) :
return head + len([c for c in config.childs[config.wordIndex] if c > config.wordIndex])
################################################################################
################################################################################
# Number of missing childs between wordIndex and the right of the sentence
def nbLinksBufferRightHead(config) :
return 1 if int(config.getGold(config.wordIndex, "HEAD")) > config.wordIndex else 0
################################################################################
################################################################################
# Number of missing links between stack top and the right of the sentence
def nbLinksStackRight(config) :
......@@ -107,7 +113,7 @@ def linkCauseCycle(config, fromIndex, toIndex) :
################################################################################
def scoreOracleRight(config, ml) :
return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRight"])
return 0 if config.getGold(config.wordIndex, "HEAD") == config.stack[-1] else (ml["BufferStack"] + ml["BufferRightHead"])
################################################################################
################################################################################
......@@ -168,3 +174,11 @@ def applyEOS(config) :
config.set(index, "HEAD", str(rootIndex))
################################################################################
################################################################################
def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name]
transition.apply(config)
return config.moveWordIndex(movement)
################################################################################
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#! /usr/bin/env python3
import sys
import random
import os
import argparse
from datetime import datetime
import Config
from Transition import Transition, getMissingLinks
import Decode
import Train
from Transition import Transition
import Networks
from Dicts import Dicts
from conll18_ud_eval import load_conllu, evaluate
import torch
################################################################################
def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name]
transition.apply(config)
return config.moveWordIndex(movement)
def timeStamp() :
return "[%s]"%datetime.now().strftime("%H:%M:%S")
################################################################################
################################################################################
def randomDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
while True :
candidates = [trans for trans in transitionSet if trans.appliable(config)]
if len(candidates) == 0 :
break
candidate = candidates[random.randint(0, 100) % len(candidates)]
if args.debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(transitionSet, strategy, config, candidate.name)
EOS.apply(config)
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
examples = []
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug)
print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
examples = torch.stack(examples)
network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
lossFct = torch.nn.CrossEntropyLoss()
for iter in range(1,nbIter+1) :
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
for batchIndex in range(0,examples.size(0)-6,6) :
batch = examples[batchIndex:batchIndex+batchSize]
targets = batch[:,:1].view(-1)
inputs = batch[:,1:]
outputs = network(inputs)
loss = lossFct(outputs, targets)
network.zero_grad()
loss.backward()
optimizer.step()
totalLoss += float(loss)
devScore = ""
if devFile is not None :
outFilename = modelDir+"/predicted_dev.conllu"
decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
decodeMode(debug, filename, "model", network, dicts)
return
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
def oracleDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if args.debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(transitionSet, strategy, config, candidate)
EOS.apply(config)
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type in ["random", "oracle"] :
decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode
for config in sentences :
decodeFunc(transitionSet, strategy, config, args.debug)
sentences[0].print(sys.stdout, header=True)
for config in sentences[1:] :
config.print(sys.stdout, header=False)
elif type == "model" :
for config in sentences :
Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug)
sentences[0].print(output, header=True)
for config in sentences[1:] :
config.print(output, header=False)
else :
print("ERROR : unknown type '%s'"%type, file=sys.stderr)
exit(1)
################################################################################
################################################################################
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("trainCorpus", type=str,
help="Name of the CoNLL-U training file.")
parser.add_argument("mode", type=str,
help="What to do : train | decode")
parser.add_argument("type", type=str,
help="Type of train or decode. random | oracle")
parser.add_argument("corpus", type=str,
help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
parser.add_argument("model", type=str,
help="Path to the model directory.")
parser.add_argument("--iter", "-n", default=5,
help="Number of training epoch.")
parser.add_argument("--batchSize", default=64,
help="Size of each batch.")
parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.")
args = parser.parse_args()
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(sys.argv[1])
os.makedirs(args.model, exist_ok=True)
first = True
for config in sentences :
oracleDecode(transitionSet, strategy, config)
config.print(sys.stdout, header=first)
first = False
if args.mode == "train" :
trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev)
elif args.mode == "decode" :
decodeMode(args.debug, args.corpus, args.type)
else :
print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
exit(1)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment