Skip to content
Snippets Groups Projects
Commit 3bf253c8 authored by Alexis Nasr's avatar Alexis Nasr
Browse files

adaptation du parser à pytorch WIP

parent c97f050b
No related branches found
No related tags found
No related merge requests found
##POS
NULL
ROOT
SCONJ
ADP
NOUN
ADJ
PUNCT
DET
VERB
AUX
PROPN
NUM
CCONJ
PRON
ADV
SYM
PART
##LABEL
NULL
ROOT
case
fixed
obl
amod
punct
det
nsubj
root
aux
xcomp
obj
nmod
nummod
cc
conj
mark
advcl
advmod
csubj
appos
flat
ccomp
acl
cop
compound
iobj
expl
orphan
parataxis
##EOS
NULL
ROOT
0
1
......@@ -54,7 +54,7 @@ class FeatModel:
def getFeatLabel(self, featIndex):
return self.featArray[featIndex][3]
def buildInputVector(self, featVec, dicos):
def buildInputVectorOneHot(self, featVec, dicos):
inputVector = np.zeros(self.inputVectorSize, dtype="int32")
origin = 0
for i in range(self.getNbFeat()):
......@@ -65,3 +65,16 @@ class FeatModel:
inputVector[origin + position] = 1
origin += size
return inputVector
def buildInputVector(self, featVec, dicos):
inputVector = np.zeros(self.getNbFeat(), dtype="int32")
origin = 0
for i in range(self.getNbFeat()):
label = self.getFeatLabel(i)
size = dicos.getDico(label).getSize()
position = dicos.getCode(label, featVec[i])
# print('featureName = ', featureName, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin)
# print('value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin)
inputVector[i] = position
origin += size
return inputVector
......@@ -36,8 +36,14 @@ class Moves:
labelCode = int((mvt_Code - 3)/ 2)
return ('RIGHT', self.dicoLabels.getSymbol(labelCode))
def buildOutputVector(self, mvt):
def buildOutputVectorOneHot(self, mvt):
outputVector = np.zeros(self.nb, dtype="int32")
codeMvt = self.mvtCode(mvt)
outputVector[codeMvt] = 1
return outputVector
def buildOutputVector(self, mvt):
outputVector = np.zeros(1, dtype="int32")
codeMvt = self.mvtCode(mvt)
outputVector[0] = codeMvt
return outputVector
......@@ -66,7 +66,7 @@ def oracle(c):
#print("no movement possible return SHIFT")
if not c.getBuffer().endReached():
return('SHIFT', '')
print("The machine is stucked")
print("The machine is stuck")
exit(1)
......
......@@ -41,13 +41,16 @@ def prepareData(mcd, mcfFile, featModel, moves, filename, wordsLimit) :
prepareWordBufferForTrain(c.getBuffer())
while True :
mvt = Oracle.oracle(c)
outputVector = moves.buildOutputVector(mvt)
code = moves.mvtCode(mvt)
# print("mvt = ", mvt, "code = ", code)
# outputVector = moves.buildOutputVector(mvt)
featVec = c.extractFeatVec(featModel)
inputVector = featModel.buildInputVector(featVec, dicos)
# np.savetxt(dataFile, inputVector, fmt="%s", delimiter=' ', newline=' ')
# dataFile.write('\n')
np.savetxt(dataFile, [code], fmt="%s", delimiter=' ', newline=' ')
np.savetxt(dataFile, inputVector, fmt="%s", delimiter=' ', newline=' ')
dataFile.write('\n')
np.savetxt(dataFile, outputVector, fmt="%s", delimiter=' ', newline=' ')
dataFile.write('\n')
if(verbose == True) :
print("------------------------------------------")
......
......@@ -3,7 +3,7 @@ from WordBuffer import WordBuffer
from Word import Word
if len(sys.argv) < 2 :
print('usage:', sys.argv[0], 'conllFile mcdFile')
print('usage:', sys.argv[0], 'conllFile')
exit(1)
......
File moved
import sys
import Oracle
from Dicos import Dicos
from Config import Config
from Word import Word
from Mcd import Mcd
from Moves import Moves
from FeatModel import FeatModel
import torch
import numpy as np
def prepareWordBufferForDecode(buffer):
"""Add to every word of the buffer features GOVREF and LABELREF.
GOVEREF is a copy of feature GOV and LABELREF a copy of LABEL
GOV and LABEL are set to initialization values
"""
for word in buffer.array:
word.setFeat('GOV', str(Word.invalidGov()))
word.setFeat('LABEL', Word.invalidLabel())
verbose = False
if len(sys.argv) != 7 :
print('usage:', sys.argv[0], 'mcf_file model_file dicos_file feat_model mcd_file words_limit')
exit(1)
mcf_file = sys.argv[1]
model_file = sys.argv[2]
dicos_file = sys.argv[3]
feat_model = sys.argv[4]
mcd_file = sys.argv[5]
wordsLimit = int(sys.argv[6])
sys.stderr.write('reading mcd from file :')
sys.stderr.write(mcd_file)
sys.stderr.write('\n')
mcd = Mcd(mcd_file)
sys.stderr.write('loading dicos\n')
dicos = Dicos(fileName=dicos_file)
moves = Moves(dicos)
sys.stderr.write('reading feature model from file :')
sys.stderr.write(feat_model)
sys.stderr.write('\n')
featModel = FeatModel(feat_model, dicos)
sys.stderr.write('loading model :')
sys.stderr.write(model_file)
sys.stderr.write('\n')
model = load_model(model_file)
inputSize = featModel.getInputSize()
outputSize = moves.getNb()
c = Config(mcf_file, mcd, dicos)
numSent = 0
verbose = False
numWords = 0
while c.getBuffer().readNextSentence() and numWords < wordsLimit :
c.getStack().empty()
prepareWordBufferForDecode(c.getBuffer())
numWords += c.getBuffer().getLength()
while True :
featVec = c.extractFeatVec(featModel)
inputVector = featModel.buildInputVector(featVec, dicos)
outputVector = model.predict(inputVector.reshape((1,inputSize)), batch_size=1, verbose=0, steps=None)
mvt_Code = outputVector.argmax()
mvt = moves.mvtDecode(mvt_Code)
if(verbose == True) :
print("------------------------------------------")
c.affiche()
print('predicted move', mvt[0], mvt[1])
print(mvt, featVec)
res = c.applyMvt(mvt)
if not res :
sys.stderr.write("cannot apply predicted movement\n")
mvt_type = mvt[0]
mvt_label = mvt[1]
if mvt_type != "SHIFT" :
sys.stderr.write("try to force SHIFT\n")
res = c.shift()
if res == False :
sys.stderr.write("try to force REDUCE\n")
res = c.red()
if res == False :
sys.stderr.write("abort sentence\n")
break
if(c.isFinal()):
break
for i in range(1, c.getBuffer().getLength()):
w = c.getBuffer().getWord(i)
w.affiche(mcd)
print('')
# print('\t', w.getFeat("GOV"), end='\t')
# print(w.getFeat("LABEL"))
numSent += 1
# if numSent % 10 == 0:
# print ("Sent : ", numSent)
File moved
import sys
import numpy as np
import torch
from torch import nn
def readData(dataFilename) :
allX = []
allY = []
try:
# dataFile = open(dataFilename, encoding='utf-8')
dataFile = open(dataFilename)
except IOError:
print(dataFilename, " : ce fichier n'existe pas")
exit(1)
inputSize = int(dataFile.readline())
print("input size = ", inputSize)
outputSize = int(dataFile.readline())
print("output size = ", outputSize)
inputLine = True
for ligne in dataFile:
# print(ligne)
vector = ligne.split()
vector[:] = list(map(int, vector))
if inputLine == True:
#print("input ", vector)
allX.append(vector)
inputLine = False
else:
#print("output ", vector)
allY.append(vector)
inputLine = True
# x_train and y_train are Numpy arrays
x_train = np.array(allX)
y_train = np.array(allY)
return (inputSize, outputSize, x_train, y_train)
if len(sys.argv) < 3 :
print('usage:', sys.argv[0], 'cffTrainFileName cffDevFileName pytorchModelFileName')
exit(1)
cffTrainFileName = sys.argv[1]
cffDevFileName = sys.argv[2]
kerasModelFileName = sys.argv[3]
inputSize, outputSize, x_train, y_train = readData(cffTrainFileName)
devInputSize, devOutputSize, x_dev, y_dev = readData(cffDevFileName)
model = mlp()
model = nn.Sequential()
model.add_module("dense1", nn.Linear(8,12))
model = Sequential()
model.add(Dense(units=128, activation='relu', input_dim=inputSize))
model.add(Dropout(0.4))
model.add(Dense(units=outputSize, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_dev,y_dev))
#if len(sys.argv) == 5 :
# model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_dev,y_dev))
#else :
# model.fit(x_train, y_train, epochs=10, batch_size=32)
model.save(kerasModelFileName)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment