diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ea9142a4cd47121e9eaf3f1901fa05758ac0070d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,173 @@
+*.tgz
+
+data/
+
+expe/out/
+
+slides/
+
+*.pytorch
+
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+#  and can be added to the global gitignore or merged into this file.  For a more nuclear
+#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
\ No newline at end of file
diff --git a/README.md b/README.md
index 7ea5d2ee284d976ffc0e4560751bcadd6b0bec55..8c7885984884af32b6ff7469ebe8a38a75b7f92d 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,19 @@
 # tbp
 
+
+## List of todos
+
+
+### Todo
+
+- [ ] complete implemented models with word and POS embeddings @tatiana.bladier 
+  - [ ] word embeddings 
+  - [ ] POS embeddings 
+
+### In Progress
+
+- [ ] 
+
+### Done ✓
+
+- [x] 
diff --git a/expe/PLE.dic b/expe/PLE.dic
index 6a580e2be22de852b0ccb6a7abeb8228da83adc7..a28118463aa57b214f047e29d708268f99c17059 100644
--- a/expe/PLE.dic
+++ b/expe/PLE.dic
@@ -1,9 +1,62 @@
 ##POS
 NULL
 ROOT
+DET
+NOUN
+ADP
+VERB
+ADJ
+PUNCT
+CCONJ
+PROPN
+AUX
+ADV
+PRON
+NUM
+X
+SCONJ
+PART
+INTJ
+SYM
 ##LABEL
 NULL
 ROOT
+det
+nsubj
+case
+nmod
+root
+amod
+obl
+punct
+cc
+conj
+aux
+advmod
+xcomp
+appos
+acl
+dep
+obj
+iobj
+cop
+mark
+advcl
+flat
+fixed
+nummod
+compound
+ccomp
+parataxis
+expl
+csubj
+dislocated
+orphan
+discourse
+vocative
+goeswith
 ##EOS
 NULL
 ROOT
+0
+1
diff --git a/expe/vazy.sh b/expe/vazy.sh
index f0baec955868ff111a6b9b3ee11dec601de7f494..c0c7a7fd3f349a8601a58cab416cb4d846abebd9 100644
--- a/expe/vazy.sh
+++ b/expe/vazy.sh
@@ -1,4 +1,5 @@
 lang=$1
+mkdir -p out
 train_conll="../data/train_${lang}.conllu"
 train_proj_conll="./out/train_${lang}_proj.conllu"
 train_mcf="./out/train_${lang}_pgle.mcf"
@@ -12,6 +13,7 @@ dev_cff="./out/dev_${lang}.cff"
 dev_word_limit="5000"
 
 test_conll="../data/test_${lang}.conllu"
+
 test_mcf="./out/test_${lang}_pgle.mcf"
 test_mcf_hyp="./out/test_${lang}_hyp.mcf"
 test_word_limit="700"
@@ -20,7 +22,7 @@ feat_model="basic.fm"
 
 dicos="./out/${lang}_train.dic"
 dicos="PLE.dic"
-model="./out/${lang}.keras"
+model="./out/${lang}.pytorch"
 results="./out/${lang}.res"
 
 mcd_pgle="PGLE.mcd"
@@ -41,11 +43,17 @@ python3 ../src/mcf2cff.py $dev_mcf $feat_model $mcd_pgle $dicos $dev_cff $dev_wo
 
 python3 ../src/mcf2cff.py $train_mcf $feat_model $mcd_pgle $dicos $train_cff $train_word_limit
 
-python3 ../src/tbp_train.py $train_cff $dev_cff $model
+#python3 ../src/tbp_train.py $train_cff $dev_cff $model
+python3 ../src/tbp_train_pytorch.py $train_cff $dev_cff $model
+#python3 ../src/tbp_train_keras.py $train_cff $dev_cff $model
+
+
+#python3 ../src/tbp_decode.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit > $test_mcf_hyp
+python3 ../src/tbp_decode_pytorch.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit $train_cff $dev_cff > $test_mcf_hyp
+#python3 ../src/tbp_decode_keras.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit > $test_mcf_hyp
 
-python3 ../src/tbp_decode.py $test_mcf $model $dicos $feat_model $mcd_pgle $test_word_limit > $test_mcf_hyp
 
-python3 ../src/eval_mcf.py $test_mcf $test_mcf_hyp $mcd_pgle $mcd_pgle $lang > $results
+python3 ../src/eval_mcf.py $test_mcf $test_mcf_hyp $mcd_pgle $mcd_pgle $lang verbose > $results
 
 
 
diff --git a/src/Config.py b/src/Config.py
index 212a34506d42ade61949469c8e79f71d00d6ba33..0c1b9200bda405a887fb00990e7913ccf66770e8 100644
--- a/src/Config.py
+++ b/src/Config.py
@@ -37,7 +37,8 @@ class Config:
                 if(self.getStack().isEmpty()):
                         sys.stderr.write("cannot reduce an empty stack !\n")
                         return False
-        
+                
+                #print('AHHHHHHHHH', int(self.getBuffer().getWord(self.getStack().top()).getFeat('GOV')))
                 if int(self.getBuffer().getWord(self.getStack().top()).getFeat('GOV')) == Word.invalidGov() :
                         sys.stderr.write("cannot reduce the stack if top element does not have a governor !\n")
                         return False
@@ -52,6 +53,7 @@ class Config:
                 
                 govIndex = self.getStack().top()
                 depIndex = self.getBuffer().currentIndex
+                #print('560 govIndex, depIndex ', govIndex, depIndex)
                 self.getBuffer().getCurrentWord().setFeat('LABEL', label)
                 self.getBuffer().getCurrentWord().setFeat('GOV', str(govIndex - depIndex))
                 self.getBuffer().getWord(self.getStack().top()).addRightDaughter(depIndex)
@@ -66,6 +68,7 @@ class Config:
 		
                 govIndex = self.getBuffer().currentIndex
                 depIndex = self.getStack().top()
+                #print('561 govIndex, depIndex ', govIndex, depIndex)
                 self.getBuffer().getWord(self.getStack().top()).setFeat('LABEL', label)
                 self.getBuffer().getWord(self.getStack().top()).setFeat('GOV', str(govIndex - depIndex))
                 self.getBuffer().getCurrentWord().addLeftDaughter(depIndex)
@@ -126,19 +129,19 @@ class Config:
                 container = featTuple[2]
                 index = featTuple[3]
 
-                word = self.getWordWithRelativeIndex(containe, index)
+                word = self.getWordWithRelativeIndex(container, index)
                 if word == None :
                         return 'NULL'
-                return string(len(word.getLeftDaughters()))
+                return str(len(word.getLeftDaughters()))
 
         def getNrDepFeat(self, featTuple):
                 container = featTuple[2]
                 index = featTuple[3]
 
-                word = self.getWordWithRelativeIndex(containe, index)
+                word = self.getWordWithRelativeIndex(container, index)
                 if word == None :
                         return 'NULL'
-                return string(len(word.getRightDaughters()))
+                return str(len(word.getRightDaughters()))
 
 
         def getLlDepFeat(self, featTuple):
@@ -148,7 +151,7 @@ class Config:
                 return 'NULL'
 
         def getStackHeightFeat(self, featTuple):
-                string(self.getStack().getLength())
+                str(self.getStack().getLength())
 
         def getDistFeat(self, featTuple):
                 containerWord1 = featTuple[1]
@@ -193,7 +196,7 @@ class Config:
                 featVec = []
                 i = 0
                 for f in FeatModel.getFeatArray():
-#                        print(f, '=', self.getWordFeat(f))
+                        #print(f, '=', self.getWordFeat(f))
                         featVec.append(self.getFeat(f))
 #                        featVec.append(self.getWordFeat(f))
                         i += 1
diff --git a/src/Dicos.py b/src/Dicos.py
index a84cab76de8ecc1d4b46e82b832158575c4628a6..c0d8aa72cc56ab420cda183cfcf875acd5966fc6 100644
--- a/src/Dicos.py
+++ b/src/Dicos.py
@@ -5,9 +5,11 @@ class Dicos:
                 self.content = {}
                 if mcd :
                         self.initializeWithMcd(mcd)
+
                 if fileName :
                         self.initializeWithDicoFile(fileName)
 
+
         def initializeWithMcd(self, mcd):
                 for index in range(mcd.getNbCol()):
                         if(mcd.getColStatus(index) == 'KEEP') and (mcd.getColType(index) == 'SYM') :
diff --git a/src/FeatModel.py b/src/FeatModel.py
index 8c344feca37f69d2f806140d6edf7c8c94446014..207e2d40c22e499e56acc07972f4b0a5939b09cb 100644
--- a/src/FeatModel.py
+++ b/src/FeatModel.py
@@ -61,7 +61,7 @@ class FeatModel:
             label = self.getFeatLabel(i)
             size = dicos.getDico(label).getSize()
             position = dicos.getCode(label, featVec[i])
-            #print('featureName = ', featureName, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin)
+            #print('featureName = ', label, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin)
             inputVector[origin + position] = 1
             origin += size
         return inputVector
@@ -73,7 +73,7 @@ class FeatModel:
             label = self.getFeatLabel(i)
             size = dicos.getDico(label).getSize()
             position = dicos.getCode(label, featVec[i])
-#            print('featureName = ', featureName, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin)
+            #print('featureName = ', label, 'value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin)
 #            print('value =', featVec[i], 'size =', size, 'position =', position, 'origin =', origin)
             inputVector[i] = position
             origin += size
diff --git a/src/Moves.py b/src/Moves.py
index 4232bdb00fea8c6489a1c77573812befb1d82d21..84b71893505935dd807e3ac3e3fafb89fb1351f4 100644
--- a/src/Moves.py
+++ b/src/Moves.py
@@ -3,6 +3,7 @@ import numpy as np
 class Moves:
     def __init__(self, dicos):
         self.dicoLabels = dicos.getDico('LABEL')
+        
         if not self.dicoLabels :
             print("cannot find LABEL in dicos")
             exit(1)
@@ -26,14 +27,18 @@ class Moves:
         if(mvtType == 'LEFT'):   return 3 + 2 * labelCode + 1
 
     def mvtDecode(self, mvt_Code):
-        if(mvt_Code == 0) : return ('SHIFT', 'NULL')
+        if(mvt_Code == 0) : 
+            return ('SHIFT', 'NULL')
         if(mvt_Code == 1) : return ('REDUCE', 'NULL')
         if(mvt_Code == 2) : return ('ROOT', 'NULL')
         if mvt_Code % 2 == 0 : #LEFT
             labelCode = int((mvt_Code - 4) / 2)
+            #print("label code: ", labelCode)
+
             return ('LEFT', self.dicoLabels.getSymbol(labelCode))
         else :
             labelCode = int((mvt_Code - 3)/ 2)
+            #print("label code: ", labelCode)
             return ('RIGHT', self.dicoLabels.getSymbol(labelCode))
 
     def buildOutputVectorOneHot(self, mvt):
@@ -46,4 +51,5 @@ class Moves:
         outputVector = np.zeros(1, dtype="int32")
         codeMvt = self.mvtCode(mvt)
         outputVector[0] = codeMvt
+
         return outputVector
diff --git a/src/Word.py b/src/Word.py
index fe92bd51f53c0062cfeee39030add45ff9250a12..f87b817735ecf5dc9538f8b950c84bd7d5466f8a 100644
--- a/src/Word.py
+++ b/src/Word.py
@@ -4,7 +4,7 @@ class Word:
         self.leftDaughters = []   # liste des indices des dépendants gauches
         self.rightDaughters = []  # liste des indices des dépendants droits
         self.index = self.invalidIndex()
-
+        
     def getFeat(self, featName):
         if(not featName in self.featDic):
             print('WARNING : feat', featName, 'does not exist')
@@ -40,7 +40,7 @@ class Word:
                 else:
                     print('\t', end='')
                 print(self.getFeat(mcd.getColName(columnNb)), end='')
-#        print('')
+        print('')
 
     @staticmethod
     def fakeWordConll():
diff --git a/src/WordBuffer.py b/src/WordBuffer.py
index f2fba582402d6627a4e91c81aa4d8ce0798855b1..9208b7f2597bd07c3a5fefa87cfded4bf79668d8 100644
--- a/src/WordBuffer.py
+++ b/src/WordBuffer.py
@@ -32,7 +32,8 @@ class WordBuffer:
         def affiche(self, mcd):
                 for w in self.array:
                         w.affiche(mcd)
-                        print('')
+                        #print('')
+                        #print(w.affiche(mcd))
 
         def getLength(self):
                 return len(self.array)
diff --git a/src/eval_mcf.py b/src/eval_mcf.py
index 22c26336921ed089ae601b8fd1b8e3f240aef557..c236c6d83d6df0d68473d387703953b5778c8b00 100644
--- a/src/eval_mcf.py
+++ b/src/eval_mcf.py
@@ -19,10 +19,10 @@ else:
     verbose = False
 
 
-#print('reading mcd from file :', refMcdFileName)
+print('reading mcd from file :', refMcdFileName)
 refMcd = Mcd(refMcdFileName)
 
-#print('reading mcd from file :', hypMcdFileName)
+print('reading mcd from file :', hypMcdFileName)
 hypMcd = Mcd(hypMcdFileName)
 
 GovColIndex = refMcd.locateCol('GOV')
diff --git a/src/mcf2cff.py b/src/mcf2cff.py
index 3e34b47e79ef4dccfce1926bfb6f6bbdf73bb7a6..11f9f7592c08eb72ca5412444da486195482ad31 100644
--- a/src/mcf2cff.py
+++ b/src/mcf2cff.py
@@ -96,7 +96,7 @@ featModel = FeatModel(featModelFileName, dicos)
 
 inputSize = featModel.getInputSize()
 outputSize = moves.getNb()
-print('input size = ', inputSize, 'outputSize =' , outputSize)
+print('featModel input size = ', inputSize, 'outputSize =' , outputSize)
 
 print('preparing training data')
 prepareData(mcd, mcfFileName, featModel, moves, dataFileName, wordsLimit)
diff --git a/src/plot_lib.py b/src/plot_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4152a38ae9d0f78d97e272176bd94edcf19def4
--- /dev/null
+++ b/src/plot_lib.py
@@ -0,0 +1,153 @@
+from matplotlib import pyplot as plt
+import matplotlib.cm as cm
+import numpy as np
+import torch
+from IPython.display import HTML, display
+
+
+def set_default(figsize=(10, 10), dpi=100):
+    plt.style.use(['dark_background', 'bmh'])
+    plt.rc('axes', facecolor='k')
+    plt.rc('figure', facecolor='k')
+    plt.rc('figure', figsize=figsize, dpi=dpi)
+
+
+def plot_data(X, y, d=0, auto=False, zoom=1):
+    X = X.cpu()
+    y = y.cpu()
+    plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
+    plt.axis('square')
+    plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
+    if auto is True: plt.axis('equal')
+    plt.axis('off')
+
+    _m, _c = 0, '.15'
+    plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
+    plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
+
+
+def plot_model(X, y, model):
+    model.cpu()
+    mesh = np.arange(-1.1, 1.1, 0.01)
+    xx, yy = np.meshgrid(mesh, mesh)
+    with torch.no_grad():
+        data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
+        Z = model(data).detach()
+    Z = np.argmax(Z, axis=1).reshape(xx.shape)
+    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
+    plot_data(X, y)
+
+
+def show_scatterplot(X, colors, title='', axis=True):
+    colors = cm.rainbow(np.linspace(0, 1, len(ys)))
+
+    X = X.numpy()
+    # plt.figure()
+    plt.axis('equal')
+    plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
+    # plt.grid(True)
+    plt.title(title)
+    plt.axis('off')
+    _m, _c = 0, '.15'
+    if axis:
+        plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
+        plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
+
+
+def plot_bases(bases, plotting=True, width=0.04):
+    bases[2:] -= bases[:2]
+    # if plot_bases.a: plot_bases.a.set_visible(False)
+    # if plot_bases.b: plot_bases.b.set_visible(False)
+    if plotting:
+        plot_bases.a = plt.arrow(*bases[0], *bases[2], width=width, color='r', zorder=10, alpha=1., length_includes_head=True)
+        plot_bases.b = plt.arrow(*bases[1], *bases[3], width=width, color='g', zorder=10, alpha=1., length_includes_head=True)
+
+
+plot_bases.a = None
+plot_bases.b = None
+
+
+def show_mat(mat, vect, prod, threshold=-1):
+    # Subplot grid definition
+    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
+                                        gridspec_kw={'width_ratios':[5,1,1]})
+    # Plot matrices
+    cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
+    ax2.matshow(vect.numpy(), clim=(-1, 1))
+    cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
+
+    # Set titles
+    ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
+    ax2.set_title(f'a^(i): {vect.numel()}')
+    ax3.set_title(f'p: {prod.numel()}')
+
+    # Remove xticks for vectors
+    ax2.set_xticks(tuple())
+    ax3.set_xticks(tuple())
+
+    # Plot colourbars
+    fig.colorbar(cax1, ax=ax2)
+    fig.colorbar(cax3, ax=ax3)
+
+    # Fix y-axis limits
+    ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
+
+
+colors = dict(
+    aqua='#8dd3c7',
+    yellow='#ffffb3',
+    lavender='#bebada',
+    red='#fb8072',
+    blue='#80b1d3',
+    orange='#fdb462',
+    green='#b3de69',
+    pink='#fccde5',
+    grey='#d9d9d9',
+    violet='#bc80bd',
+    unk1='#ccebc5',
+    unk2='#ffed6f',
+)
+
+
+def _cstr(s, color='black'):
+    if s == ' ':
+        return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
+    else:
+        return f'<text style=color:#000;background-color:{color}>{s} </text>'
+
+# print html
+def _print_color(t):
+    display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
+
+# get appropriate color for value
+def _get_clr(value):
+    colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
+              '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
+              '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
+              '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
+    value = int((value * 100) / 5)
+    if value == len(colors): value -= 1  # fixing bugs...
+    return colors[value]
+
+def _visualise_values(output_values, result_list):
+    text_colours = []
+    for i in range(len(output_values)):
+        text = (result_list[i], _get_clr(output_values[i]))
+        text_colours.append(text)
+    _print_color(text_colours)
+
+def print_colourbar():
+    color_range = torch.linspace(-2.5, 2.5, 20)
+    to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
+    _print_color(to_print)
+
+
+# Let's only focus on the last time step for now
+# First, the cell state (Long term memory)
+def plot_state(data, state, b, decoder, idx_to_symbol):
+    actual_data = decoder(data[b, :, :].numpy(), idx_to_symbol)
+    seq_len = len(actual_data)
+    seq_len_w_pad = len(state)
+    for s in range(state.size(2)):
+        states = torch.sigmoid(state[:, b, s])
+        _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
\ No newline at end of file
diff --git a/src/pytorch_utils.py b/src/pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0427d751c43daeae77697508706ba2d5f8dc79a
--- /dev/null
+++ b/src/pytorch_utils.py
@@ -0,0 +1,198 @@
+import math
+
+import numpy as np
+import six
+
+""" read file in cff format
+
+    the contents of such files look like this:
+    133
+    75
+    0
+    0 0 1 4 2 0 0
+    0
+    0 1 4 2 3 1 0
+
+    the first two lines represent TODO
+"""
+def readFile_cff(filepath):
+    lines = []
+    with open(filepath) as file_in:
+        for line in file_in:
+            lines.append(line)
+    inputSize = int(lines[0].strip())
+    outputSize = int(lines[1].strip())
+    items_labels = [x.strip() for x in lines[2:]]
+    items = items_labels[1::2]
+    labels = items_labels[::2]
+    return items, labels, inputSize, outputSize
+
+
+#flatten a list
+def flatten(xss):
+    return [x for xs in xss for x in xs]
+
+
+#find how many classes we have in data:
+def find_unique_classes(*inputlists):
+    unique_classes = np.unique(flatten([*inputlists])).tolist()
+    unique_classes = [str(x) for x in unique_classes]
+    number_of_unique_classes = len(unique_classes)
+    return unique_classes, number_of_unique_classes
+
+
+def find_unique_symbols(*inputlists):
+    flat_list = flatten([*inputlists])
+    maxlen = len(max(flat_list, key=len))
+    unique_elem_list = []
+    for elem in flat_list:
+        tokens = [x for x in elem.split(' ')]
+        for t in tokens:
+            if not t in unique_elem_list:
+                unique_elem_list.append(t)
+    return maxlen, unique_elem_list + [' ']
+
+
+def pad_sequences(sequences, maxlen=None, dtype='int32',
+                  padding='pre', truncating='pre', value=0.):
+    if not hasattr(sequences, '__len__'):
+        raise ValueError('`sequences` must be iterable.')
+    lengths = []
+    for x in sequences:
+        if not hasattr(x, '__len__'):
+            raise ValueError('`sequences` must be a list of iterables. '
+                             'Found non-iterable: ' + str(x))
+        lengths.append(len(x))
+
+    num_samples = len(sequences)
+    #print("NUM SAMPLES", num_samples)
+    if maxlen is None:
+        maxlen = np.max(lengths)
+
+    # take the sample shape from the first non empty sequence
+    # checking for consistency in the main loop below.
+    sample_shape = tuple()
+    for s in sequences:
+        if len(s) > 0:
+            sample_shape = np.asarray(s).shape[1:]
+            break
+
+    is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
+    if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
+        raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
+                         "You should set `dtype=object` for variable length strings."
+                         .format(dtype, type(value)))
+
+    x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
+    for idx, s in enumerate(sequences):
+        if not len(s):
+            continue  # empty list/array was found
+        if truncating == 'pre':
+            trunc = s[-maxlen:]
+        elif truncating == 'post':
+            trunc = s[:maxlen]
+        else:
+            raise ValueError('Truncating type "%s" '
+                             'not understood' % truncating)
+
+        # check `trunc` has expected shape
+        trunc = np.asarray(trunc, dtype=dtype)
+        if trunc.shape[1:] != sample_shape:
+            raise ValueError('Shape of sample %s of sequence at position %s '
+                             'is different from expected shape %s' %
+                             (trunc.shape[1:], idx, sample_shape))
+
+        if padding == 'post':
+            x[idx, :len(trunc)] = trunc
+        elif padding == 'pre':
+            x[idx, -len(trunc):] = trunc
+        else:
+            raise ValueError('Padding type "%s" not understood' % padding)
+    return x
+
+
+def to_categorical(y, num_classes=None, dtype='float32'):
+    y = np.array(y, dtype='int')
+    input_shape = y.shape
+    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
+        input_shape = tuple(input_shape[:-1])
+    y = y.ravel()
+    if not num_classes:
+        num_classes = np.max(y) + 1
+    n = y.shape[0]
+    categorical = np.zeros((n, num_classes), dtype=dtype)
+    categorical[np.arange(n), y] = 1
+    output_shape = input_shape + (num_classes,)
+    categorical = np.reshape(categorical, output_shape)
+    return categorical
+
+
+def make_pytorch_dicts(*input_paths):
+    input_paths = [x for x in input_paths]
+    item_lsts = []
+    label_lsts = []
+    for inpath in input_paths:
+        items_list, labels_list, _, _ = readFile_cff(inpath)
+        item_lsts.append(items_list)
+        label_lsts.append(labels_list)
+    items_list = flatten(item_lsts)
+    labels_list = flatten(label_lsts)
+
+    classes, n_classes = find_unique_classes(labels_list) # possible moves
+    maxlen, all_symbols = find_unique_symbols(items_list) # basically, the number of POS tags
+    n_symbols = len(all_symbols)
+
+    #print('Number of classes: ', n_classes)
+    #print('Number of symbols: ', n_symbols)
+    #print('Max length of sequence: ', maxlen)
+
+    symbol_to_idx = {s: n for n, s in enumerate(all_symbols)}
+    idx_to_symbol = {n: s for n, s in enumerate(all_symbols)}
+
+    class_to_idx = {c: n for n, c in enumerate(classes)}
+    idx_to_class = {n: c for n, c in enumerate(classes)}
+
+    return n_classes, maxlen, n_symbols, symbol_to_idx, idx_to_symbol, class_to_idx, idx_to_class, classes
+
+def encode_x_batch(x_batch, symbol_to_idx, n_symbols):
+    return pad_sequences([encode_x(x, symbol_to_idx, n_symbols) for x in x_batch],maxlen=15) # TODO read maxlen from the file
+
+
+def encode_x(x, symbol_to_idx, n_symbols):
+    idx_x = [symbol_to_idx[s] for s in x]
+    return to_categorical(idx_x, num_classes=n_symbols)
+
+
+def encode_y_batch(y_batch, class_to_idx, n_classes):
+    return np.array([encode_y(y, class_to_idx, n_classes) for y in y_batch])
+
+
+def encode_y(y, class_to_idx, n_classes):
+    idx_y = class_to_idx[y]
+    return to_categorical(idx_y, num_classes=n_classes)
+
+
+def preprocess_data(items_list, labels_list, batch_size, symbol_to_idx, class_to_idx, n_symbols, n_classes):
+    upp_bound = int(math.ceil(len(items_list) / batch_size))
+    data_batches = []
+    i = 0
+    for i in range(upp_bound -1):
+        a = i * batch_size
+        b = (i + 1) * batch_size
+        data_batches.append((encode_x_batch(items_list[a:b], symbol_to_idx, n_symbols), encode_y_batch(labels_list[a:b], class_to_idx, n_classes)))
+        i+=1
+    return data_batches
+
+
+def decode_x(x, idx_to_symbol):
+    x = x[np.sum(x, axis=1) > 0]    # remove padding
+    return ''.join([idx_to_symbol[pos] for pos in np.argmax(x, axis=1)])
+
+def decode_y(y, idx_to_class):
+    return idx_to_class[np.argmax(y)]
+
+def decode_x_batch(x_batch, idx_to_symbol):
+    return [decode_x(x, idx_to_symbol) for x in x_batch]
+
+def decode_y_batch(y_batch, idx_to_class):
+    return [idx_to_class[pos] for pos in np.argmax(y_batch, axis=1)]
\ No newline at end of file
diff --git a/src/tbp_decode_pytorch.py b/src/tbp_decode_pytorch.py
index c7e139992109bcf2fbae735241be9a2af078b53c..13c6b317e7602b6eb203998613556835512b9881 100644
--- a/src/tbp_decode_pytorch.py
+++ b/src/tbp_decode_pytorch.py
@@ -9,6 +9,7 @@ from FeatModel import FeatModel
 import torch
 
 import numpy as np
+from pytorch_utils import *
 
 
 
@@ -23,9 +24,9 @@ def prepareWordBufferForDecode(buffer):
         word.setFeat('LABEL', Word.invalidLabel())
 
 
-verbose = False
-if len(sys.argv) != 7 :
-    print('usage:', sys.argv[0], 'mcf_file model_file dicos_file feat_model mcd_file words_limit')
+verbose = True
+if len(sys.argv) != 9 :
+    print('usage:', sys.argv[0], 'mcf_file model_file dicos_file feat_model mcd_file words_limit train_file dev_file')
     exit(1)
 
 mcf_file =       sys.argv[1]
@@ -34,76 +35,178 @@ dicos_file =     sys.argv[3]
 feat_model =     sys.argv[4]
 mcd_file =       sys.argv[5]
 wordsLimit = int(sys.argv[6])
+train_fr_file = sys.argv[7]
+dev_fr_file = sys.argv[8]
 
+##########################################################################
+
+n_classes, maxlen, n_symbols, symbol_to_idx, idx_to_symbol, class_to_idx, idx_to_class, _ = make_pytorch_dicts(dev_fr_file, train_fr_file)
+##########################################################################
     
-sys.stderr.write('reading mcd from file :')
-sys.stderr.write(mcd_file)
-sys.stderr.write('\n')
-mcd = Mcd(mcd_file)
+#print('reading mcd from file : ', mcd_file) # PGLE.mcd
+
+mcd = Mcd(mcd_file) # MCD = multi column descriptions
+
 
-sys.stderr.write('loading dicos\n')
-dicos = Dicos(fileName=dicos_file)
+
+#print('loading dicos from : ', dicos_file) # PLE.dic
+dicos = Dicos(fileName=dicos_file) # dictionaries with class labels for columns from mcd
 
 moves = Moves(dicos)
 
-sys.stderr.write('reading feature model from file :')
-sys.stderr.write(feat_model)
-sys.stderr.write('\n')
+#print('reading feature model from file : ', feat_model) # basic.fm
+
 featModel = FeatModel(feat_model, dicos)
 
-sys.stderr.write('loading model :')
-sys.stderr.write(model_file)
-sys.stderr.write('\n')
-model = load_model(model_file)
+#print('loading pytorch model from :', model_file) # /home/taniabladier/Programming/AMU/tbp/expe/out/fr.pytorch
+
+
+############################
+
+"""
+load saved pytorch model
+"""
+
+batch_size = 32
+
+
+# Setup the RNN and training settings
+import torch.nn as nn
+import torch.nn.functional as F
+
+class SimpleLSTM(nn.Module):
+    def __init__(self, input_size, hidden_size, output_size):
+        super().__init__()
+        self.lstm = torch.nn.LSTM(input_size, hidden_size, batch_first=True)
+        self.linear = torch.nn.Linear(hidden_size, output_size)
+
+    def forward(self, x):
+        h = self.lstm(x)[0]
+        x = self.linear(h)
+        return x
+
+    def get_states_across_time(self, x):
+        h_c = None
+        h_list, c_list = list(), list()
+        with torch.no_grad():
+            for t in range(x.size(1)):
+                h_c = self.lstm(x[:, [t], :], h_c)[1]
+                h_list.append(h_c[0])
+                c_list.append(h_c[1])
+            h = torch.cat(h_list)
+            c = torch.cat(c_list)
+        return h, c
+    
+
+
+
+input_size  = 133 #n_symbols
+hidden_size = 128
+output_size = 75 #n_classes
+
+
+print('input output', input_size, output_size)
+model       = SimpleLSTM(input_size, hidden_size, output_size)
+criterion   = torch.nn.CrossEntropyLoss()
+optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)
+
+
+checkpoint = torch.load(model_file, map_location=torch.device('cpu'), weights_only=False)
+model.load_state_dict(checkpoint['model_state_dict'])
+optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+epoch = checkpoint['epoch']
+loss = checkpoint['loss']
 
-inputSize = featModel.getInputSize()
-outputSize = moves.getNb()
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+model = model.to(device)
+model.eval();
+
+
+############################
+
+
+#model = load_model(model_file)
+
+#inputSize = featModel.getInputSize()
+#outputSize = moves.getNb()
 
 c = Config(mcf_file, mcd, dicos)
+
+
+
 numSent = 0
-verbose = False
+verbose = True
 numWords = 0
 
+
 while c.getBuffer().readNextSentence()  and numWords < wordsLimit :
     c.getStack().empty()
     prepareWordBufferForDecode(c.getBuffer())
+
     numWords += c.getBuffer().getLength()
 
     while True :
         featVec = c.extractFeatVec(featModel)
         inputVector = featModel.buildInputVector(featVec, dicos)
-        outputVector = model.predict(inputVector.reshape((1,inputSize)), batch_size=1, verbose=0, steps=None)
-        mvt_Code = outputVector.argmax()
+
+        ###############
+        inputVector = ' '.join(str(x) for x in inputVector)
+        inputVector = encode_x_batch([inputVector], symbol_to_idx, 133) #n_symbols)
+        inputVector = torch.from_numpy(inputVector).float().to(device)      
+       
+
+        with torch.no_grad():
+            output = model(inputVector)
+
+            # Pick only the output corresponding to last sequence element (input is pre padded)
+            output = output[:, -1, :]
+            preds = output.argmax(dim=1)
+            mvt_Code = int([idx_to_class[y.item()] for y in preds][0])
+
+
+        ##################
+        
+        #mvt_Code = outputVector.argmax()
+        #mvt = moves.mvtDecode(mvt_Code)
+
         mvt = moves.mvtDecode(mvt_Code)
 
+        verbose = True
+
         if(verbose == True) :
-            print("------------------------------------------")
-            c.affiche()
-            print('predicted move', mvt[0], mvt[1])
-            print(mvt, featVec)
 
-        res = c.applyMvt(mvt)
+            """
+            print("\n\n------------------------------------------\n")
+            print('inputVector ::', featModel.buildInputVector(featVec, dicos))
+
+            print('mvt code: ', mvt_Code, ' move ::: ', mvt)
+
+            c.affiche()            
+
+            print('predicted move: ', mvt[0], mvt[1], 'for', str(featVec))        
+            """
+
+        res = c.applyMvt(mvt) #result, True or False
+        #print("RES", res)
         if not res :
-            sys.stderr.write("cannot apply predicted movement\n")
+            print("cannot apply predicted movement\n")
             mvt_type = mvt[0]
             mvt_label = mvt[1]
             if mvt_type != "SHIFT" :
-                sys.stderr.write("try to force SHIFT\n")
+                print("try to force SHIFT\n")
                 res = c.shift()
                 if res == False :
-                    sys.stderr.write("try to force REDUCE\n")
+                    print("try to force REDUCE\n")
                     res = c.red()
                     if res == False :
-                        sys.stderr.write("abort sentence\n")
+                        print("abort sentence\n")
                         break
         if(c.isFinal()):
             break
     for i in range(1, c.getBuffer().getLength()):
         w = c.getBuffer().getWord(i)
         w.affiche(mcd)
-        print('')
-#        print('\t', w.getFeat("GOV"), end='\t')
-#        print(w.getFeat("LABEL"))
 
     numSent += 1
 #    if numSent % 10 == 0:
diff --git a/src/tbp_train_pytorch.py b/src/tbp_train_pytorch.py
index 314948443c060006198f0fe6c1dbd9241aa46109..2ed492617a3fc641939bc55f6dafbd3114e69441 100644
--- a/src/tbp_train_pytorch.py
+++ b/src/tbp_train_pytorch.py
@@ -1,8 +1,15 @@
 import sys
 import numpy as np
 import torch
-from torch import nn
+import torch.nn as nn
+import torch.nn.functional as F
+from pytorch_utils import *
+from plot_lib import *
+import os
 
+"""## 1. Reading Data Files"""
+
+"""
 def readData(dataFilename) :
     allX = []
     allY = []
@@ -37,6 +44,7 @@ def readData(dataFilename) :
     x_train = np.array(allX)
     y_train = np.array(allY)
     return (inputSize, outputSize, x_train, y_train)
+"""
 
 
 
@@ -46,8 +54,437 @@ if len(sys.argv) < 3 :
 
 cffTrainFileName =   sys.argv[1]
 cffDevFileName =     sys.argv[2]
-kerasModelFileName = sys.argv[3]
+pytorchModelFileName = sys.argv[3]
+
+
+batch_size = 32
+
+
+n_classes, maxlen, n_symbols, symbol_to_idx, idx_to_symbol, \
+      class_to_idx, idx_to_class, classes = make_pytorch_dicts(cffTrainFileName, cffDevFileName)
+
+
+train_items_list, train_labels_list, train_inputSize, train_outputSize  = readFile_cff(cffTrainFileName)
+dev_items_list, dev_labels_list, dev_inputSize, dev_outputSize  = readFile_cff(cffDevFileName)
+
+
+train_data_gen = preprocess_data(train_items_list[:800000], train_labels_list[:800000], batch_size, symbol_to_idx, class_to_idx, train_inputSize, train_outputSize)#, n_symbols, n_classes)
+dev_data_gen  = preprocess_data(dev_items_list[:200000], dev_labels_list[:200000], batch_size, symbol_to_idx, class_to_idx, train_inputSize, train_outputSize)#, n_symbols, n_classes)
+
+
+"""## 2. Defining the Model"""
+
+
+# Set the random seed for reproducible results
+torch.manual_seed(1)
+
+# Setup the RNN and training settings
+input_size  = 133 #train_inputSize #n_symbols
+hidden_size = 128
+output_size = 75 #train_outputSize #n_classes
+
+class SimpleMLP(nn.Module):
+    def __init__(self, input_size, output_size):
+        super(SimpleMLP, self).__init__()
+        self.fc1 = nn.Linear(input_size, 200)  # Dense layer with 128 units
+        self.dropout = nn.Dropout(0.5)  # Dropout with 0.4 probability
+        self.fc2 = nn.Linear(200, output_size)  # Dense layer with outputSize units
+
+    def forward(self, x):
+        x = F.relu(self.fc1(x))  # Apply ReLU activation
+        x = self.dropout(x)  # Apply Dropout
+        x = F.softmax(self.fc2(x), dim=1)  # Apply Softmax activation
+        return x
+
+    def get_states_across_time(self, x):
+        h_c = None
+        h_list, c_list = list(), list()
+        with torch.no_grad():
+            for t in range(x.size(1)):
+                h_c = self.lstm(x[:, [t], :], h_c)[1]
+                h_list.append(h_c[0])
+                c_list.append(h_c[1])
+            h = torch.cat(h_list)
+            c = torch.cat(c_list)
+        return h, c
+
+
+class SimpleRNN(nn.Module):
+    def __init__(self, input_size, hidden_size, output_size):
+        # This just calls the base class constructor
+        super().__init__()
+        # Neural network layers assigned as attributes of a Module subclass
+        # have their parameters registered for training automatically.
+        self.rnn = torch.nn.RNN(input_size, hidden_size, nonlinearity='relu', batch_first=True)
+        self.linear = torch.nn.Linear(hidden_size, output_size)
+
+    def forward(self, x):
+        # The RNN also returns its hidden state but we don't use it.
+        # While the RNN can also take a hidden state as input, the RNN
+        # gets passed a hidden state initialized with zeros by default.
+        h = self.rnn(x)[0]
+        x = self.linear(h)
+        return x
+
+class SimpleLSTM(nn.Module):
+    def __init__(self, input_size, hidden_size, output_size):
+        super().__init__()
+        self.lstm = torch.nn.LSTM(input_size, hidden_size, batch_first=True)
+        self.linear = torch.nn.Linear(hidden_size, output_size)
+
+    def forward(self, x):
+        h = self.lstm(x)[0]
+        x = self.linear(h)
+        return x
+
+    def get_states_across_time(self, x):
+        h_c = None
+        h_list, c_list = list(), list()
+        with torch.no_grad():
+            for t in range(x.size(1)):
+                h_c = self.lstm(x[:, [t], :], h_c)[1]
+                h_list.append(h_c[0])
+                c_list.append(h_c[1])
+            h = torch.cat(h_list)
+            c = torch.cat(c_list)
+        return h, c
+    
+
+"""## 3. Defining the Training Loop"""
+
+def train(model, train_data_gen, criterion, optimizer, device):
+    # Set the model to training mode. This will turn on layers that would
+    # otherwise behave differently during evaluation, such as dropout.
+    model.train()
+
+    # Store the number of sequences that were classified correctly
+    num_correct = 0
+
+    # Iterate over every batch of sequences. Note that the length of a data generator
+    # is defined as the number of batches required to produce a total of roughly 1000
+    # sequences given a batch size.
+    for batch_idx in range(len(train_data_gen)):
+
+
+        # Request a batch of sequences and class labels, convert them into tensors
+        # of the correct type, and then send them to the appropriate device.
+        data, target = train_data_gen[batch_idx]
+
+        data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)
+
+
+        # Perform the forward pass of the model
+        output = model(data)  # Step ①
+
+        # Pick only the output corresponding to last sequence element (input is pre padded)
+        output = output[:, -1, :]
+
+        # Compute the value of the loss for this batch. For loss functions like CrossEntropyLoss,
+        # the second argument is actually expected to be a tensor of class indices rather than
+        # one-hot encoded class labels. One approach is to take advantage of the one-hot encoding
+        # of the target and call argmax along its second dimension to create a tensor of shape
+        # (batch_size) containing the index of the class label that was hot for each sequence.
+        target = target.argmax(dim=1)
+
+        loss = criterion(output, target)  # Step ②
+
+        # Clear the gradient buffers of the optimized parameters.
+        # Otherwise, gradients from the previous batch would be accumulated.
+        optimizer.zero_grad()  # Step ③
+
+        loss.backward()  # Step ④
+
+        optimizer.step()  # Step ⑤
+
+        y_pred = output.argmax(dim=1)
+        num_correct += (y_pred == target).sum().item()
+
+    return num_correct, loss.item()
+
+"""## 4. Defining the Testing Loop"""
+
+def test(model, test_data_gen, criterion, device):
+    # Set the model to evaluation mode. This will turn off layers that would
+    # otherwise behave differently during training, such as dropout.
+    model.eval()
+
+    # Store the number of sequences that were classified correctly
+    num_correct = 0
+
+    # A context manager is used to disable gradient calculations during inference
+    # to reduce memory usage, as we typically don't need the gradients at this point.
+    with torch.no_grad():
+        for batch_idx in range(len(test_data_gen)):
+            data, target = test_data_gen[batch_idx]
+
+            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)
+            output = model(data)
+            # Pick only the output corresponding to last sequence element (input is pre padded)
+            output = output[:, -1, :]
+
+            target = target.argmax(dim=1)
+            loss = criterion(output, target)
+
+            y_pred = output.argmax(dim=1)
+            num_correct += (y_pred == target).sum().item()
+
+    return num_correct, loss.item()
+
+"""## 5. Putting it All Together"""
+
+import matplotlib.pyplot as plt
+from plot_lib import set_default, plot_state, print_colourbar
+
+set_default()
+
+def check_point(model, filename, epochs, optimizer, criterion):
+    #torch.save(model.state_dict(), filename)
+
+    print(f"Saving model from epoch", epochs)
+    torch.save({
+                'epoch': epochs,
+                'model_state_dict': model.state_dict(),
+                'optimizer_state_dict': optimizer.state_dict(),
+                'loss': criterion,
+                }, filename)
+
+
+def resume(model, filename):
+    #model.load_state_dict(torch.load(filename))
+
+    checkpoint = torch.load(filename, map_location=torch.device('cpu'), weights_only=False)
+    model.load_state_dict(checkpoint['model_state_dict'])
+    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+    epoch = checkpoint['epoch']
+    loss = checkpoint['loss']
+
+def train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=True):
+    # Automatically determine the device that PyTorch should use for computation
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+
+    # Move model to the device which will be used for train and test
+    model.to(device)
+
+    # Track the value of the loss function and model accuracy across epochs
+    history_train = {'loss': [], 'acc': []}
+    history_test = {'loss': [], 'acc': []}
+
+
+    early_stop_thresh = 5
+    best_accuracy = -1
+    best_epoch = -1
+    for epoch in range(max_epochs):
+        # Run the training loop and calculate the accuracy.
+        # Remember that the length of a data generator is the number of batches,
+        # so we multiply it by the batch size to recover the total number of sequences.
+        num_correct, loss = train(model, train_data_gen, criterion, optimizer, device)
+        accuracy = float(num_correct) / (len(train_data_gen) * batch_size) * 100
+        history_train['loss'].append(loss)
+        history_train['acc'].append(accuracy)
 
+        # Do the same for the testing loop
+        num_correct, loss = test(model, test_data_gen, criterion, device)
+        accuracy = float(num_correct) / (len(test_data_gen) * batch_size) * 100
+        history_test['loss'].append(loss)
+        history_test['acc'].append(accuracy)
+
+        if history_test['acc'][-1] > best_accuracy:
+            best_accuracy = history_test['acc'][-1]
+            best_epoch = epoch
+            print('epoch', epoch)
+            check_point(model, pytorchModelFileName, epoch, optimizer, criterion)
+        elif epoch - best_epoch > early_stop_thresh:
+            print("Early stopped training at epoch %d" % epoch)
+            break  # terminate the training loop
+
+        if verbose or epoch + 1 == max_epochs:
+            print(f'[Epoch {epoch + 1}/{max_epochs}]'
+                  f" loss: {history_train['loss'][-1]:.4f}, acc: {history_train['acc'][-1]:2.2f}%"
+                  f" - test_loss: {history_test['loss'][-1]:.4f}, test_acc: {history_test['acc'][-1]:2.2f}%")
+
+
+        resume(model, pytorchModelFileName)
+
+    # Generate diagnostic plots for the loss and accuracy
+    fig, axes = plt.subplots(ncols=2, figsize=(9, 4.5))
+    for ax, metric in zip(axes, ['loss', 'acc']):
+        ax.plot(history_train[metric])
+        ax.plot(history_test[metric])
+        ax.set_xlabel('epoch', fontsize=12)
+        ax.set_ylabel(metric, fontsize=12)
+        ax.legend(['Train', 'Test'], loc='best')
+        plt.savefig(os.path.abspath('..') + '/expe/out/loss_accuracy.png')
+    #plt.show()
+
+    return model
+
+
+
+"""## 5. Simple RNN: 10 Epochs"""
+"""
+# Setup the training and test data generators
+
+
+
+# Setup the RNN and training settings
+input_size  = n_symbols
+hidden_size = 128
+output_size = n_classes
+model       = SimpleRNN(input_size, hidden_size, output_size)
+criterion   = torch.nn.CrossEntropyLoss()
+optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)
+max_epochs  = 100
+
+# Train the model
+model = train_and_test(model, train_data_gen, dev_data_gen, criterion, optimizer, max_epochs)
+
+for parameter_group in list(model.parameters()):
+    print(parameter_group.size())
+
+"""
+
+"""## 6b SimpleMLP: 10 Epochs
+
+"""
+"""
+# Setup the MLP and training settings
+input_size  = n_symbols
+output_size = n_classes
+model = SimpleMLP(input_size, output_size)
+criterion   = torch.nn.CrossEntropyLoss()
+#optimizer   = torch.optim.Adam(model.parameters(), lr=0.01)
+optimizer = torch.optim.Adamax(model.parameters(), lr=0.01)
+
+max_epochs  = 30
+
+# Train the model
+model = train_and_test(model, train_data_gen, dev_data_gen, criterion, optimizer, max_epochs)
+
+for parameter_group in list(model.parameters()):
+    print(parameter_group.size())
+
+"""
+
+"""## 6. Simple LSTM: 10 Epochs"""
+
+
+model       = SimpleLSTM(input_size, hidden_size, output_size)
+criterion   = torch.nn.CrossEntropyLoss()
+optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)
+
+
+
+max_epochs  = 30
+
+# Train the model
+model = train_and_test(model, train_data_gen, dev_data_gen, criterion, optimizer, max_epochs)
+
+#for parameter_group in list(model.parameters()):
+#    print(parameter_group.size())
+
+
+"""## 7. Model Evaluation"""
+
+import collections
+import random
+
+def evaluate_model(model, seed=9001, verbose=False):
+    # Define a dictionary that maps class indices to labels
+    #class_idx_to_label = {0: 'Q', 1: 'R', 2: 'S', 3: 'U', 4: '0', 5: '8'}
+    class_idx_to_label = {n: c for n, c in enumerate(classes)}
+
+    # Create a new data generator
+    #data_generator = QRSU.prepare_data(seed=seed)
+    data_generator = preprocess_data(dev_items_list, dev_labels_list, batch_size, symbol_to_idx, class_to_idx, input_size, output_size)#, n_symbols, n_classes)
+
+    # Track the number of times a class appears
+    count_classes = collections.Counter()
+    # Keep correctly classified and misclassified sequences, and their
+    # true and predicted class labels, for diagnostic information.
+    correct = []
+    incorrect = []
+
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+
+    model.eval()
+
+    with torch.no_grad():
+        for batch_idx in range(len(data_generator)):
+            data, target = data_generator[batch_idx]
+            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)
+
+            data_decoded = decode_x_batch(data.cpu().numpy(), idx_to_symbol)
+            target_decoded = decode_y_batch(target.cpu().numpy(), idx_to_class)
+
+            output = model(data)
+            output = output[:, -1, :]
+
+            target = target.argmax(dim=1)
+            y_pred = output.argmax(dim=1)
+            y_pred_decoded = [class_idx_to_label[y.item()] for y in y_pred]
+
+            count_classes.update(target_decoded)
+            for i, (truth, prediction) in enumerate(zip(target_decoded, y_pred_decoded)):
+                if truth == prediction:
+                    correct.append((data_decoded[i], truth, prediction))
+                else:
+                    incorrect.append((data_decoded[i], truth, prediction))
+
+    num_sequences = sum(count_classes.values())
+    print('len correct', len(correct))
+    print('num seqs', num_sequences)
+    accuracy = float(len(correct)) / num_sequences * 100
+    print(f'The accuracy of the model is measured to be {accuracy:.2f}%.\n')
+
+    # Report the accuracy by class
+    for label in sorted(count_classes):
+        num_correct = sum(1 for _, truth, _ in correct if truth == label)
+        #print(f'{label}: {num_correct} / {count_classes[label]} correct')
+
+    print("Number of correct predictions: ", len(correct))
+
+    # Report some random sequences for examination
+    print('\nHere are some example sequences:')
+    if len(correct) > 0:
+        for i in range(0,3):
+            sequence, truth, prediction = correct[random.randrange(0, 2)]
+            #print(f'{sequence} -> {truth} was labelled {prediction}')
+
+    print("Number of incorrect predictions: ", len(incorrect))
+
+    # Report misclassified sequences for investigation
+    if len(incorrect) > 0:
+        print('\nThe following sequences were misclassified:')
+        for i in range(0,3):
+            sequence, truth, prediction = incorrect[random.randrange(0, 2)]
+            print(f'{sequence} -> {truth} was labelled {prediction}')
+    else:
+        print('\nThere were no misclassified sequences.')
+
+evaluate_model(model)
+
+
+""" Visualize Model """
+
+# Get hidden (H) and cell (C) batch state given a batch input (X)
+device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+model.eval()
+with torch.no_grad():
+    data = dev_data_gen[0][0]
+    X = torch.from_numpy(data).float().to(device)
+    H_t, C_t = model.get_states_across_time(X)
+
+#print("Color range is as follows:")
+#print_colourbar()
+
+#plot_state(X.cpu(), C_t, b=31, decoder=decode_x, idx_to_symbol=idx_to_symbol)  # 3, 6, 9
+
+#plot_state(X.cpu(), H_t, b=31, decoder=decode_x, idx_to_symbol=idx_to_symbol) #b=9
+
+
+
+
+"""
 inputSize, outputSize, x_train, y_train = readData(cffTrainFileName)
 devInputSize, devOutputSize, x_dev, y_dev = readData(cffDevFileName)
 model = mlp()
@@ -75,4 +512,4 @@ model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_dev,y_d
 
 model.save(kerasModelFileName)
         
-
+"""