diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ea9142a4cd47121e9eaf3f1901fa05758ac0070d --- /dev/null +++ b/.gitignore @@ -0,0 +1,173 @@ +*.tgz + +data/ + +expe/out/ + +slides/ + +*.pytorch + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/README.md b/README.md index 7ea5d2ee284d976ffc0e4560751bcadd6b0bec55..8c7885984884af32b6ff7469ebe8a38a75b7f92d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,19 @@ # tbp + +## List of todos + + +### Todo + +- [ ] complete implemented models with word and POS embeddings @tatiana.bladier + - [ ] word embeddings + - [ ] POS embeddings + +### In Progress + +- [ ] + +### Done ✓ + +- [x] diff --git a/expe/PLE.dic b/expe/PLE.dic index 6a580e2be22de852b0ccb6a7abeb8228da83adc7..a28118463aa57b214f047e29d708268f99c17059 100644 --- a/expe/PLE.dic +++ b/expe/PLE.dic @@ -1,9 +1,62 @@ ##POS NULL ROOT +DET +NOUN +ADP +VERB +ADJ +PUNCT +CCONJ +PROPN +AUX +ADV +PRON +NUM +X +SCONJ +PART +INTJ +SYM ##LABEL NULL ROOT +det +nsubj +case +nmod +root +amod +obl +punct +cc +conj +aux +advmod +xcomp +appos +acl +dep +obj +iobj +cop +mark +advcl +flat +fixed +nummod +compound +ccomp +parataxis +expl +csubj +dislocated +orphan +discourse +vocative +goeswith ##EOS NULL ROOT +0 +1 diff --git a/expe/vazy.sh b/expe/vazy.sh index f0baec955868ff111a6b9b3ee11dec601de7f494..c0c7a7fd3f349a8601a58cab416cb4d846abebd9 100644 --- a/expe/vazy.sh +++ b/expe/vazy.sh @@ -1,4 +1,5 @@ lang=$1 +mkdir -p out train_conll="../data/train_${lang}.conllu" train_proj_conll="./out/train_${lang}_proj.conllu" train_mcf="./out/train_${lang}_pgle.mcf" @@ -12,6 +13,7 @@ dev_cff="./out/dev_${lang}.cff" dev_word_limit="5000" test_conll="../data/test_${lang}.conllu" + test_mcf="./out/test_${lang}_pgle.mcf" test_mcf_hyp="./out/test_${lang}_hyp.mcf" test_word_limit="700" @@ -20,7 +22,7 @@ feat_model="basic.fm" dicos="./out/${lang}_train.dic" dicos="PLE.dic" -model="./out/${lang}.keras" +model="./out/${lang}.pytorch" results="./out/${lang}.res" mcd_pgle="PGLE.mcd" @@ -41,11 +43,17 @@ python3 ../src/mcf2cff.py $dev_mcf $feat_model $mcd_pgle $dicos $dev_cff $dev_wo python3 ../src/mcf2cff.py $train_mcf $feat_model $mcd_pgle $dicos $train_cff $train_word_limit -python3 ../src/tbp_train.py $train_cff $dev_cff $model +#python3 ../src/tbp_train.py $train_cff $dev_cff $model +python3 ../src/tbp_train_pytorch.py $train_cff $dev_cff $model +#python3 ../src/tbp_train_keras.py $train_cff $dev_cff $model + + +#python3 ../src/tbp_decode.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit > $test_mcf_hyp +python3 ../src/tbp_decode_pytorch.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit $train_cff $dev_cff > $test_mcf_hyp +#python3 ../src/tbp_decode_keras.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit > $test_mcf_hyp -python3 ../src/tbp_decode.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit > $test_mcf_hyp -python3 ../src/eval_mcf.py $test_mcf $test_mcf_hyp $mcd_pgle $mcd_pgle $lang > $results +python3 ../src/eval_mcf.py $test_mcf $test_mcf_hyp $mcd_pgle $mcd_pgle $lang verbose > $results diff --git a/src/Config.py b/src/Config.py index 212a34506d42ade61949469c8e79f71d00d6ba33..0c1b9200bda405a887fb00990e7913ccf66770e8 100644 --- a/src/Config.py +++ b/src/Config.py @@ -37,7 +37,8 @@ class Config: if(self.getStack().isEmpty()): sys.stderr.write("cannot reduce an empty stack !\n") return False - + + #print('AHHHHHHHHH', int(self.getBuffer().getWord(self.getStack().top()).getFeat('GOV'))) if int(self.getBuffer().getWord(self.getStack().top()).getFeat('GOV')) == Word.invalidGov() : sys.stderr.write("cannot reduce the stack if top element does not have a governor !\n") return False @@ -52,6 +53,7 @@ class Config: govIndex = self.getStack().top() depIndex = self.getBuffer().currentIndex + #print('560 govIndex, depIndex ', govIndex, depIndex) self.getBuffer().getCurrentWord().setFeat('LABEL', label) self.getBuffer().getCurrentWord().setFeat('GOV', str(govIndex - depIndex)) self.getBuffer().getWord(self.getStack().top()).addRightDaughter(depIndex) @@ -66,6 +68,7 @@ class Config: govIndex = self.getBuffer().currentIndex depIndex = self.getStack().top() + #print('561 govIndex, depIndex ', govIndex, depIndex) self.getBuffer().getWord(self.getStack().top()).setFeat('LABEL', label) self.getBuffer().getWord(self.getStack().top()).setFeat('GOV', str(govIndex - depIndex)) self.getBuffer().getCurrentWord().addLeftDaughter(depIndex) @@ -126,19 +129,19 @@ class Config: container = featTuple[2] index = featTuple[3] - word = self.getWordWithRelativeIndex(containe, index) + word = self.getWordWithRelativeIndex(container, index) if word == None : return 'NULL' - return string(len(word.getLeftDaughters())) + return str(len(word.getLeftDaughters())) def getNrDepFeat(self, featTuple): container = featTuple[2] index = featTuple[3] - word = self.getWordWithRelativeIndex(containe, index) + word = self.getWordWithRelativeIndex(container, index) if word == None : return 'NULL' - return string(len(word.getRightDaughters())) + return str(len(word.getRightDaughters())) def getLlDepFeat(self, featTuple): @@ -148,7 +151,7 @@ class Config: return 'NULL' def getStackHeightFeat(self, featTuple): - string(self.getStack().getLength()) + str(self.getStack().getLength()) def getDistFeat(self, featTuple): containerWord1 = featTuple[1] @@ -193,7 +196,7 @@ class Config: featVec = [] i = 0 for f in FeatModel.getFeatArray(): -# print(f, '=', self.getWordFeat(f)) + #print(f, '=', self.getWordFeat(f)) featVec.append(self.getFeat(f)) # featVec.append(self.getWordFeat(f)) i += 1 diff --git a/src/Dicos.py b/src/Dicos.py index a84cab76de8ecc1d4b46e82b832158575c4628a6..c0d8aa72cc56ab420cda183cfcf875acd5966fc6 100644 --- a/src/Dicos.py +++ b/src/Dicos.py @@ -5,9 +5,11 @@ class Dicos: self.content = {} if mcd : self.initializeWithMcd(mcd) + if fileName : self.initializeWithDicoFile(fileName) + def initializeWithMcd(self, mcd): for index in range(mcd.getNbCol()): if(mcd.getColStatus(index) == 'KEEP') and (mcd.getColType(index) == 'SYM') : diff --git a/src/FeatModel.py b/src/FeatModel.py index 8c344feca37f69d2f806140d6edf7c8c94446014..207e2d40c22e499e56acc07972f4b0a5939b09cb 100644 --- a/src/FeatModel.py +++ b/src/FeatModel.py @@ -61,7 +61,7 @@ class FeatModel: label = self.getFeatLabel(i) size = dicos.getDico(label).getSize() position = dicos.getCode(label, featVec[i]) - #print('featureName = ', featureName, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin) + #print('featureName = ', label, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin) inputVector[origin + position] = 1 origin += size return inputVector @@ -73,7 +73,7 @@ class FeatModel: label = self.getFeatLabel(i) size = dicos.getDico(label).getSize() position = dicos.getCode(label, featVec[i]) -# print('featureName = ', featureName, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin) + #print('featureName = ', label, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin) # print('value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin) inputVector[i] = position origin += size diff --git a/src/Moves.py b/src/Moves.py index 4232bdb00fea8c6489a1c77573812befb1d82d21..84b71893505935dd807e3ac3e3fafb89fb1351f4 100644 --- a/src/Moves.py +++ b/src/Moves.py @@ -3,6 +3,7 @@ import numpy as np class Moves: def __init__(self, dicos): self.dicoLabels = dicos.getDico('LABEL') + if not self.dicoLabels : print("cannot find LABEL in dicos") exit(1) @@ -26,14 +27,18 @@ class Moves: if(mvtType == 'LEFT'): return 3 + 2 * labelCode + 1 def mvtDecode(self, mvt_Code): - if(mvt_Code == 0) : return ('SHIFT', 'NULL') + if(mvt_Code == 0) : + return ('SHIFT', 'NULL') if(mvt_Code == 1) : return ('REDUCE', 'NULL') if(mvt_Code == 2) : return ('ROOT', 'NULL') if mvt_Code % 2 == 0 : #LEFT labelCode = int((mvt_Code - 4) / 2) + #print("label code: ", labelCode) + return ('LEFT', self.dicoLabels.getSymbol(labelCode)) else : labelCode = int((mvt_Code - 3)/ 2) + #print("label code: ", labelCode) return ('RIGHT', self.dicoLabels.getSymbol(labelCode)) def buildOutputVectorOneHot(self, mvt): @@ -46,4 +51,5 @@ class Moves: outputVector = np.zeros(1, dtype="int32") codeMvt = self.mvtCode(mvt) outputVector[0] = codeMvt + return outputVector diff --git a/src/Word.py b/src/Word.py index fe92bd51f53c0062cfeee39030add45ff9250a12..f87b817735ecf5dc9538f8b950c84bd7d5466f8a 100644 --- a/src/Word.py +++ b/src/Word.py @@ -4,7 +4,7 @@ class Word: self.leftDaughters = [] # liste des indices des dépendants gauches self.rightDaughters = [] # liste des indices des dépendants droits self.index = self.invalidIndex() - + def getFeat(self, featName): if(not featName in self.featDic): print('WARNING : feat', featName, 'does not exist') @@ -40,7 +40,7 @@ class Word: else: print('\t', end='') print(self.getFeat(mcd.getColName(columnNb)), end='') -# print('') + print('') @staticmethod def fakeWordConll(): diff --git a/src/WordBuffer.py b/src/WordBuffer.py index f2fba582402d6627a4e91c81aa4d8ce0798855b1..9208b7f2597bd07c3a5fefa87cfded4bf79668d8 100644 --- a/src/WordBuffer.py +++ b/src/WordBuffer.py @@ -32,7 +32,8 @@ class WordBuffer: def affiche(self, mcd): for w in self.array: w.affiche(mcd) - print('') + #print('') + #print(w.affiche(mcd)) def getLength(self): return len(self.array) diff --git a/src/eval_mcf.py b/src/eval_mcf.py index 22c26336921ed089ae601b8fd1b8e3f240aef557..c236c6d83d6df0d68473d387703953b5778c8b00 100644 --- a/src/eval_mcf.py +++ b/src/eval_mcf.py @@ -19,10 +19,10 @@ else: verbose = False -#print('reading mcd from file :', refMcdFileName) +print('reading mcd from file :', refMcdFileName) refMcd = Mcd(refMcdFileName) -#print('reading mcd from file :', hypMcdFileName) +print('reading mcd from file :', hypMcdFileName) hypMcd = Mcd(hypMcdFileName) GovColIndex = refMcd.locateCol('GOV') diff --git a/src/mcf2cff.py b/src/mcf2cff.py index 3e34b47e79ef4dccfce1926bfb6f6bbdf73bb7a6..11f9f7592c08eb72ca5412444da486195482ad31 100644 --- a/src/mcf2cff.py +++ b/src/mcf2cff.py @@ -96,7 +96,7 @@ featModel = FeatModel(featModelFileName, dicos) inputSize = featModel.getInputSize() outputSize = moves.getNb() -print('input size = ', inputSize, 'outputSize =' , outputSize) +print('featModel input size = ', inputSize, 'outputSize =' , outputSize) print('preparing training data') prepareData(mcd, mcfFileName, featModel, moves, dataFileName, wordsLimit) diff --git a/src/plot_lib.py b/src/plot_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..f4152a38ae9d0f78d97e272176bd94edcf19def4 --- /dev/null +++ b/src/plot_lib.py @@ -0,0 +1,153 @@ +from matplotlib import pyplot as plt +import matplotlib.cm as cm +import numpy as np +import torch +from IPython.display import HTML, display + + +def set_default(figsize=(10, 10), dpi=100): + plt.style.use(['dark_background', 'bmh']) + plt.rc('axes', facecolor='k') + plt.rc('figure', facecolor='k') + plt.rc('figure', figsize=figsize, dpi=dpi) + + +def plot_data(X, y, d=0, auto=False, zoom=1): + X = X.cpu() + y = y.cpu() + plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral) + plt.axis('square') + plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom) + if auto is True: plt.axis('equal') + plt.axis('off') + + _m, _c = 0, '.15' + plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0) + plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0) + + +def plot_model(X, y, model): + model.cpu() + mesh = np.arange(-1.1, 1.1, 0.01) + xx, yy = np.meshgrid(mesh, mesh) + with torch.no_grad(): + data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float() + Z = model(data).detach() + Z = np.argmax(Z, axis=1).reshape(xx.shape) + plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3) + plot_data(X, y) + + +def show_scatterplot(X, colors, title='', axis=True): + colors = cm.rainbow(np.linspace(0, 1, len(ys))) + + X = X.numpy() + # plt.figure() + plt.axis('equal') + plt.scatter(X[:, 0], X[:, 1], c=colors, s=30) + # plt.grid(True) + plt.title(title) + plt.axis('off') + _m, _c = 0, '.15' + if axis: + plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0) + plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0) + + +def plot_bases(bases, plotting=True, width=0.04): + bases[2:] -= bases[:2] + # if plot_bases.a: plot_bases.a.set_visible(False) + # if plot_bases.b: plot_bases.b.set_visible(False) + if plotting: + plot_bases.a = plt.arrow(*bases[0], *bases[2], width=width, color='r', zorder=10, alpha=1., length_includes_head=True) + plot_bases.b = plt.arrow(*bases[1], *bases[3], width=width, color='g', zorder=10, alpha=1., length_includes_head=True) + + +plot_bases.a = None +plot_bases.b = None + + +def show_mat(mat, vect, prod, threshold=-1): + # Subplot grid definition + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True, + gridspec_kw={'width_ratios':[5,1,1]}) + # Plot matrices + cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1)) + ax2.matshow(vect.numpy(), clim=(-1, 1)) + cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1)) + + # Set titles + ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}') + ax2.set_title(f'a^(i): {vect.numel()}') + ax3.set_title(f'p: {prod.numel()}') + + # Remove xticks for vectors + ax2.set_xticks(tuple()) + ax3.set_xticks(tuple()) + + # Plot colourbars + fig.colorbar(cax1, ax=ax2) + fig.colorbar(cax3, ax=ax3) + + # Fix y-axis limits + ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5) + + +colors = dict( + aqua='#8dd3c7', + yellow='#ffffb3', + lavender='#bebada', + red='#fb8072', + blue='#80b1d3', + orange='#fdb462', + green='#b3de69', + pink='#fccde5', + grey='#d9d9d9', + violet='#bc80bd', + unk1='#ccebc5', + unk2='#ffed6f', +) + + +def _cstr(s, color='black'): + if s == ' ': + return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>' + else: + return f'<text style=color:#000;background-color:{color}>{s} </text>' + +# print html +def _print_color(t): + display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t]))) + +# get appropriate color for value +def _get_clr(value): + colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8', + '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8', + '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f', + '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e') + value = int((value * 100) / 5) + if value == len(colors): value -= 1 # fixing bugs... + return colors[value] + +def _visualise_values(output_values, result_list): + text_colours = [] + for i in range(len(output_values)): + text = (result_list[i], _get_clr(output_values[i])) + text_colours.append(text) + _print_color(text_colours) + +def print_colourbar(): + color_range = torch.linspace(-2.5, 2.5, 20) + to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range] + _print_color(to_print) + + +# Let's only focus on the last time step for now +# First, the cell state (Long term memory) +def plot_state(data, state, b, decoder, idx_to_symbol): + actual_data = decoder(data[b, :, :].numpy(), idx_to_symbol) + seq_len = len(actual_data) + seq_len_w_pad = len(state) + for s in range(state.size(2)): + states = torch.sigmoid(state[:, b, s]) + _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data)) \ No newline at end of file diff --git a/src/pytorch_utils.py b/src/pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0427d751c43daeae77697508706ba2d5f8dc79a --- /dev/null +++ b/src/pytorch_utils.py @@ -0,0 +1,198 @@ +import math + +import numpy as np +import six + +""" read file in cff format + + the contents of such files look like this: + 133 + 75 + 0 + 0 0 1 4 2 0 0 + 0 + 0 1 4 2 3 1 0 + + the first two lines represent TODO +""" +def readFile_cff(filepath): + lines = [] + with open(filepath) as file_in: + for line in file_in: + lines.append(line) + inputSize = int(lines[0].strip()) + outputSize = int(lines[1].strip()) + items_labels = [x.strip() for x in lines[2:]] + items = items_labels[1::2] + labels = items_labels[::2] + return items, labels, inputSize, outputSize + + +#flatten a list +def flatten(xss): + return [x for xs in xss for x in xs] + + +#find how many classes we have in data: +def find_unique_classes(*inputlists): + unique_classes = np.unique(flatten([*inputlists])).tolist() + unique_classes = [str(x) for x in unique_classes] + number_of_unique_classes = len(unique_classes) + return unique_classes, number_of_unique_classes + + +def find_unique_symbols(*inputlists): + flat_list = flatten([*inputlists]) + maxlen = len(max(flat_list, key=len)) + unique_elem_list = [] + for elem in flat_list: + tokens = [x for x in elem.split(' ')] + for t in tokens: + if not t in unique_elem_list: + unique_elem_list.append(t) + return maxlen, unique_elem_list + [' '] + + +def pad_sequences(sequences, maxlen=None, dtype='int32', + padding='pre', truncating='pre', value=0.): + if not hasattr(sequences, '__len__'): + raise ValueError('`sequences` must be iterable.') + lengths = [] + for x in sequences: + if not hasattr(x, '__len__'): + raise ValueError('`sequences` must be a list of iterables. ' + 'Found non-iterable: ' + str(x)) + lengths.append(len(x)) + + num_samples = len(sequences) + #print("NUM SAMPLES", num_samples) + if maxlen is None: + maxlen = np.max(lengths) + + # take the sample shape from the first non empty sequence + # checking for consistency in the main loop below. + sample_shape = tuple() + for s in sequences: + if len(s) > 0: + sample_shape = np.asarray(s).shape[1:] + break + + is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_) + if isinstance(value, six.string_types) and dtype != object and not is_dtype_str: + raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n" + "You should set `dtype=object` for variable length strings." + .format(dtype, type(value))) + + x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) + for idx, s in enumerate(sequences): + if not len(s): + continue # empty list/array was found + if truncating == 'pre': + trunc = s[-maxlen:] + elif truncating == 'post': + trunc = s[:maxlen] + else: + raise ValueError('Truncating type "%s" ' + 'not understood' % truncating) + + # check `trunc` has expected shape + trunc = np.asarray(trunc, dtype=dtype) + if trunc.shape[1:] != sample_shape: + raise ValueError('Shape of sample %s of sequence at position %s ' + 'is different from expected shape %s' % + (trunc.shape[1:], idx, sample_shape)) + + if padding == 'post': + x[idx, :len(trunc)] = trunc + elif padding == 'pre': + x[idx, -len(trunc):] = trunc + else: + raise ValueError('Padding type "%s" not understood' % padding) + return x + + +def to_categorical(y, num_classes=None, dtype='float32'): + y = np.array(y, dtype='int') + input_shape = y.shape + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + y = y.ravel() + if not num_classes: + num_classes = np.max(y) + 1 + n = y.shape[0] + categorical = np.zeros((n, num_classes), dtype=dtype) + categorical[np.arange(n), y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical + + +def make_pytorch_dicts(*input_paths): + input_paths = [x for x in input_paths] + item_lsts = [] + label_lsts = [] + for inpath in input_paths: + items_list, labels_list, _, _ = readFile_cff(inpath) + item_lsts.append(items_list) + label_lsts.append(labels_list) + items_list = flatten(item_lsts) + labels_list = flatten(label_lsts) + + classes, n_classes = find_unique_classes(labels_list) # possible moves + maxlen, all_symbols = find_unique_symbols(items_list) # basically, the number of POS tags + n_symbols = len(all_symbols) + + #print('Number of classes: ', n_classes) + #print('Number of symbols: ', n_symbols) + #print('Max length of sequence: ', maxlen) + + symbol_to_idx = {s: n for n, s in enumerate(all_symbols)} + idx_to_symbol = {n: s for n, s in enumerate(all_symbols)} + + class_to_idx = {c: n for n, c in enumerate(classes)} + idx_to_class = {n: c for n, c in enumerate(classes)} + + return n_classes, maxlen, n_symbols, symbol_to_idx, idx_to_symbol, class_to_idx, idx_to_class, classes + +def encode_x_batch(x_batch, symbol_to_idx, n_symbols): + return pad_sequences([encode_x(x, symbol_to_idx, n_symbols) for x in x_batch],maxlen=15) # TODO read maxlen from the file + + +def encode_x(x, symbol_to_idx, n_symbols): + idx_x = [symbol_to_idx[s] for s in x] + return to_categorical(idx_x, num_classes=n_symbols) + + +def encode_y_batch(y_batch, class_to_idx, n_classes): + return np.array([encode_y(y, class_to_idx, n_classes) for y in y_batch]) + + +def encode_y(y, class_to_idx, n_classes): + idx_y = class_to_idx[y] + return to_categorical(idx_y, num_classes=n_classes) + + +def preprocess_data(items_list, labels_list, batch_size, symbol_to_idx, class_to_idx, n_symbols, n_classes): + upp_bound = int(math.ceil(len(items_list) / batch_size)) + data_batches = [] + i = 0 + for i in range(upp_bound -1): + a = i * batch_size + b = (i + 1) * batch_size + data_batches.append((encode_x_batch(items_list[a:b], symbol_to_idx, n_symbols), encode_y_batch(labels_list[a:b], class_to_idx, n_classes))) + i+=1 + return data_batches + + +def decode_x(x, idx_to_symbol): + x = x[np.sum(x, axis=1) > 0] # remove padding + return ''.join([idx_to_symbol[pos] for pos in np.argmax(x, axis=1)]) + +def decode_y(y, idx_to_class): + return idx_to_class[np.argmax(y)] + +def decode_x_batch(x_batch, idx_to_symbol): + return [decode_x(x, idx_to_symbol) for x in x_batch] + +def decode_y_batch(y_batch, idx_to_class): + return [idx_to_class[pos] for pos in np.argmax(y_batch, axis=1)] \ No newline at end of file diff --git a/src/tbp_decode_pytorch.py b/src/tbp_decode_pytorch.py index c7e139992109bcf2fbae735241be9a2af078b53c..13c6b317e7602b6eb203998613556835512b9881 100644 --- a/src/tbp_decode_pytorch.py +++ b/src/tbp_decode_pytorch.py @@ -9,6 +9,7 @@ from FeatModel import FeatModel import torch import numpy as np +from pytorch_utils import * @@ -23,9 +24,9 @@ def prepareWordBufferForDecode(buffer): word.setFeat('LABEL', Word.invalidLabel()) -verbose = False -if len(sys.argv) != 7 : - print('usage:', sys.argv[0], 'mcf_file model_file dicos_file feat_model mcd_file words_limit') +verbose = True +if len(sys.argv) != 9 : + print('usage:', sys.argv[0], 'mcf_file model_file dicos_file feat_model mcd_file words_limit train_file dev_file') exit(1) mcf_file = sys.argv[1] @@ -34,76 +35,178 @@ dicos_file = sys.argv[3] feat_model = sys.argv[4] mcd_file = sys.argv[5] wordsLimit = int(sys.argv[6]) +train_fr_file = sys.argv[7] +dev_fr_file = sys.argv[8] +########################################################################## + +n_classes, maxlen, n_symbols, symbol_to_idx, idx_to_symbol, class_to_idx, idx_to_class, _ = make_pytorch_dicts(dev_fr_file, train_fr_file) +########################################################################## -sys.stderr.write('reading mcd from file :') -sys.stderr.write(mcd_file) -sys.stderr.write('\n') -mcd = Mcd(mcd_file) +#print('reading mcd from file : ', mcd_file) # PGLE.mcd + +mcd = Mcd(mcd_file) # MCD = multi column descriptions + -sys.stderr.write('loading dicos\n') -dicos = Dicos(fileName=dicos_file) + +#print('loading dicos from : ', dicos_file) # PLE.dic +dicos = Dicos(fileName=dicos_file) # dictionaries with class labels for columns from mcd moves = Moves(dicos) -sys.stderr.write('reading feature model from file :') -sys.stderr.write(feat_model) -sys.stderr.write('\n') +#print('reading feature model from file : ', feat_model) # basic.fm + featModel = FeatModel(feat_model, dicos) -sys.stderr.write('loading model :') -sys.stderr.write(model_file) -sys.stderr.write('\n') -model = load_model(model_file) +#print('loading pytorch model from :', model_file) # /home/taniabladier/Programming/AMU/tbp/expe/out/fr.pytorch + + +############################ + +""" +load saved pytorch model +""" + +batch_size = 32 + + +# Setup the RNN and training settings +import torch.nn as nn +import torch.nn.functional as F + +class SimpleLSTM(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.lstm = torch.nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = torch.nn.Linear(hidden_size, output_size) + + def forward(self, x): + h = self.lstm(x)[0] + x = self.linear(h) + return x + + def get_states_across_time(self, x): + h_c = None + h_list, c_list = list(), list() + with torch.no_grad(): + for t in range(x.size(1)): + h_c = self.lstm(x[:, [t], :], h_c)[1] + h_list.append(h_c[0]) + c_list.append(h_c[1]) + h = torch.cat(h_list) + c = torch.cat(c_list) + return h, c + + + + +input_size = 133 #n_symbols +hidden_size = 128 +output_size = 75 #n_classes + + +print('input output', input_size, output_size) +model = SimpleLSTM(input_size, hidden_size, output_size) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001) + + +checkpoint = torch.load(model_file, map_location=torch.device('cpu'), weights_only=False) +model.load_state_dict(checkpoint['model_state_dict']) +optimizer.load_state_dict(checkpoint['optimizer_state_dict']) +epoch = checkpoint['epoch'] +loss = checkpoint['loss'] -inputSize = featModel.getInputSize() -outputSize = moves.getNb() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = model.to(device) +model.eval(); + + +############################ + + +#model = load_model(model_file) + +#inputSize = featModel.getInputSize() +#outputSize = moves.getNb() c = Config(mcf_file, mcd, dicos) + + + numSent = 0 -verbose = False +verbose = True numWords = 0 + while c.getBuffer().readNextSentence() and numWords < wordsLimit : c.getStack().empty() prepareWordBufferForDecode(c.getBuffer()) + numWords += c.getBuffer().getLength() while True : featVec = c.extractFeatVec(featModel) inputVector = featModel.buildInputVector(featVec, dicos) - outputVector = model.predict(inputVector.reshape((1,inputSize)), batch_size=1, verbose=0, steps=None) - mvt_Code = outputVector.argmax() + + ############### + inputVector = ' '.join(str(x) for x in inputVector) + inputVector = encode_x_batch([inputVector], symbol_to_idx, 133) #n_symbols) + inputVector = torch.from_numpy(inputVector).float().to(device) + + + with torch.no_grad(): + output = model(inputVector) + + # Pick only the output corresponding to last sequence element (input is pre padded) + output = output[:, -1, :] + preds = output.argmax(dim=1) + mvt_Code = int([idx_to_class[y.item()] for y in preds][0]) + + + ################## + + #mvt_Code = outputVector.argmax() + #mvt = moves.mvtDecode(mvt_Code) + mvt = moves.mvtDecode(mvt_Code) + verbose = True + if(verbose == True) : - print("------------------------------------------") - c.affiche() - print('predicted move', mvt[0], mvt[1]) - print(mvt, featVec) - res = c.applyMvt(mvt) + """ + print("\n\n------------------------------------------\n") + print('inputVector ::', featModel.buildInputVector(featVec, dicos)) + + print('mvt code: ', mvt_Code, ' move ::: ', mvt) + + c.affiche() + + print('predicted move: ', mvt[0], mvt[1], 'for', str(featVec)) + """ + + res = c.applyMvt(mvt) #result, True or False + #print("RES", res) if not res : - sys.stderr.write("cannot apply predicted movement\n") + print("cannot apply predicted movement\n") mvt_type = mvt[0] mvt_label = mvt[1] if mvt_type != "SHIFT" : - sys.stderr.write("try to force SHIFT\n") + print("try to force SHIFT\n") res = c.shift() if res == False : - sys.stderr.write("try to force REDUCE\n") + print("try to force REDUCE\n") res = c.red() if res == False : - sys.stderr.write("abort sentence\n") + print("abort sentence\n") break if(c.isFinal()): break for i in range(1, c.getBuffer().getLength()): w = c.getBuffer().getWord(i) w.affiche(mcd) - print('') -# print('\t', w.getFeat("GOV"), end='\t') -# print(w.getFeat("LABEL")) numSent += 1 # if numSent % 10 == 0: diff --git a/src/tbp_train_pytorch.py b/src/tbp_train_pytorch.py index 314948443c060006198f0fe6c1dbd9241aa46109..2ed492617a3fc641939bc55f6dafbd3114e69441 100644 --- a/src/tbp_train_pytorch.py +++ b/src/tbp_train_pytorch.py @@ -1,8 +1,15 @@ import sys import numpy as np import torch -from torch import nn +import torch.nn as nn +import torch.nn.functional as F +from pytorch_utils import * +from plot_lib import * +import os +"""## 1. Reading Data Files""" + +""" def readData(dataFilename) : allX = [] allY = [] @@ -37,6 +44,7 @@ def readData(dataFilename) : x_train = np.array(allX) y_train = np.array(allY) return (inputSize, outputSize, x_train, y_train) +""" @@ -46,8 +54,437 @@ if len(sys.argv) < 3 : cffTrainFileName = sys.argv[1] cffDevFileName = sys.argv[2] -kerasModelFileName = sys.argv[3] +pytorchModelFileName = sys.argv[3] + + +batch_size = 32 + + +n_classes, maxlen, n_symbols, symbol_to_idx, idx_to_symbol, \ + class_to_idx, idx_to_class, classes = make_pytorch_dicts(cffTrainFileName, cffDevFileName) + + +train_items_list, train_labels_list, train_inputSize, train_outputSize = readFile_cff(cffTrainFileName) +dev_items_list, dev_labels_list, dev_inputSize, dev_outputSize = readFile_cff(cffDevFileName) + + +train_data_gen = preprocess_data(train_items_list[:800000], train_labels_list[:800000], batch_size, symbol_to_idx, class_to_idx, train_inputSize, train_outputSize)#, n_symbols, n_classes) +dev_data_gen = preprocess_data(dev_items_list[:200000], dev_labels_list[:200000], batch_size, symbol_to_idx, class_to_idx, train_inputSize, train_outputSize)#, n_symbols, n_classes) + + +"""## 2. Defining the Model""" + + +# Set the random seed for reproducible results +torch.manual_seed(1) + +# Setup the RNN and training settings +input_size = 133 #train_inputSize #n_symbols +hidden_size = 128 +output_size = 75 #train_outputSize #n_classes + +class SimpleMLP(nn.Module): + def __init__(self, input_size, output_size): + super(SimpleMLP, self).__init__() + self.fc1 = nn.Linear(input_size, 200) # Dense layer with 128 units + self.dropout = nn.Dropout(0.5) # Dropout with 0.4 probability + self.fc2 = nn.Linear(200, output_size) # Dense layer with outputSize units + + def forward(self, x): + x = F.relu(self.fc1(x)) # Apply ReLU activation + x = self.dropout(x) # Apply Dropout + x = F.softmax(self.fc2(x), dim=1) # Apply Softmax activation + return x + + def get_states_across_time(self, x): + h_c = None + h_list, c_list = list(), list() + with torch.no_grad(): + for t in range(x.size(1)): + h_c = self.lstm(x[:, [t], :], h_c)[1] + h_list.append(h_c[0]) + c_list.append(h_c[1]) + h = torch.cat(h_list) + c = torch.cat(c_list) + return h, c + + +class SimpleRNN(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + # This just calls the base class constructor + super().__init__() + # Neural network layers assigned as attributes of a Module subclass + # have their parameters registered for training automatically. + self.rnn = torch.nn.RNN(input_size, hidden_size, nonlinearity='relu', batch_first=True) + self.linear = torch.nn.Linear(hidden_size, output_size) + + def forward(self, x): + # The RNN also returns its hidden state but we don't use it. + # While the RNN can also take a hidden state as input, the RNN + # gets passed a hidden state initialized with zeros by default. + h = self.rnn(x)[0] + x = self.linear(h) + return x + +class SimpleLSTM(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.lstm = torch.nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = torch.nn.Linear(hidden_size, output_size) + + def forward(self, x): + h = self.lstm(x)[0] + x = self.linear(h) + return x + + def get_states_across_time(self, x): + h_c = None + h_list, c_list = list(), list() + with torch.no_grad(): + for t in range(x.size(1)): + h_c = self.lstm(x[:, [t], :], h_c)[1] + h_list.append(h_c[0]) + c_list.append(h_c[1]) + h = torch.cat(h_list) + c = torch.cat(c_list) + return h, c + + +"""## 3. Defining the Training Loop""" + +def train(model, train_data_gen, criterion, optimizer, device): + # Set the model to training mode. This will turn on layers that would + # otherwise behave differently during evaluation, such as dropout. + model.train() + + # Store the number of sequences that were classified correctly + num_correct = 0 + + # Iterate over every batch of sequences. Note that the length of a data generator + # is defined as the number of batches required to produce a total of roughly 1000 + # sequences given a batch size. + for batch_idx in range(len(train_data_gen)): + + + # Request a batch of sequences and class labels, convert them into tensors + # of the correct type, and then send them to the appropriate device. + data, target = train_data_gen[batch_idx] + + data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device) + + + # Perform the forward pass of the model + output = model(data) # Step ① + + # Pick only the output corresponding to last sequence element (input is pre padded) + output = output[:, -1, :] + + # Compute the value of the loss for this batch. For loss functions like CrossEntropyLoss, + # the second argument is actually expected to be a tensor of class indices rather than + # one-hot encoded class labels. One approach is to take advantage of the one-hot encoding + # of the target and call argmax along its second dimension to create a tensor of shape + # (batch_size) containing the index of the class label that was hot for each sequence. + target = target.argmax(dim=1) + + loss = criterion(output, target) # Step ② + + # Clear the gradient buffers of the optimized parameters. + # Otherwise, gradients from the previous batch would be accumulated. + optimizer.zero_grad() # Step ③ + + loss.backward() # Step ④ + + optimizer.step() # Step ⑤ + + y_pred = output.argmax(dim=1) + num_correct += (y_pred == target).sum().item() + + return num_correct, loss.item() + +"""## 4. Defining the Testing Loop""" + +def test(model, test_data_gen, criterion, device): + # Set the model to evaluation mode. This will turn off layers that would + # otherwise behave differently during training, such as dropout. + model.eval() + + # Store the number of sequences that were classified correctly + num_correct = 0 + + # A context manager is used to disable gradient calculations during inference + # to reduce memory usage, as we typically don't need the gradients at this point. + with torch.no_grad(): + for batch_idx in range(len(test_data_gen)): + data, target = test_data_gen[batch_idx] + + data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device) + output = model(data) + # Pick only the output corresponding to last sequence element (input is pre padded) + output = output[:, -1, :] + + target = target.argmax(dim=1) + loss = criterion(output, target) + + y_pred = output.argmax(dim=1) + num_correct += (y_pred == target).sum().item() + + return num_correct, loss.item() + +"""## 5. Putting it All Together""" + +import matplotlib.pyplot as plt +from plot_lib import set_default, plot_state, print_colourbar + +set_default() + +def check_point(model, filename, epochs, optimizer, criterion): + #torch.save(model.state_dict(), filename) + + print(f"Saving model from epoch", epochs) + torch.save({ + 'epoch': epochs, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': criterion, + }, filename) + + +def resume(model, filename): + #model.load_state_dict(torch.load(filename)) + + checkpoint = torch.load(filename, map_location=torch.device('cpu'), weights_only=False) + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + epoch = checkpoint['epoch'] + loss = checkpoint['loss'] + +def train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=True): + # Automatically determine the device that PyTorch should use for computation + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + # Move model to the device which will be used for train and test + model.to(device) + + # Track the value of the loss function and model accuracy across epochs + history_train = {'loss': [], 'acc': []} + history_test = {'loss': [], 'acc': []} + + + early_stop_thresh = 5 + best_accuracy = -1 + best_epoch = -1 + for epoch in range(max_epochs): + # Run the training loop and calculate the accuracy. + # Remember that the length of a data generator is the number of batches, + # so we multiply it by the batch size to recover the total number of sequences. + num_correct, loss = train(model, train_data_gen, criterion, optimizer, device) + accuracy = float(num_correct) / (len(train_data_gen) * batch_size) * 100 + history_train['loss'].append(loss) + history_train['acc'].append(accuracy) + # Do the same for the testing loop + num_correct, loss = test(model, test_data_gen, criterion, device) + accuracy = float(num_correct) / (len(test_data_gen) * batch_size) * 100 + history_test['loss'].append(loss) + history_test['acc'].append(accuracy) + + if history_test['acc'][-1] > best_accuracy: + best_accuracy = history_test['acc'][-1] + best_epoch = epoch + print('epoch', epoch) + check_point(model, pytorchModelFileName, epoch, optimizer, criterion) + elif epoch - best_epoch > early_stop_thresh: + print("Early stopped training at epoch %d" % epoch) + break # terminate the training loop + + if verbose or epoch + 1 == max_epochs: + print(f'[Epoch {epoch + 1}/{max_epochs}]' + f" loss: {history_train['loss'][-1]:.4f}, acc: {history_train['acc'][-1]:2.2f}%" + f" - test_loss: {history_test['loss'][-1]:.4f}, test_acc: {history_test['acc'][-1]:2.2f}%") + + + resume(model, pytorchModelFileName) + + # Generate diagnostic plots for the loss and accuracy + fig, axes = plt.subplots(ncols=2, figsize=(9, 4.5)) + for ax, metric in zip(axes, ['loss', 'acc']): + ax.plot(history_train[metric]) + ax.plot(history_test[metric]) + ax.set_xlabel('epoch', fontsize=12) + ax.set_ylabel(metric, fontsize=12) + ax.legend(['Train', 'Test'], loc='best') + plt.savefig(os.path.abspath('..') + '/expe/out/loss_accuracy.png') + #plt.show() + + return model + + + +"""## 5. Simple RNN: 10 Epochs""" +""" +# Setup the training and test data generators + + + +# Setup the RNN and training settings +input_size = n_symbols +hidden_size = 128 +output_size = n_classes +model = SimpleRNN(input_size, hidden_size, output_size) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001) +max_epochs = 100 + +# Train the model +model = train_and_test(model, train_data_gen, dev_data_gen, criterion, optimizer, max_epochs) + +for parameter_group in list(model.parameters()): + print(parameter_group.size()) + +""" + +"""## 6b SimpleMLP: 10 Epochs + +""" +""" +# Setup the MLP and training settings +input_size = n_symbols +output_size = n_classes +model = SimpleMLP(input_size, output_size) +criterion = torch.nn.CrossEntropyLoss() +#optimizer = torch.optim.Adam(model.parameters(), lr=0.01) +optimizer = torch.optim.Adamax(model.parameters(), lr=0.01) + +max_epochs = 30 + +# Train the model +model = train_and_test(model, train_data_gen, dev_data_gen, criterion, optimizer, max_epochs) + +for parameter_group in list(model.parameters()): + print(parameter_group.size()) + +""" + +"""## 6. Simple LSTM: 10 Epochs""" + + +model = SimpleLSTM(input_size, hidden_size, output_size) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001) + + + +max_epochs = 30 + +# Train the model +model = train_and_test(model, train_data_gen, dev_data_gen, criterion, optimizer, max_epochs) + +#for parameter_group in list(model.parameters()): +# print(parameter_group.size()) + + +"""## 7. Model Evaluation""" + +import collections +import random + +def evaluate_model(model, seed=9001, verbose=False): + # Define a dictionary that maps class indices to labels + #class_idx_to_label = {0: 'Q', 1: 'R', 2: 'S', 3: 'U', 4: '0', 5: '8'} + class_idx_to_label = {n: c for n, c in enumerate(classes)} + + # Create a new data generator + #data_generator = QRSU.prepare_data(seed=seed) + data_generator = preprocess_data(dev_items_list, dev_labels_list, batch_size, symbol_to_idx, class_to_idx, input_size, output_size)#, n_symbols, n_classes) + + # Track the number of times a class appears + count_classes = collections.Counter() + # Keep correctly classified and misclassified sequences, and their + # true and predicted class labels, for diagnostic information. + correct = [] + incorrect = [] + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + model.eval() + + with torch.no_grad(): + for batch_idx in range(len(data_generator)): + data, target = data_generator[batch_idx] + data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device) + + data_decoded = decode_x_batch(data.cpu().numpy(), idx_to_symbol) + target_decoded = decode_y_batch(target.cpu().numpy(), idx_to_class) + + output = model(data) + output = output[:, -1, :] + + target = target.argmax(dim=1) + y_pred = output.argmax(dim=1) + y_pred_decoded = [class_idx_to_label[y.item()] for y in y_pred] + + count_classes.update(target_decoded) + for i, (truth, prediction) in enumerate(zip(target_decoded, y_pred_decoded)): + if truth == prediction: + correct.append((data_decoded[i], truth, prediction)) + else: + incorrect.append((data_decoded[i], truth, prediction)) + + num_sequences = sum(count_classes.values()) + print('len correct', len(correct)) + print('num seqs', num_sequences) + accuracy = float(len(correct)) / num_sequences * 100 + print(f'The accuracy of the model is measured to be {accuracy:.2f}%.\n') + + # Report the accuracy by class + for label in sorted(count_classes): + num_correct = sum(1 for _, truth, _ in correct if truth == label) + #print(f'{label}: {num_correct} / {count_classes[label]} correct') + + print("Number of correct predictions: ", len(correct)) + + # Report some random sequences for examination + print('\nHere are some example sequences:') + if len(correct) > 0: + for i in range(0,3): + sequence, truth, prediction = correct[random.randrange(0, 2)] + #print(f'{sequence} -> {truth} was labelled {prediction}') + + print("Number of incorrect predictions: ", len(incorrect)) + + # Report misclassified sequences for investigation + if len(incorrect) > 0: + print('\nThe following sequences were misclassified:') + for i in range(0,3): + sequence, truth, prediction = incorrect[random.randrange(0, 2)] + print(f'{sequence} -> {truth} was labelled {prediction}') + else: + print('\nThere were no misclassified sequences.') + +evaluate_model(model) + + +""" Visualize Model """ + +# Get hidden (H) and cell (C) batch state given a batch input (X) +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +model.eval() +with torch.no_grad(): + data = dev_data_gen[0][0] + X = torch.from_numpy(data).float().to(device) + H_t, C_t = model.get_states_across_time(X) + +#print("Color range is as follows:") +#print_colourbar() + +#plot_state(X.cpu(), C_t, b=31, decoder=decode_x, idx_to_symbol=idx_to_symbol) # 3, 6, 9 + +#plot_state(X.cpu(), H_t, b=31, decoder=decode_x, idx_to_symbol=idx_to_symbol) #b=9 + + + + +""" inputSize, outputSize, x_train, y_train = readData(cffTrainFileName) devInputSize, devOutputSize, x_dev, y_dev = readData(cffDevFileName) model = mlp() @@ -75,4 +512,4 @@ model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_dev,y_d model.save(kerasModelFileName) - +"""