diff --git a/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb b/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb
index e23180be4b2b3ebb0fa7b86798291f716322a6a5..ceb6052a49e387b6e1af94e850c3bffb5c13fcd0 100644
--- a/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb
+++ b/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb
@@ -88,7 +88,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "id": "2a882cc9-8f9d-4457-becb-d2e26ab3f14f",
    "metadata": {},
    "outputs": [
@@ -107,7 +107,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "id": "8897dcc3-4218-4ee5-9984-17b9a6d8dce2",
    "metadata": {},
    "outputs": [],
@@ -163,7 +163,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 6,
    "id": "1363f307-fa4b-43ba-93d5-2d1c11ceb9e4",
    "metadata": {},
    "outputs": [
@@ -184,7 +184,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 7,
    "id": "1362e192-514a-4a77-a8cb-5c012026e2bb",
    "metadata": {},
    "outputs": [],
@@ -238,7 +238,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 8,
    "id": "544ff6aa-4104-4580-a01f-97429ffcc228",
    "metadata": {},
    "outputs": [
@@ -328,7 +328,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 9,
    "id": "b9052dc2-ce45-4af4-a0a0-46c60a13da12",
    "metadata": {},
    "outputs": [],
@@ -388,7 +388,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 10,
    "id": "1e9dd0fb-db6a-47d1-8bfb-1015845f6d3e",
    "metadata": {},
    "outputs": [
@@ -398,7 +398,7 @@
        "{'Flesch-Douma': 88.68, 'LIX': 11.55, 'Kandel-Moles': 5.86}"
       ]
      },
-     "execution_count": 11,
+     "execution_count": 10,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -421,7 +421,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 11,
    "id": "24bc84a5-b2df-4194-838a-8f24302599bd",
    "metadata": {},
    "outputs": [],
@@ -467,7 +467,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 12,
    "id": "0cdb972f-31b6-4e7e-82a8-371eda344f2c",
    "metadata": {},
    "outputs": [
@@ -477,7 +477,7 @@
        "{'Average Word Length': 3.79, 'Average Sentence Length': 7.0}"
       ]
      },
-     "execution_count": 13,
+     "execution_count": 12,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -521,7 +521,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 13,
    "id": "56af520c-d56b-404a-aebf-ad7c2a9ca503",
    "metadata": {},
    "outputs": [],
@@ -567,7 +567,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 14,
    "id": "f7c8b125-4651-4b21-bcc4-93ef78a4239b",
    "metadata": {},
    "outputs": [
@@ -603,7 +603,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 15,
    "id": "daa17c33-adca-4695-90eb-741579382939",
    "metadata": {},
    "outputs": [],
@@ -622,7 +622,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 16,
    "id": "80d8fa08-6b7d-4ab7-85cd-987823639277",
    "metadata": {},
    "outputs": [
@@ -665,7 +665,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 18,
    "id": "3f9c7dc7-6820-4013-a85c-2af4f846d4f5",
    "metadata": {},
    "outputs": [
@@ -693,7 +693,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 19,
    "id": "65e1a630-c46e-4b18-9831-b97864de53ee",
    "metadata": {},
    "outputs": [],
@@ -713,7 +713,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 20,
    "id": "1612e911-12a8-47c9-b811-b2d6885c3647",
    "metadata": {},
    "outputs": [
@@ -742,7 +742,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 21,
    "id": "925a3a75-aaaa-4851-b77b-b42cb1e21e11",
    "metadata": {},
    "outputs": [],
@@ -757,7 +757,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 22,
    "id": "6fa60897-ad26-43b4-b8de-861290ca6bd3",
    "metadata": {},
    "outputs": [
@@ -783,25 +783,10 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 23,
    "id": "f3678462-e572-4ce5-8d3d-a5389b2356c8",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Defaulting to user installation because normal site-packages is not writeable\n",
-      "Collecting scipy\n",
-      "  Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
-      "Requirement already satisfied: numpy<2.5,>=1.23.5 in /public/conda/Miniconda/envs/pytorch-2.6/lib/python3.11/site-packages (from scipy) (2.2.4)\n",
-      "Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)\n",
-      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.7/37.7 MB\u001b[0m \u001b[31m75.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
-      "\u001b[?25hInstalling collected packages: scipy\n",
-      "Successfully installed scipy-1.15.3\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "#!pip3 install seaborn\n",
     "#!pip3 install scipy"
@@ -809,7 +794,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 24,
    "id": "b621b2a8-488f-44db-b085-fe156f453943",
    "metadata": {},
    "outputs": [
@@ -830,7 +815,7 @@
     "import matplotlib.pyplot as plt\n",
     "from scipy.stats import spearmanr\n",
     "\n",
-    "# Sample data (replace with your real values)\n",
+    "# Sample data (to be replaces with real values)\n",
     "data = {\n",
     "    \"perplexity\": [32.5, 45.2, 28.1, 39.0, 50.3],\n",
     "    \"avg_word_length\": [4.1, 4.3, 4.0, 4.2, 4.5],\n",
@@ -853,10 +838,123 @@
     "plt.show()"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "45ee04fc-acab-4bba-ba06-e4cf4bca9fe5",
+   "metadata": {},
+   "source": [
+    "## Tree depth"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "id": "79f99787-c220-4f1d-93a9-59230363ec3f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def parse_sentence_block(text):\n",
+    "    lines = text.strip().split('\\n')\n",
+    "    result = []\n",
+    "    tokenlist = []\n",
+    "    for line in lines:\n",
+    "        # Split the line by tab and strip whitespace\n",
+    "        parts = tuple(line.strip().split('\\t'))\n",
+    "        # Only include lines that have exactly 4 parts\n",
+    "        if len(parts) == 4:\n",
+    "            parentidx =  int(parts[3])\n",
+    "            if '@@' in parts[2]:\n",
+    "                nonterm1 = parts[2].split('@@')[0]\n",
+    "                nonterm2 = parts[2].split('@@')[1]\n",
+    "            else:\n",
+    "                nonterm1 = parts[2]\n",
+    "                nonterm2 = '<nul>'\n",
+    "            postag = parts[1]\n",
+    "            token = parts[0]\n",
+    "            result.append((parentidx, nonterm1, nonterm2, postag))\n",
+    "            tokenlist.append(token)\n",
+    "    return result, tokenlist\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "id": "f567efb0-8b0b-4782-9345-052cf1785776",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "example_sentence = \"\"\"\n",
+    "<s>\t<s>\t<s>\t1\n",
+    "--\tponct\t<nul>@@<nul>\t1\n",
+    "Eh\tnpp\t<nul>@@<nul>\t1\n",
+    "bien?\tadv\tAP@@<nul>\t1\n",
+    "fit\tv\tVN@@<nul>\t2\n",
+    "-il\tcls-suj\tVN@@VPinf-OBJ\t3\n",
+    ".\tponct\t<nul>@@<nul>\t4\n",
+    "</s>\t</s>\t</s>\t4\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "id": "8d4ecba9-89b8-4000-a061-aa16aa68a404",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from transform import *\n",
+    "\n",
+    "def visualize_const_prediction(example_sent):\n",
+    "    parsed, tokenlist = parse_sentence_block(example_sent)\n",
+    "    tree = AttachJuxtaposeTree.totree(tokenlist, 'SENT')\n",
+    "    AttachJuxtaposeTree.action2tree(tree, parsed).pretty_print()\n",
+    "    nltk_tree = AttachJuxtaposeTree.action2tree(tree, parsed)\n",
+    "    #print(\"NLTK TREE\", nltk_tree)\n",
+    "    depth = nltk_tree.height() - 1  # NLTK includes the leaf level as height 1, so subtract 1 for tree depth \n",
+    "    print(\"Tree depth:\", depth)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "id": "bfd3abf3-b83a-4817-85ad-654daf72be88",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "               SENT                                 \n",
+      "                |                                    \n",
+      "               <s>                                  \n",
+      "                |                                    \n",
+      "               <s>                                  \n",
+      "  ______________|__________                          \n",
+      " |    |    |               AP                       \n",
+      " |    |    |     __________|________                 \n",
+      " |    |    |    |               VPinf-OBJ           \n",
+      " |    |    |    |     ______________|_______         \n",
+      " |    |    |    |    |                      VN      \n",
+      " |    |    |    |    |      ________________|____    \n",
+      " |    |    |    |    VN    |                |   </s>\n",
+      " |    |    |    |    |     |                |    |   \n",
+      " |  ponct npp  adv   v  cls-suj           ponct </s>\n",
+      " |    |    |    |    |     |                |    |   \n",
+      "<s>   --   Eh bien? fit   -il               .   </s>\n",
+      "\n",
+      "Tree depth: 8\n"
+     ]
+    }
+   ],
+   "source": [
+    "visualize_const_prediction(example_sentence)"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "3a6e3b53-7104-45ef-a4b5-e831bdd6ca6f",
+   "id": "bc51ab44-6885-45cc-bad2-6a43a7791fdb",
    "metadata": {},
    "outputs": [],
    "source": []
diff --git a/tania_scripts/.ipynb_checkpoints/transform-checkpoint.py b/tania_scripts/.ipynb_checkpoints/transform-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6
--- /dev/null
+++ b/tania_scripts/.ipynb_checkpoints/transform-checkpoint.py
@@ -0,0 +1,459 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
+
+import nltk
+import torch
+
+from supar.models.const.crf.transform import Tree
+from supar.utils.common import NUL
+from supar.utils.logging import get_logger
+from supar.utils.tokenizer import Tokenizer
+from supar.utils.transform import Sentence
+
+if TYPE_CHECKING:
+    from supar.utils import Field
+
+logger = get_logger(__name__)
+
+
+class AttachJuxtaposeTree(Tree):
+    r"""
+    :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class,
+    supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`.
+
+    Attributes:
+        WORD:
+            Words in the sentence.
+        POS:
+            Part-of-speech tags, or underscores if not available.
+        TREE:
+            The raw constituency tree in :class:`nltk.tree.Tree` format.
+        NODE:
+            The target node on each rightmost chain.
+        PARENT:
+            The label of the parent node of each terminal.
+        NEW:
+            The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children.
+            ``NUL`` represents the `Attach` action.
+    """
+
+    fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW']
+
+    def __init__(
+        self,
+        WORD: Optional[Union[Field, Iterable[Field]]] = None,
+        POS: Optional[Union[Field, Iterable[Field]]] = None,
+        TREE: Optional[Union[Field, Iterable[Field]]] = None,
+        NODE: Optional[Union[Field, Iterable[Field]]] = None,
+        PARENT: Optional[Union[Field, Iterable[Field]]] = None,
+        NEW: Optional[Union[Field, Iterable[Field]]] = None
+    ) -> Tree:
+        super().__init__()
+
+        self.WORD = WORD
+        self.POS = POS
+        self.TREE = TREE
+        self.NODE = NODE
+        self.PARENT = PARENT
+        self.NEW = NEW
+
+    @property
+    def src(self):
+        return self.WORD, self.POS, self.TREE
+
+    @property
+    def tgt(self):
+        return self.NODE, self.PARENT, self.NEW
+
+    @classmethod
+    def tree2action(cls, tree: nltk.Tree):
+        r"""
+        Converts a constituency tree into AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                A constituency tree in :class:`nltk.tree.Tree` format.
+
+        Returns:
+            A sequence of AttachJuxtapose actions.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = nltk.Tree.fromstring('''
+                                            (TOP
+                                              (S
+                                                (NP (_ Arthur))
+                                                (VP
+                                                  (_ is)
+                                                  (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons)))))
+                                                (_ .)))
+                                            ''')
+            >>> tree.pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+            >>> AttachJuxtaposeTree.tree2action(tree)
+            [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+             (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+             (0, '<nul>', '<nul>')]
+        """
+
+        def isroot(node):
+            return node == tree[0]
+
+        def isterminal(node):
+            return len(node) == 1 and not isinstance(node[0], nltk.Tree)
+
+        def last_leaf(node):
+            pos = ()
+            while True:
+                pos += (len(node) - 1,)
+                node = node[-1]
+                if isterminal(node):
+                    return node, pos
+
+        def parent(position):
+            return tree[position[:-1]]
+
+        def grand(position):
+            return tree[position[:-2]]
+
+        def detach(tree):
+            last, last_pos = last_leaf(tree)
+            siblings = parent(last_pos)[:-1]
+
+            if len(siblings) > 0:
+                last_subtree = last
+                last_subtree_siblings = siblings
+                parent_label = NUL
+            else:
+                last_subtree, last_pos = parent(last_pos), last_pos[:-1]
+                last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1]
+                parent_label = last_subtree.label()
+
+            target_pos, new_label, last_tree = 0, NUL, tree
+            if isroot(last_subtree):
+                last_tree = None
+            
+            elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]):
+                new_label = parent(last_pos).label()
+                new_label = new_label
+                target = last_subtree_siblings[0]
+                last_grand = grand(last_pos)
+                if last_grand is None:
+                    last_tree = targetistermina
+                else:
+                    last_grand[-1] = target
+                target_pos = len(last_pos) - 2
+            else:
+                target = parent(last_pos)
+                target.pop()
+                target_pos = len(last_pos) - 2
+            action = target_pos, parent_label, new_label
+            return action, last_tree
+        if tree is None:
+            return []
+        action, last_tree = detach(tree)
+        return cls.tree2action(last_tree) + [action]
+
+    @classmethod
+    def action2tree(
+        cls,
+        tree: nltk.Tree,
+        actions: List[Tuple[int, str, str]],
+        join: str = '::',
+    ) -> nltk.Tree:
+        r"""
+        Recovers a constituency tree from a sequence of AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                An empty tree that provides a base for building a result tree.
+            actions (List[Tuple[int, str, str]]):
+                A sequence of AttachJuxtapose actions.
+            join (str):
+                A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains.
+                Default: ``'::'``.
+
+        Returns:
+            A result constituency tree.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.action2tree(tree,
+                                                [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+                                                 (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+                                                 (0, '<nul>', '<nul>')]).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+        """
+
+        def target(node, depth):
+            node_pos = ()
+            for _ in range(depth):
+                node_pos += (len(node) - 1,)
+                node = node[-1]
+            return node, node_pos
+
+        def parent(tree, position):
+            return tree[position[:-1]]
+
+        def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree:
+            target_pos, parent_label, new_label, post = action
+            #print(target_pos, parent_label, new_label)
+            new_leaf = nltk.Tree(post, [terminal[0]])
+
+            # create the subtree to be inserted
+            new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf])
+            # find the target position at which to insert the new subtree
+            target_node = tree
+            if target_node is not None:
+                target_node, target_pos = target(target_node, target_pos)
+
+            # Attach
+            if new_label == NUL:
+                # attach the first token
+                if target_node is None:
+                    return new_subtree
+                target_node.append(new_subtree)
+            # Juxtapose
+            else:
+                new_subtree = nltk.Tree(new_label, [target_node, new_subtree])
+                if len(target_pos) > 0:
+                    parent_node = parent(tree, target_pos)
+                    parent_node[-1] = new_subtree
+                else:
+                    tree = new_subtree
+            return tree
+
+        tree, root, terminals = None, tree.label(), tree.pos()
+        for terminal, action in zip(terminals, actions):
+            tree = execute(tree, terminal, action)
+        # recover unary chains
+        nodes = [tree]
+        while nodes:
+            node = nodes.pop()
+            if isinstance(node, nltk.Tree):
+                nodes.extend(node)
+                if join in node.label():
+                    labels = node.label().split(join)
+                    node.set_label(labels[0])
+                    subtree = nltk.Tree(labels[-1], node)
+                    for label in reversed(labels[1:-1]):
+                        subtree = nltk.Tree(label, [subtree])
+                    node[:] = [subtree]
+        return nltk.Tree(root, [tree])
+
+    @classmethod
+    def action2span(
+        cls,
+        action: torch.Tensor,
+        spans: torch.Tensor = None,
+        nul_index: int = -1,
+        mask: torch.BoolTensor = None
+    ) -> torch.Tensor:
+        r"""
+        Converts a batch of the tensorized action at a given step into spans.
+
+        Args:
+            action (~torch.Tensor): ``[3, batch_size]``.
+                A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels.
+            spans (~torch.Tensor):
+                Spans generated at previous steps, ``None`` at the first step. Default: ``None``.
+            nul_index (int):
+                The index for the obj:`NUL` token, representing the Attach action. Default: -1.
+            mask (~torch.BoolTensor): ``[batch_size]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            A tensor representing a batch of spans for the given step.
+
+        Examples:
+            >>> from collections import Counter
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab
+            >>> from supar.utils.common import NUL
+            >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL),
+                                             (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL),
+                                             (0, NUL, NUL)])
+            >>> vocab = Vocab(Counter(sorted(set([*parents, *news]))))
+            >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1)
+            >>> spans = None
+            >>> for action in actions.unbind(-1):
+            ...     spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL])
+            ...
+            >>> spans
+            tensor([[[-1,  1, -1, -1, -1, -1, -1,  3],
+                     [-1, -1, -1, -1, -1, -1,  4, -1],
+                     [-1, -1, -1,  1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1,  2, -1],
+                     [-1, -1, -1, -1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1]]])
+            >>> sequence = torch.where(spans.ge(0))
+            >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]]))
+            >>> sequence
+            [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')]
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+
+        """
+
+        # [batch_size]
+        target, parent, new = action
+        if spans is None:
+            spans = action.new_full((action.shape[1], 2, 2), -1)
+            spans[:, 0, 1] = parent
+            return spans
+        if mask is None:
+            mask = torch.ones_like(target, dtype=bool)
+        juxtapose_mask = new.ne(nul_index) & mask
+        # ancestor nodes are those on the rightmost chain and higher than the target node
+        # [batch_size, seq_len]
+        rightmost_mask = spans[..., -1].ge(0)
+        ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1
+        # should not include the target node for the Juxtapose action
+        ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1))
+        target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1]
+        # the right boundaries of ancestor nodes should be aligned with the new generated terminals
+        spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1)
+        spans[..., -2].masked_fill_(ancestor_mask, -1)
+        spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask]
+        spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask]
+        # [batch_size, seq_len+1, seq_len+1]
+        spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1)
+        return spans
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        **kwargs
+    ) -> List[AttachJuxtaposeTreeSentence]:
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                A filename or a list of instances.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+
+        Returns:
+            A list of :class:`AttachJuxtaposeTreeSentence` instances.
+        """
+
+        if lang is not None:
+            tokenizer = Tokenizer(lang)
+        if isinstance(data, str) and os.path.exists(data):
+            if data.endswith('.txt'):
+                data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1)
+            else:
+                data = open(data)
+        else:
+            if lang is not None:
+                data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+
+        index = 0
+        for s in data:
+
+            try:
+                tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            except ValueError:
+                logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!")
+                continue
+            except IndexError:
+                tree = nltk.Tree.fromstring('(S ' + s + ')')
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            else:
+                yield sentence
+                index += 1
+        self.root = tree.label()
+
+
+class AttachJuxtaposeTreeSentence(Sentence):
+    r"""
+    Args:
+        transform (AttachJuxtaposeTree):
+            A :class:`AttachJuxtaposeTree` object.
+        tree (nltk.tree.Tree):
+            A :class:`nltk.tree.Tree` object.
+        index (Optional[int]):
+            Index of the sentence in the corpus. Default: ``None``.
+    """
+
+    def __init__(
+        self,
+        transform: AttachJuxtaposeTree,
+        tree: nltk.Tree,
+        index: Optional[int] = None
+    ) -> AttachJuxtaposeTreeSentence:
+        super().__init__(transform, index)
+
+        words, tags = zip(*tree.pos())
+        nodes, parents, news = None, None, None
+        if transform.training:
+            oracle_tree = tree.copy(True)
+            # the root node must have a unary chain
+            if len(oracle_tree) > 1:
+                oracle_tree[:] = [nltk.Tree('*', oracle_tree)]
+            oracle_tree.collapse_unary(joinChar='::')
+            if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree):
+                oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]])
+            nodes, parents, news = zip(*transform.tree2action(oracle_tree))
+        tags = [x.split("##")[0] for x in tags]
+        self.values = [words, tags, tree, nodes, parents, news]
+
+    def __repr__(self):
+        return self.values[-4].pformat(1000000)
+
+    def pretty_print(self):
+        self.values[-4].pretty_print()
diff --git a/tania_scripts/__pycache__/transform.cpython-311.pyc b/tania_scripts/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fa71196328a1158b12ae2fffe91d227ae067efc
Binary files /dev/null and b/tania_scripts/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/.ipynb_checkpoints/__init__-checkpoint.py b/tania_scripts/supar/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..146f3df423731c6ccad8c46a95acf8c26649da6e
--- /dev/null
+++ b/tania_scripts/supar/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+
+from .models import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos,
+                     BiaffineDependencyParser,
+                     BiaffineSemanticDependencyParser, CRF2oDependencyParser,
+                     CRFConstituencyParser, CRFDependencyParser, SLDependencyParser, ArcEagerDependencyParser,
+                     TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser,
+                     VIDependencyParser, VISemanticDependencyParser)
+from .parser import Parser
+from .structs import (BiLexicalizedConstituencyCRF, ConstituencyCRF,
+                      ConstituencyLBP, ConstituencyMFVI, Dependency2oCRF,
+                      DependencyCRF, DependencyLBP, DependencyMFVI,
+                      LinearChainCRF, MatrixTree, SemanticDependencyLBP,
+                      SemanticDependencyMFVI, SemiMarkovCRF)
+
+
+
+__all__ = ['Parser',
+
+           # Dependency Parsing
+           'BiaffineDependencyParser',
+           'CRFDependencyParser',
+           'CRF2oDependencyParser',
+           'VIDependencyParser',
+           'SLDependencyParser',
+           'ArcEagerDependencyParser',
+
+           # Constituent Parsing
+           'AttachJuxtaposeConstituencyParser',
+           'CRFConstituencyParser',
+           'TetraTaggingConstituencyParser',
+           'VIConstituencyParser',
+           'SLConstituentParser',
+
+           # Semantic Parsing
+           'BiaffineSemanticDependencyParser',
+           'VISemanticDependencyParser',
+           'LinearChainCRF',
+           'SemiMarkovCRF',
+
+           # transforms
+           'MatrixTree',
+           'DependencyCRF',
+           'Dependency2oCRF',
+           'ConstituencyCRF',
+           'BiLexicalizedConstituencyCRF',
+           'DependencyLBP',
+           'DependencyMFVI',
+           'ConstituencyLBP',
+           'ConstituencyMFVI',
+           'SemanticDependencyLBP',
+           'SemanticDependencyMFVI']
+
+__version__ = '1.1.4'
+
+PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser,
+                                             CRFDependencyParser,
+                                             CRF2oDependencyParser,
+                                             VIDependencyParser,
+                                             SLDependencyParser,
+                                             ArcEagerDependencyParser,
+                                             AttachJuxtaposeConstituencyParser,
+                                             CRFConstituencyParser,
+                                             TetraTaggingConstituencyParser,
+                                             VIConstituencyParser,
+                                             SLConstituentParser,
+                                             BiaffineSemanticDependencyParser,
+                                             VISemanticDependencyParser]}
+
+SRC = {'github': 'https://github.com/yzhangcs/parser/releases/download',
+       'hlt': 'http://hlt.suda.edu.cn/~yzhang/supar'}
+NAME = {
+    'biaffine-dep-en': 'ptb.biaffine.dep.lstm.char',
+    'biaffine-dep-zh': 'ctb7.biaffine.dep.lstm.char',
+    'crf2o-dep-en': 'ptb.crf2o.dep.lstm.char',
+    'crf2o-dep-zh': 'ctb7.crf2o.dep.lstm.char',
+    'biaffine-dep-roberta-en': 'ptb.biaffine.dep.roberta',
+    'biaffine-dep-electra-zh': 'ctb7.biaffine.dep.electra',
+    'biaffine-dep-xlmr': 'ud.biaffine.dep.xlmr',
+    'crf-con-en': 'ptb.crf.con.lstm.char',
+    'crf-con-zh': 'ctb7.crf.con.lstm.char',
+    'crf-con-roberta-en': 'ptb.crf.con.roberta',
+    'crf-con-electra-zh': 'ctb7.crf.con.electra',
+    'crf-con-xlmr': 'spmrl.crf.con.xlmr',
+    'biaffine-sdp-en': 'dm.biaffine.sdp.lstm.tag-char-lemma',
+    'biaffine-sdp-zh': 'semeval16.biaffine.sdp.lstm.tag-char-lemma',
+    'vi-sdp-en': 'dm.vi.sdp.lstm.tag-char-lemma',
+    'vi-sdp-zh': 'semeval16.vi.sdp.lstm.tag-char-lemma',
+    'vi-sdp-roberta-en': 'dm.vi.sdp.roberta',
+    'vi-sdp-electra-zh': 'semeval16.vi.sdp.electra'
+}
+MODEL = {src: {n: f"{link}/v1.1.0/{m}.zip" for n, m in NAME.items()} for src, link in SRC.items()}
+CONFIG = {src: {n: f"{link}/v1.1.0/{m}.ini" for n, m in NAME.items()} for src, link in SRC.items()}
+
+
+def compatible():
+    import sys
+    supar = sys.modules[__name__]
+    if supar.__version__ < '1.2':
+        sys.modules['supar.utils.transform'].CoNLL = supar.models.dep.biaffine.transform.CoNLL
+        sys.modules['supar.utils.transform'].Tree = supar.models.const.crf.transform.Tree
+        sys.modules['supar.parsers'] = supar.models
+        sys.modules['supar.parsers.con'] = supar.models.const
+
+
+compatible()
diff --git a/tania_scripts/supar/.ipynb_checkpoints/parser-checkpoint.py b/tania_scripts/supar/.ipynb_checkpoints/parser-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1c372e52475d87780f414396606f71a7494653b
--- /dev/null
+++ b/tania_scripts/supar/.ipynb_checkpoints/parser-checkpoint.py
@@ -0,0 +1,620 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import contextlib
+import os
+import shutil
+import sys
+import tempfile
+import pickle
+from contextlib import contextmanager
+from datetime import datetime, timedelta
+from typing import Any, Iterable, Union
+
+import dill
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.cuda.amp import GradScaler
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler
+
+import supar
+from supar.utils import Config, Dataset
+from supar.utils.field import Field
+from supar.utils.fn import download, get_rng_state, set_rng_state
+from supar.utils.logging import get_logger, init_logger, progress_bar
+from supar.utils.metric import Metric
+from supar.utils.optim import InverseSquareRootLR, LinearLR
+from supar.utils.parallel import DistributedDataParallel as DDP
+from supar.utils.parallel import gather, is_dist, is_master, reduce
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class Parser(object):
+
+    NAME = None
+    MODEL = None
+
+    def __init__(self, args, model, transform):
+        self.args = args
+        self.model = model
+        self.transform = transform
+
+    @property
+    def device(self):
+        return 'cuda' if torch.cuda.is_available() else 'cpu'
+
+    @property
+    def sync_grad(self):
+        return self.step % self.args.update_steps == 0 or self.step % self.n_batches == 0
+
+    @contextmanager
+    def sync(self):
+        context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext')
+        if is_dist() and not self.sync_grad:
+            context = self.model.no_sync
+        with context():
+            yield
+
+    @contextmanager
+    def join(self):
+        context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext')
+        if not is_dist():
+            with context():
+                yield
+        elif self.model.training:
+            with self.model.join():
+                yield
+        else:
+            try:
+                dist_model = self.model
+                # https://github.com/pytorch/pytorch/issues/54059
+                if hasattr(self.model, 'module'):
+                    self.model = self.model.module
+                yield
+            finally:
+                self.model = dist_model
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int,
+        patience: int,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        clip: float = 5.0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ) -> None:
+        r"""
+        Args:
+            train/dev/test (Union[str, Iterable]):
+                Filenames of the train/dev/test datasets.
+            epochs (int):
+                The number of training iterations.
+            patience (int):
+                The number of consecutive iterations after which the training process would be early stopped if no improvement.
+            batch_size (int):
+                The number of tokens in each batch. Default: 5000.
+            update_steps (int):
+                Gradient accumulation steps. Default: 1.
+            buckets (int):
+                The number of buckets that sentences are assigned to. Default: 32.
+            workers (int):
+                The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
+            clip (float):
+                Clips gradient of an iterable of parameters at specified value. Default: 5.0.
+            amp (bool):
+                Specifies whether to use automatic mixed precision. Default: ``False``.
+            cache (bool):
+                If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
+            verbose (bool):
+                If ``True``, increases the output verbosity. Default: ``True``.
+        """
+
+        args = self.args.update(locals())
+        init_logger(logger, verbose=args.verbose)
+
+        self.transform.train()
+        batch_size = batch_size // update_steps
+        eval_batch_size = args.get('eval_batch_size', batch_size)
+        if is_dist():
+            batch_size = batch_size // dist.get_world_size()
+            eval_batch_size = eval_batch_size // dist.get_world_size()
+        logger.info("Loading the data")
+        if args.cache:
+            args.bin = os.path.join(os.path.dirname(args.path), 'bin')
+        args.even = args.get('even', is_dist())
+        
+        train = Dataset(self.transform, args.train, **args).build(batch_size=batch_size,
+                                                                  n_buckets=buckets,
+                                                                  shuffle=True,
+                                                                  distributed=is_dist(),
+                                                                  even=args.even,
+                                                                  n_workers=workers)
+        dev = Dataset(self.transform, args.dev, **args).build(batch_size=eval_batch_size,
+                                                              n_buckets=buckets,
+                                                              shuffle=False,
+                                                              distributed=is_dist(),
+                                                              even=False,
+                                                              n_workers=workers)
+        logger.info(f"{'train:':6} {train}")
+        if not args.test:
+            logger.info(f"{'dev:':6} {dev}\n")
+        else:
+            test = Dataset(self.transform, args.test, **args).build(batch_size=eval_batch_size,
+                                                                    n_buckets=buckets,
+                                                                    shuffle=False,
+                                                                    distributed=is_dist(),
+                                                                    even=False,
+                                                                    n_workers=workers)
+            logger.info(f"{'dev:':6} {dev}")
+            logger.info(f"{'test:':6} {test}\n")
+        loader, sampler = train.loader, train.loader.batch_sampler
+        args.steps = len(loader) * epochs // args.update_steps
+        args.save(f"{args.path}.yaml")
+
+        self.optimizer = self.init_optimizer()
+        self.scheduler = self.init_scheduler()
+        self.scaler = GradScaler(enabled=args.amp)
+
+        if dist.is_initialized():
+            self.model = DDP(module=self.model,
+                             device_ids=[args.local_rank],
+                             find_unused_parameters=args.get('find_unused_parameters', True),
+                             static_graph=args.get('static_graph', False))
+            if args.amp:
+                from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook
+                self.model.register_comm_hook(dist.group.WORLD, fp16_compress_hook)
+        if args.wandb and is_master():
+            import wandb
+            # start a new wandb run to track this script
+            wandb.init(config=args.primitive_config,
+                       project=args.get('project', self.NAME),
+                       name=args.get('name', args.path),
+                       resume=self.args.checkpoint)
+        self.step, self.epoch, self.best_e, self.patience = 1, 1, 1, patience
+        # uneven batches are excluded
+        self.n_batches = min(gather(len(loader))) if is_dist() else len(loader)
+        self.best_metric, self.elapsed = Metric(), timedelta()
+        if args.checkpoint:
+            try:
+                self.optimizer.load_state_dict(self.checkpoint_state_dict.pop('optimizer_state_dict'))
+                self.scheduler.load_state_dict(self.checkpoint_state_dict.pop('scheduler_state_dict'))
+                self.scaler.load_state_dict(self.checkpoint_state_dict.pop('scaler_state_dict'))
+                set_rng_state(self.checkpoint_state_dict.pop('rng_state'))
+                for k, v in self.checkpoint_state_dict.items():
+                    setattr(self, k, v)
+                sampler.set_epoch(self.epoch)
+            except AttributeError:
+                logger.warning("No checkpoint found. Try re-launching the training procedure instead")
+
+        for epoch in range(self.epoch, args.epochs + 1):
+            start = datetime.now()
+            bar, metric = progress_bar(loader), Metric()
+
+            logger.info(f"Epoch {epoch} / {args.epochs}:")
+            self.model.train()
+            with self.join():
+                # we should reset `step` as the number of batches in different processes is not necessarily equal
+                self.step = 1
+                for batch in bar:
+                    with self.sync():
+                        with torch.autocast(self.device, enabled=args.amp):
+                            loss = self.train_step(batch)
+                        self.backward(loss)
+                    if self.sync_grad:
+                        self.clip_grad_norm_(self.model.parameters(), args.clip)
+                        self.scaler.step(self.optimizer)
+                        self.scaler.update()
+                        self.scheduler.step()
+                        self.optimizer.zero_grad(True)
+
+                    bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f}")
+                    # log metrics to wandb
+                    if args.wandb and is_master():
+                        wandb.log({'lr': self.scheduler.get_last_lr()[0], 'loss': loss})
+                    self.step += 1
+                logger.info(f"{bar.postfix}")
+            self.model.eval()
+            with self.join(), torch.autocast(self.device, enabled=args.amp):
+                metric = self.reduce(sum([self.eval_step(i) for i in progress_bar(dev.loader)], Metric()))
+                logger.info(f"{'dev:':5} {metric}")
+                if args.wandb and is_master():
+                    wandb.log({'dev': metric.values, 'epochs': epoch})
+                if args.test:
+                    test_metric = sum([self.eval_step(i) for i in progress_bar(test.loader)], Metric())
+                    logger.info(f"{'test:':5} {self.reduce(test_metric)}")
+                    if args.wandb and is_master():
+                        wandb.log({'test': test_metric.values, 'epochs': epoch})
+
+            t = datetime.now() - start
+            self.epoch += 1
+            self.patience -= 1
+            self.elapsed += t
+
+            if metric > self.best_metric:
+                self.best_e, self.patience, self.best_metric = epoch, patience, metric
+                if is_master():
+                    self.save_checkpoint(args.path)
+                logger.info(f"{t}s elapsed (saved)\n")
+            else:
+                logger.info(f"{t}s elapsed\n")
+            if self.patience < 1:
+                break
+        if is_dist():
+            dist.barrier()
+
+        best = self.load(**args)
+        # only allow the master device to save models
+        if is_master():
+            best.save(args.path)
+
+        logger.info(f"Epoch {self.best_e} saved")
+        logger.info(f"{'dev:':5} {self.best_metric}")
+        if args.test:
+            best.model.eval()
+            with best.join():
+                test_metric = sum([best.eval_step(i) for i in progress_bar(test.loader)], Metric())
+                logger.info(f"{'test:':5} {best.reduce(test_metric)}")
+        logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch")
+        if args.wandb and is_master():
+            wandb.finish()
+
+        self.evaluate(data=args.test, batch_size=batch_size)
+        self.predict(args.test, batch_size=batch_size, buckets=buckets, workers=workers)
+
+        with open(f'{self.args.folder}/status', 'w') as file:
+            file.write('finished')
+
+
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                The data for evaluation. Both a filename and a list of instances are allowed.
+            batch_size (int):
+                The number of tokens in each batch. Default: 5000.
+            buckets (int):
+                The number of buckets that sentences are assigned to. Default: 8.
+            workers (int):
+                The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
+            amp (bool):
+                Specifies whether to use automatic mixed precision. Default: ``False``.
+            cache (bool):
+                If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
+            verbose (bool):
+                If ``True``, increases the output verbosity. Default: ``True``.
+
+        Returns:
+            The evaluation results.
+        """
+
+        args = self.args.update(locals())
+        init_logger(logger, verbose=args.verbose)
+
+        self.transform.train()
+        logger.info("Loading the data")
+        if args.cache:
+            args.bin = os.path.join(os.path.dirname(args.path), 'bin')
+        if is_dist():
+            batch_size = batch_size // dist.get_world_size()
+        data = Dataset(self.transform, **args)
+        data.build(batch_size=batch_size,
+                   n_buckets=buckets,
+                   shuffle=False,
+                   distributed=is_dist(),
+                   even=False,
+                   n_workers=workers)
+        logger.info(f"\n{data}")
+
+        logger.info("Evaluating the data")
+        start = datetime.now()
+        self.model.eval()
+        with self.join():
+            bar, metric = progress_bar(data.loader), Metric()
+            for batch in bar:
+                metric += self.eval_step(batch)
+                bar.set_postfix_str(metric)
+            metric = self.reduce(metric)
+        elapsed = datetime.now() - start
+        logger.info(f"{metric}")
+        logger.info(f"{elapsed}s elapsed, "
+                    f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, "
+                    f"{len(data)/elapsed.total_seconds():.2f} Sents/s")
+        os.makedirs(os.path.dirname(self.args.folder + '/metrics.pickle'), exist_ok=True)
+        with open(f'{self.args.folder}/metrics.pickle', 'wb') as file:
+            pickle.dump(obj=metric, file=file)
+
+        return metric
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                The data for prediction.
+                - a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
+                - a list of instances.
+            pred (str):
+                If specified, the predicted results will be saved to the file. Default: ``None``.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+            prob (bool):
+                If ``True``, outputs the probabilities. Default: ``False``.
+            batch_size (int):
+                The number of tokens in each batch. Default: 5000.
+            buckets (int):
+                The number of buckets that sentences are assigned to. Default: 8.
+            workers (int):
+                The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
+            amp (bool):
+                Specifies whether to use automatic mixed precision. Default: ``False``.
+            cache (bool):
+                If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
+            verbose (bool):
+                If ``True``, increases the output verbosity. Default: ``True``.
+
+        Returns:
+            A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
+        """
+
+        args = self.args.update(locals())
+        init_logger(logger, verbose=args.verbose)
+
+        if self.args.use_vq:
+            self.model.passes_remaining = 0
+            self.model.vq.observe_steps_remaining = 0
+
+        self.transform.eval()
+        if args.prob:
+            self.transform.append(Field('probs'))
+
+        #logger.info("Loading the data")
+        if args.cache:
+            args.bin = os.path.join(os.path.dirname(args.path), 'bin')
+        if is_dist():
+            batch_size = batch_size // dist.get_world_size()
+        data = Dataset(self.transform, **args)
+        data.build(batch_size=batch_size,
+                   n_buckets=buckets,
+                   shuffle=False,
+                   distributed=is_dist(),
+                   even=False,
+                   n_workers=workers)
+
+        #logger.info(f"\n{data}")
+
+        #logger.info("Making predictions on the data")
+        start = datetime.now()
+        self.model.eval()
+        #with tempfile.TemporaryDirectory() as t:
+            # we have clustered the sentences by length here to speed up prediction,
+            # so the order of the yielded sentences can't be guaranteed
+        for batch in progress_bar(data.loader):
+            #batch, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict = self.pred_step(batch)
+            *predicted_values, = self.pred_step(batch)
+            #print('429 supar/parser.py ', batch.sentences)
+            #print(head_preds, deprel_preds, stack_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text)
+            #logger.info(f"Saving predicted results to {pred}")
+           # with open(pred, 'w') as f:
+                    
+
+            elapsed = datetime.now() - start
+
+            #if is_dist():
+            #    dist.barrier()
+            #tdirs = gather(t) if is_dist() else (t,)
+            #if pred is not None and is_master():
+                #logger.info(f"Saving predicted results to {pred}")
+            """with open(pred, 'w') as f:
+                    # merge all predictions into one single file
+                    if is_dist() or args.cache:
+                        sentences = (os.path.join(i, s) for i in tdirs for s in os.listdir(i))
+                        for i in progress_bar(sorted(sentences, key=lambda x: int(os.path.basename(x)))):
+                            with open(i) as s:
+                                shutil.copyfileobj(s, f)
+                    else:
+                        for s in progress_bar(data):
+                            f.write(str(s) + '\n')"""
+            # exit util all files have been merged
+            if is_dist():
+                dist.barrier()
+        #logger.info(f"{elapsed}s elapsed, "
+        #            f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, "
+        #            f"{len(data)/elapsed.total_seconds():.2f} Sents/s")
+
+        if not cache:
+            #return data, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict
+            return *predicted_values,
+
+    def backward(self, loss: torch.Tensor, **kwargs):
+        loss /= self.args.update_steps
+        if hasattr(self, 'scaler'):
+            self.scaler.scale(loss).backward(**kwargs)
+        else:
+            loss.backward(**kwargs)
+
+    def clip_grad_norm_(
+        self,
+        params: Union[Iterable[torch.Tensor], torch.Tensor],
+        max_norm: float,
+        norm_type: float = 2
+    ) -> torch.Tensor:
+        self.scaler.unscale_(self.optimizer)
+        return nn.utils.clip_grad_norm_(params, max_norm, norm_type)
+
+    def clip_grad_value_(
+        self,
+        params: Union[Iterable[torch.Tensor], torch.Tensor],
+        clip_value: float
+    ) -> None:
+        self.scaler.unscale_(self.optimizer)
+        return nn.utils.clip_grad_value_(params, clip_value)
+
+    def reduce(self, obj: Any) -> Any:
+        if not is_dist():
+            return obj
+        return reduce(obj)
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        ...
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> Metric:
+        ...
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        ...
+
+    def init_optimizer(self) -> Optimizer:
+        if self.args.encoder in ('lstm', 'transformer'):
+            optimizer = Adam(params=self.model.parameters(),
+                             lr=self.args.lr,
+                             betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)),
+                             eps=self.args.get('eps', 1e-8),
+                             weight_decay=self.args.get('weight_decay', 0))
+        else:
+            # we found that Huggingface's AdamW is more robust and empirically better than the native implementation
+            from transformers import AdamW
+            optimizer = AdamW(params=[{'params': p, 'lr': self.args.lr * (1 if n.startswith('encoder') else self.args.lr_rate)}
+                                      for n, p in self.model.named_parameters()],
+                              lr=self.args.lr,
+                              betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)),
+                              eps=self.args.get('eps', 1e-8),
+                              weight_decay=self.args.get('weight_decay', 0))
+        return optimizer
+
+    def init_scheduler(self) -> _LRScheduler:
+        if self.args.encoder == 'lstm':
+            scheduler = ExponentialLR(optimizer=self.optimizer,
+                                      gamma=self.args.decay**(1/self.args.decay_steps))
+        elif self.args.encoder == 'transformer':
+            scheduler = InverseSquareRootLR(optimizer=self.optimizer,
+                                            warmup_steps=self.args.warmup_steps)
+        else:
+            scheduler = LinearLR(optimizer=self.optimizer,
+                                 warmup_steps=self.args.get('warmup_steps', int(self.args.steps*self.args.get('warmup', 0))),
+                                 steps=self.args.steps)
+        return scheduler
+
+    @classmethod
+    def build(cls, path, **kwargs):
+        ...
+
+    @classmethod
+    def load(
+        cls,
+        path: str,
+        reload: bool = True,
+        src: str = 'github',
+        checkpoint: bool = False,
+        **kwargs
+    ) -> Parser:
+        r"""
+        Loads a parser with data fields and pretrained model parameters.
+
+        Args:
+            path (str):
+                - a string with the shortcut name of a pretrained model defined in ``supar.MODEL``
+                  to load from cache or download, e.g., ``'biaffine-dep-en'``.
+                - a local path to a pretrained model, e.g., ``./<path>/model``.
+            reload (bool):
+                Whether to discard the existing cache and force a fresh download. Default: ``False``.
+            src (str):
+                Specifies where to download the model.
+                ``'github'``: github release page.
+                ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
+                Default: ``'github'``.
+            checkpoint (bool):
+                If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
+
+        Examples:
+            >>> from supar import Parser
+            >>> parser = Parser.load('biaffine-dep-en')
+            >>> parser = Parser.load('./ptb.biaffine.dep.lstm.char')
+        """
+
+        args = Config(**locals())
+        if not os.path.exists(path):
+            path = download(supar.MODEL[src].get(path, path), reload=reload)
+        state = torch.load(path, map_location='cpu', weights_only=False)
+        #torch.load(path, map_location='cpu')
+        cls = supar.PARSER[state['name']] if cls.NAME is None else cls
+        args = state['args'].update(args)
+        #print('ARGS', args)
+        model = cls.MODEL(**args)
+        model.load_pretrained(state['pretrained'])
+        model.load_state_dict(state['state_dict'], True)
+        transform = state['transform']
+        parser = cls(args, model, transform)
+        parser.checkpoint_state_dict = state.get('checkpoint_state_dict', None) if checkpoint else None
+        parser.model.to(parser.device)
+        return parser
+
+    def save(self, path: str) -> None:
+        model = self.model
+        if hasattr(model, 'module'):
+            model = self.model.module
+        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
+        pretrained = state_dict.pop('pretrained.weight', None)
+        state = {'name': self.NAME,
+                 'args': model.args,
+                 'state_dict': state_dict,
+                 'pretrained': pretrained,
+                 'transform': self.transform}
+        torch.save(state, path, pickle_module=dill)
+
+    def save_checkpoint(self, path: str) -> None:
+        model = self.model
+        if hasattr(model, 'module'):
+            model = self.model.module
+        checkpoint_state_dict = {k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']}
+        checkpoint_state_dict.update({'optimizer_state_dict': self.optimizer.state_dict(),
+                                      'scheduler_state_dict': self.scheduler.state_dict(),
+                                      'scaler_state_dict': self.scaler.state_dict(),
+                                      'rng_state': get_rng_state()})
+        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
+        pretrained = state_dict.pop('pretrained.weight', None)
+        state = {'name': self.NAME,
+                 'args': model.args,
+                 'state_dict': state_dict,
+                 'pretrained': pretrained,
+                 'checkpoint_state_dict': checkpoint_state_dict,
+                 'transform': self.transform}
+        torch.save(state, path, pickle_module=dill)
diff --git a/tania_scripts/supar/__init__.py b/tania_scripts/supar/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..146f3df423731c6ccad8c46a95acf8c26649da6e
--- /dev/null
+++ b/tania_scripts/supar/__init__.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+
+from .models import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos,
+                     BiaffineDependencyParser,
+                     BiaffineSemanticDependencyParser, CRF2oDependencyParser,
+                     CRFConstituencyParser, CRFDependencyParser, SLDependencyParser, ArcEagerDependencyParser,
+                     TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser,
+                     VIDependencyParser, VISemanticDependencyParser)
+from .parser import Parser
+from .structs import (BiLexicalizedConstituencyCRF, ConstituencyCRF,
+                      ConstituencyLBP, ConstituencyMFVI, Dependency2oCRF,
+                      DependencyCRF, DependencyLBP, DependencyMFVI,
+                      LinearChainCRF, MatrixTree, SemanticDependencyLBP,
+                      SemanticDependencyMFVI, SemiMarkovCRF)
+
+
+
+__all__ = ['Parser',
+
+           # Dependency Parsing
+           'BiaffineDependencyParser',
+           'CRFDependencyParser',
+           'CRF2oDependencyParser',
+           'VIDependencyParser',
+           'SLDependencyParser',
+           'ArcEagerDependencyParser',
+
+           # Constituent Parsing
+           'AttachJuxtaposeConstituencyParser',
+           'CRFConstituencyParser',
+           'TetraTaggingConstituencyParser',
+           'VIConstituencyParser',
+           'SLConstituentParser',
+
+           # Semantic Parsing
+           'BiaffineSemanticDependencyParser',
+           'VISemanticDependencyParser',
+           'LinearChainCRF',
+           'SemiMarkovCRF',
+
+           # transforms
+           'MatrixTree',
+           'DependencyCRF',
+           'Dependency2oCRF',
+           'ConstituencyCRF',
+           'BiLexicalizedConstituencyCRF',
+           'DependencyLBP',
+           'DependencyMFVI',
+           'ConstituencyLBP',
+           'ConstituencyMFVI',
+           'SemanticDependencyLBP',
+           'SemanticDependencyMFVI']
+
+__version__ = '1.1.4'
+
+PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser,
+                                             CRFDependencyParser,
+                                             CRF2oDependencyParser,
+                                             VIDependencyParser,
+                                             SLDependencyParser,
+                                             ArcEagerDependencyParser,
+                                             AttachJuxtaposeConstituencyParser,
+                                             CRFConstituencyParser,
+                                             TetraTaggingConstituencyParser,
+                                             VIConstituencyParser,
+                                             SLConstituentParser,
+                                             BiaffineSemanticDependencyParser,
+                                             VISemanticDependencyParser]}
+
+SRC = {'github': 'https://github.com/yzhangcs/parser/releases/download',
+       'hlt': 'http://hlt.suda.edu.cn/~yzhang/supar'}
+NAME = {
+    'biaffine-dep-en': 'ptb.biaffine.dep.lstm.char',
+    'biaffine-dep-zh': 'ctb7.biaffine.dep.lstm.char',
+    'crf2o-dep-en': 'ptb.crf2o.dep.lstm.char',
+    'crf2o-dep-zh': 'ctb7.crf2o.dep.lstm.char',
+    'biaffine-dep-roberta-en': 'ptb.biaffine.dep.roberta',
+    'biaffine-dep-electra-zh': 'ctb7.biaffine.dep.electra',
+    'biaffine-dep-xlmr': 'ud.biaffine.dep.xlmr',
+    'crf-con-en': 'ptb.crf.con.lstm.char',
+    'crf-con-zh': 'ctb7.crf.con.lstm.char',
+    'crf-con-roberta-en': 'ptb.crf.con.roberta',
+    'crf-con-electra-zh': 'ctb7.crf.con.electra',
+    'crf-con-xlmr': 'spmrl.crf.con.xlmr',
+    'biaffine-sdp-en': 'dm.biaffine.sdp.lstm.tag-char-lemma',
+    'biaffine-sdp-zh': 'semeval16.biaffine.sdp.lstm.tag-char-lemma',
+    'vi-sdp-en': 'dm.vi.sdp.lstm.tag-char-lemma',
+    'vi-sdp-zh': 'semeval16.vi.sdp.lstm.tag-char-lemma',
+    'vi-sdp-roberta-en': 'dm.vi.sdp.roberta',
+    'vi-sdp-electra-zh': 'semeval16.vi.sdp.electra'
+}
+MODEL = {src: {n: f"{link}/v1.1.0/{m}.zip" for n, m in NAME.items()} for src, link in SRC.items()}
+CONFIG = {src: {n: f"{link}/v1.1.0/{m}.ini" for n, m in NAME.items()} for src, link in SRC.items()}
+
+
+def compatible():
+    import sys
+    supar = sys.modules[__name__]
+    if supar.__version__ < '1.2':
+        sys.modules['supar.utils.transform'].CoNLL = supar.models.dep.biaffine.transform.CoNLL
+        sys.modules['supar.utils.transform'].Tree = supar.models.const.crf.transform.Tree
+        sys.modules['supar.parsers'] = supar.models
+        sys.modules['supar.parsers.con'] = supar.models.const
+
+
+compatible()
diff --git a/tania_scripts/supar/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5c3cf986e39a18d48492fcd3a44278864d4b01b
Binary files /dev/null and b/tania_scripts/supar/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56a7800cb9807c641e2b80be758fb03c3bf53e19
Binary files /dev/null and b/tania_scripts/supar/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cfcfd5b0e97b495425c09456acbce3e34f9f610
Binary files /dev/null and b/tania_scripts/supar/__pycache__/__init__.cpython-39.pyc differ
diff --git a/tania_scripts/supar/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f9c7f70c814a90cca12c500a2fd2a58d1bca963
Binary files /dev/null and b/tania_scripts/supar/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bda67fe0c4d0c1e433a13cbc03f01e01e490e0d9
Binary files /dev/null and b/tania_scripts/supar/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..854f1c47cee59890cdb61da54ee728faf5093ffb
Binary files /dev/null and b/tania_scripts/supar/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dcf19a63a0a627a255230554792697d13dc2fb5
Binary files /dev/null and b/tania_scripts/supar/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/__pycache__/vector_quantize.cpython-310.pyc b/tania_scripts/supar/__pycache__/vector_quantize.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..821f6b3037a8f6dcf6144badeb248d0ba76a3cf0
Binary files /dev/null and b/tania_scripts/supar/__pycache__/vector_quantize.cpython-310.pyc differ
diff --git a/tania_scripts/supar/cmds/__init__.py b/tania_scripts/supar/cmds/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tania_scripts/supar/cmds/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/cmds/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1abebfc6e576b314b0012bd5702c707a2e1f663
Binary files /dev/null and b/tania_scripts/supar/cmds/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/cmds/__pycache__/run.cpython-310.pyc b/tania_scripts/supar/cmds/__pycache__/run.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2391df2e6f812f40aa83be07b9bc14104795013a
Binary files /dev/null and b/tania_scripts/supar/cmds/__pycache__/run.cpython-310.pyc differ
diff --git a/tania_scripts/supar/cmds/const/aj.py b/tania_scripts/supar/cmds/const/aj.py
new file mode 100644
index 0000000000000000000000000000000000000000..7952168e66d32f58ac2e1e66442699b2f7a1fe2e
--- /dev/null
+++ b/tania_scripts/supar/cmds/const/aj.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import AttachJuxtaposeConstituencyParser
+from supar.cmds.run import init
+
+def main():
+    parser = argparse.ArgumentParser(description='Create AttachJuxtapose Constituency Parser.')
+    parser.set_defaults(Parser=AttachJuxtaposeConstituencyParser)
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
+    subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization')
+    subparser.add_argument('--delay', type=int, default=0)
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.pid', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/const/crf.py b/tania_scripts/supar/cmds/const/crf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3809317c47506101aebe6cc29fee25d7f8e5c07
--- /dev/null
+++ b/tania_scripts/supar/cmds/const/crf.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import CRFConstituencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create CRF Constituency Parser.')
+    parser.set_defaults(Parser=CRFConstituencyParser)
+    parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding')
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
+    subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.pid', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/const/sl.py b/tania_scripts/supar/cmds/const/sl.py
new file mode 100644
index 0000000000000000000000000000000000000000..97c9d47000b9c6a7465827cd9455f0d44e2503cc
--- /dev/null
+++ b/tania_scripts/supar/cmds/const/sl.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import SLConstituentParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create SL Constituency Parser.')
+    parser.set_defaults(Parser=SLConstituentParser)
+    parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding')
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization')
+    subparser.add_argument('--max_len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
+    subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization')
+    subparser.add_argument('--decoder', choices=['mlp', 'lstm'], default='mlp', help='incremental decoder to use')
+    subparser.add_argument('--codes', choices=['abs', 'rel'], default=None, help='sequential encoding used')
+    subparser.add_argument('--delay', type=int, default=0)
+    subparser.add_argument('--root_node', type=str, default='S')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.pid', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/const/tt.py b/tania_scripts/supar/cmds/const/tt.py
new file mode 100644
index 0000000000000000000000000000000000000000..286c402ec1d85ea7b6f648906438ed1a469551ee
--- /dev/null
+++ b/tania_scripts/supar/cmds/const/tt.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import TetraTaggingConstituencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create Tetra-tagging Constituency Parser.')
+    parser.set_defaults(Parser=TetraTaggingConstituencyParser)
+    parser.add_argument('--depth', default=8, type=int, help='stack depth')
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.pid', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/const/vi.py b/tania_scripts/supar/cmds/const/vi.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b63a3b3ed207feffe3939a09f026041a47a8520
--- /dev/null
+++ b/tania_scripts/supar/cmds/const/vi.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import VIConstituencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create Constituency Parser using Variational Inference.')
+    parser.set_defaults(Parser=VIConstituencyParser)
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.pid', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/dep/__pycache__/biaffine.cpython-310.pyc b/tania_scripts/supar/cmds/dep/__pycache__/biaffine.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d1e880285e3ad83ee46efd54d2f67696d3632f8
Binary files /dev/null and b/tania_scripts/supar/cmds/dep/__pycache__/biaffine.cpython-310.pyc differ
diff --git a/tania_scripts/supar/cmds/dep/__pycache__/eager.cpython-310.pyc b/tania_scripts/supar/cmds/dep/__pycache__/eager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad57c97660122f9abeadd858036cba5f3669e9cd
Binary files /dev/null and b/tania_scripts/supar/cmds/dep/__pycache__/eager.cpython-310.pyc differ
diff --git a/tania_scripts/supar/cmds/dep/__pycache__/sl.cpython-310.pyc b/tania_scripts/supar/cmds/dep/__pycache__/sl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce6e7c7bc3a266002f7b31faf3e1bfbd4c6f6aef
Binary files /dev/null and b/tania_scripts/supar/cmds/dep/__pycache__/sl.cpython-310.pyc differ
diff --git a/tania_scripts/supar/cmds/dep/biaffine.py b/tania_scripts/supar/cmds/dep/biaffine.py
new file mode 100644
index 0000000000000000000000000000000000000000..315ae6e89bf7e217c8f6c6ffc8f1af9c170b3b7b
--- /dev/null
+++ b/tania_scripts/supar/cmds/dep/biaffine.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import BiaffineDependencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create Biaffine Dependency Parser.')
+    parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness')
+    parser.add_argument('--proj', action='store_true', help='whether to projectivize the data')
+    parser.add_argument('--partial', action='store_true', help='whether partial annotation is included')
+    parser.set_defaults(Parser=BiaffineDependencyParser)
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/dep/crf.py b/tania_scripts/supar/cmds/dep/crf.py
new file mode 100644
index 0000000000000000000000000000000000000000..1229ae1fb03c7351a0f4f403b7319d2dcc840ec1
--- /dev/null
+++ b/tania_scripts/supar/cmds/dep/crf.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import CRFDependencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create first-order CRF Dependency Parser.')
+    parser.set_defaults(Parser=CRFDependencyParser)
+    parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding')
+    parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness')
+    parser.add_argument('--proj', action='store_true', help='whether to projectivize the data')
+    parser.add_argument('--partial', action='store_true', help='whether partial annotation is included')
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/dep/crf2o.py b/tania_scripts/supar/cmds/dep/crf2o.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc066ec057258b2c4c48c4d51b6b9bcf90044a7d
--- /dev/null
+++ b/tania_scripts/supar/cmds/dep/crf2o.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import CRF2oDependencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create second-order CRF Dependency Parser.')
+    parser.set_defaults(Parser=CRF2oDependencyParser)
+    parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding')
+    parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness')
+    parser.add_argument('--proj', action='store_true', help='whether to projectivize the data')
+    parser.add_argument('--partial', action='store_true', help='whether partial annotation is included')
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/dep/eager.py b/tania_scripts/supar/cmds/dep/eager.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ae87918887bedbf673cc9dafa5b99f9de5de3b
--- /dev/null
+++ b/tania_scripts/supar/cmds/dep/eager.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import ArcEagerDependencyParser
+from supar.cmds.run import init
+
+def main():
+    parser = argparse.ArgumentParser(description='Create Transition Dependency Parser.')
+    parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness')
+    parser.add_argument('--proj', action='store_true', help='whether to projectivize the data')
+    parser.add_argument('--partial', action='store_true', help='whether partial annotation is included')
+    parser.set_defaults(Parser=ArcEagerDependencyParser)
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
+    subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization')
+    subparser.add_argument('--decoder', choices=['mlp', 'lstm'], default='mlp', help='incremental decoder to use')
+    subparser.add_argument('--delay', type=int, default=0)
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/dep/sl.py b/tania_scripts/supar/cmds/dep/sl.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f852ba600e150a2a64b3354b1f416c86b91868
--- /dev/null
+++ b/tania_scripts/supar/cmds/dep/sl.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import SLDependencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create SL Dependency Parsing Parser.')
+    parser.set_defaults(Parser=SLDependencyParser)
+    parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness')
+    parser.add_argument('--proj', action='store_true', help='whether to projectivize the data')
+    parser.add_argument('--partial', action='store_true', help='whether partial annotation is included')
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization')
+    subparser.add_argument('--max_len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
+    subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization')
+    subparser.add_argument('--decoder', choices=['mlp', 'lstm'], default='mlp', help='incremental decoder to use')
+    subparser.add_argument('--codes', choices=['abs', 'rel', 'pos', '1p', '2p'], default=None, help='SL coding used')
+    subparser.add_argument('--delay', type=int, default=0)
+    subparser.add_argument('--root_node', type=str, default='S')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.pid', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/dep/vi.py b/tania_scripts/supar/cmds/dep/vi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1175977b10b39c97cd28321e947dac09c7211c7b
--- /dev/null
+++ b/tania_scripts/supar/cmds/dep/vi.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import VIDependencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create Dependency Parser using Variational Inference.')
+    parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness')
+    parser.add_argument('--proj', action='store_true', help='whether to projectivise the data')
+    parser.add_argument('--partial', action='store_true', help='whether partial annotation is included')
+    parser.set_defaults(Parser=VIDependencyParser)
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
+    subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
+    subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/run.py b/tania_scripts/supar/cmds/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..df93b7227bc82c3990b8bde877c23bb09e432dcd
--- /dev/null
+++ b/tania_scripts/supar/cmds/run.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+
+import os, shutil
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from supar.utils import Config
+from supar.utils.logging import init_logger, logger
+from supar.utils.parallel import get_device_count, get_free_port
+
+
+def init(parser):
+    parser.add_argument('--path', '-p', help='path to model file')
+    parser.add_argument('--conf', '-c', default='', help='path to config file')
+    parser.add_argument('--device', '-d', default='0', help='ID of GPU to use')
+    parser.add_argument('--seed', '-s', default=1, type=int, help='seed for generating random numbers')
+    parser.add_argument('--threads', '-t', default=16, type=int, help='num of threads')
+    parser.add_argument('--workers', '-w', default=0, type=int, help='num of processes used for data loading')
+    parser.add_argument('--cache', action='store_true', help='cache the data for fast loading')
+    parser.add_argument('--binarize', action='store_true', help='binarize the data first')
+    parser.add_argument('--amp', action='store_true', help='use automatic mixed precision for parsing')
+    parser.add_argument('--dist', choices=['ddp', 'fsdp'], default='ddp', help='distributed training types')
+    parser.add_argument('--wandb', action='store_true', help='wandb for tracking experiments')
+    args, unknown = parser.parse_known_args()
+    args, unknown = parser.parse_known_args(unknown, args)
+    args = Config.load(**vars(args), unknown=unknown)
+    
+    args.folder = '/'.join(args.path.split('/')[:-1])
+    if not os.path.exists(args.folder):
+        os.makedirs(args.folder)
+
+    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
+    if get_device_count() > 1:
+        os.environ['MASTER_ADDR'] = 'tcp://localhost'
+        os.environ['MASTER_PORT'] = get_free_port()
+        mp.spawn(parse, args=(args,), nprocs=get_device_count())
+    else:
+        parse(0 if torch.cuda.is_available() else -1, args)
+
+
+def parse(local_rank, args):
+    Parser = args.pop('Parser')
+    torch.set_num_threads(args.threads)
+    torch.manual_seed(args.seed)
+    if get_device_count() > 1:
+        dist.init_process_group(backend='nccl',
+                                init_method=f"{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}",
+                                world_size=get_device_count(),
+                                rank=local_rank)
+    torch.cuda.set_device(local_rank)
+    # init logger after dist has been initialized
+    init_logger(logger, f"{args.path}.{args.mode}.log", 'a' if args.get('checkpoint') else 'w')
+    logger.info('\n' + str(args))
+
+    args.local_rank = local_rank
+    os.environ['RANK'] = os.environ['LOCAL_RANK'] = f'{local_rank}'
+    if args.mode == 'train':
+        parser = Parser.load(**args) if args.checkpoint else Parser.build(**args)
+        parser.train(**args)
+    elif args.mode == 'evaluate':
+        parser = Parser.load(**args)
+        parser.evaluate(**args)
+    elif args.mode == 'predict':
+        parser = Parser.load(**args)
+        parser.predict(**args)
diff --git a/tania_scripts/supar/cmds/sdp/biaffine.py b/tania_scripts/supar/cmds/sdp/biaffine.py
new file mode 100644
index 0000000000000000000000000000000000000000..a36ab6a40c4ccba8abe47ea32732e3efc468c398
--- /dev/null
+++ b/tania_scripts/supar/cmds/sdp/biaffine.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import BiaffineSemanticDependencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create Biaffine Semantic Dependency Parser.')
+    parser.set_defaults(Parser=BiaffineSemanticDependencyParser)
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'lemma', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/sdp/DM/train.conllu', help='path to train file')
+    subparser.add_argument('--dev', default='data/sdp/DM/dev.conllu', help='path to dev file')
+    subparser.add_argument('--test', default='data/sdp/DM/test.conllu', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--n-embed-proj', default=125, type=int, help='dimension of projected embeddings')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.conllu', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/cmds/sdp/vi.py b/tania_scripts/supar/cmds/sdp/vi.py
new file mode 100644
index 0000000000000000000000000000000000000000..26fee77c40d1475e93f12c187ed66f8f3810aac1
--- /dev/null
+++ b/tania_scripts/supar/cmds/sdp/vi.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+
+from supar import VISemanticDependencyParser
+from supar.cmds.run import init
+
+
+def main():
+    parser = argparse.ArgumentParser(description='Create Semantic Dependency Parser using Variational Inference.')
+    parser.set_defaults(Parser=VISemanticDependencyParser)
+    subparsers = parser.add_subparsers(title='Commands', dest='mode')
+    # train
+    subparser = subparsers.add_parser('train', help='Train a parser.')
+    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'lemma', 'elmo', 'bert'], nargs='*', help='features to use')
+    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
+    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
+    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
+    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
+    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
+    subparser.add_argument('--train', default='data/sdp/DM/train.conllu', help='path to train file')
+    subparser.add_argument('--dev', default='data/sdp/DM/dev.conllu', help='path to dev file')
+    subparser.add_argument('--test', default='data/sdp/DM/test.conllu', help='path to test file')
+    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
+    subparser.add_argument('--n-embed-proj', default=125, type=int, help='dimension of projected embeddings')
+    subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
+    subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods')
+    # evaluate
+    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
+    # predict
+    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
+    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
+    subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
+    subparser.add_argument('--pred', default='pred.conllu', help='path to predicted result')
+    subparser.add_argument('--prob', action='store_true', help='whether to output probs')
+    init(parser)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tania_scripts/supar/codelin/__init__.py b/tania_scripts/supar/codelin/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..847de4fc4f73b6f0f337f356df357cbfbe570bf1
--- /dev/null
+++ b/tania_scripts/supar/codelin/__init__.py
@@ -0,0 +1,34 @@
+from .encs.enc_deps import D_NaiveAbsoluteEncoding, D_NaiveRelativeEncoding, D_PosBasedEncoding, D_BrkBasedEncoding, D_Brk2PBasedEncoding
+from .encs.enc_const import C_NaiveAbsoluteEncoding, C_NaiveRelativeEncoding
+from .utils.constants import D_2P_GREED, D_2P_PROP
+
+# import structures for encoding/decoding
+from .models.const_label import C_Label
+from .models.const_tree import C_Tree
+from .models.linearized_tree import LinearizedTree
+from .models.deps_label import D_Label
+from .models.deps_tree import D_Tree
+
+LABEL_SEPARATOR = '€'
+UNARY_JOINER = '@'
+
+def get_con_encoder(encoding: str, sep: str = LABEL_SEPARATOR, unary_joiner: str = UNARY_JOINER):
+    if encoding == 'abs':
+        return C_NaiveAbsoluteEncoding(sep, unary_joiner)
+    elif encoding == 'rel':
+        return C_NaiveRelativeEncoding(sep, unary_joiner)
+    return NotImplementedError
+
+def get_dep_encoder(encoding: str, sep: str, displacement: bool = False):
+    if encoding == 'abs':
+        return D_NaiveAbsoluteEncoding(sep)
+    elif encoding == 'rel':
+        return D_NaiveRelativeEncoding(sep, hang_from_root=False)
+    elif encoding == 'pos':
+        return D_PosBasedEncoding(sep)
+    elif encoding == '1p':
+        return D_BrkBasedEncoding(sep, displacement)
+    elif encoding == '2p':
+        return D_Brk2PBasedEncoding(sep, displacement, D_2P_PROP)
+    return NotImplementedError
+
diff --git a/tania_scripts/supar/codelin/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21b03f9cdf5de941345a04314d4e726119e9dfd2
Binary files /dev/null and b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..737b648aa6c9eb05010e8759390d493979dff7fc
Binary files /dev/null and b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-310.pyc b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..163660a55c0e295e147a3eb4645efc78e2a5e088
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-311.pyc b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7faea83a2faeba0be37ddc0648b2989d60fcc074
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/abstract_encoding.py b/tania_scripts/supar/codelin/encs/abstract_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e3fa035e7f5883068f9e641d164e366cc0c37a4
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/abstract_encoding.py
@@ -0,0 +1,37 @@
+from abc import ABC, abstractmethod
+
+class ADEncoding(ABC):
+    '''
+    Abstract class for Dependency Encodings
+    Sets the main constructor method and defines the methods
+        - Encode
+        - Decode
+    When adding a new Dependency Encoding it must extend this class
+    and implement those methods
+    '''
+    def __init__(self, separator):
+        self.separator = separator
+    
+    def encode(self, nodes):
+        pass
+    def decode(self, labels, postags, words):
+        pass
+
+class ACEncoding(ABC):
+    '''
+    Abstract class for Constituent Encodings
+    Sets the main constructor method and defines the abstract methods
+        - Encode
+        - Decode
+    When adding a new Constituent Encoding it must extend this class
+    and implement those methods
+    '''
+    def __init__(self, separator, ujoiner):
+        self.separator = separator
+        self.unary_joiner = ujoiner
+    
+    def encode(self, constituent_tree):
+        pass
+    def decode(self, linearized_tree):
+        pass
+
diff --git a/tania_scripts/supar/codelin/encs/constituent.py b/tania_scripts/supar/codelin/encs/constituent.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dbe8f0b18270ebf769078c825e3441e22acd766
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/constituent.py
@@ -0,0 +1,115 @@
+from src.models.linearized_tree import LinearizedTree
+from src.encs.enc_const import *
+from src.utils.extract_feats import extract_features_const
+from src.utils.constants import C_INCREMENTAL_ENCODING, C_ABSOLUTE_ENCODING, C_RELATIVE_ENCODING, C_DYNAMIC_ENCODING
+
+import stanza.pipeline
+from supar.codelin.models.const_tree import C_Tree
+
+
+## Encoding and decoding
+
+def encode_constituent(in_path, out_path, encoding_type, separator, unary_joiner, features):
+    '''
+    Encodes the selected file according to the specified parameters:
+    :param in_path: Path of the file to be encoded
+    :param out_path: Path where to write the encoded labels
+    :param encoding_type: Encoding used
+    :param separator: string used to separate label fields
+    :param unary_joiner: string used to separate nodes from unary chains
+    :param features: features to add as columns to the labels file
+    '''
+
+    if encoding_type == C_ABSOLUTE_ENCODING:
+            encoder = C_NaiveAbsoluteEncoding(separator, unary_joiner)
+    elif encoding_type == C_RELATIVE_ENCODING:
+            encoder = C_NaiveRelativeEncoding(separator, unary_joiner)
+    elif encoding_type == C_DYNAMIC_ENCODING:
+            encoder = C_NaiveDynamicEncoding(separator, unary_joiner)
+    elif encoding_type == C_INCREMENTAL_ENCODING:
+            encoder = C_NaiveIncrementalEncoding(separator, unary_joiner)
+    else:
+        raise Exception("Unknown encoding type")
+
+    # build feature index dictionary
+    f_idx_dict = {}
+    if features:
+        if features == ["ALL"]:
+            features = extract_features_const(in_path)
+        i=0
+        for f in features:
+            f_idx_dict[f]=i
+            i+=1
+
+    file_out = open(out_path, "w")
+    file_in = open(in_path, "r")
+
+    tree_counter = 0
+    labels_counter = 0
+    label_set = set()
+
+    for line in file_in:
+        line = line.rstrip()
+        tree = C_Tree.from_string(line)
+        linearized_tree = encoder.encode(tree)
+        file_out.write(linearized_tree.to_string(f_idx_dict))
+        file_out.write("\n")
+        tree_counter += 1
+        labels_counter += len(linearized_tree)
+        for lbl in linearized_tree.get_labels():
+            label_set.add(str(lbl))   
+    
+    return labels_counter, tree_counter, len(label_set)
+
+def decode_constituent(in_path, out_path, encoding_type, separator, unary_joiner, conflicts, nulls, postags, lang):
+    '''
+    Decodes the selected file according to the specified parameters:
+    :param in_path: Path of the labels file to be decoded
+    :param out_path: Path where to write the decoded tree
+    :param encoding_type: Encoding used
+    :param separator: string used to separate label fields
+    :param unary_joiner: string used to separate nodes from unary chains
+    :param conflicts: conflict resolution heuristics to apply
+    '''
+
+    if encoding_type == C_ABSOLUTE_ENCODING:
+            decoder = C_NaiveAbsoluteEncoding(separator, unary_joiner)
+    elif encoding_type == C_RELATIVE_ENCODING:
+            decoder = C_NaiveRelativeEncoding(separator, unary_joiner)
+    elif encoding_type == C_DYNAMIC_ENCODING:
+            decoder = C_NaiveDynamicEncoding(separator, unary_joiner)
+    elif encoding_type == C_INCREMENTAL_ENCODING:
+            decoder = C_NaiveIncrementalEncoding(separator, unary_joiner)
+    else:
+        raise Exception("Unknown encoding type")
+
+    if postags:
+        stanza.download(lang=lang)
+        nlp = stanza.Pipeline(lang=lang, processors='tokenize, pos')
+
+    f_in = open(in_path)
+    f_out = open(out_path,"w+")
+    
+    tree_string   = ""
+    labels_counter = 0
+    tree_counter = 0
+
+    for line in f_in:
+        if line == "\n":
+            tree_string = tree_string.rstrip()
+            current_tree = LinearizedTree.from_string(tree_string, mode="CONST", separator=separator, unary_joiner=unary_joiner)
+            
+            if postags:
+                c_tags = nlp(current_tree.get_sentence())
+                current_tree.set_postags([word.pos for word in c_tags])
+
+            decoded_tree = decoder.decode(current_tree)
+            decoded_tree = decoded_tree.postprocess_tree(conflicts, nulls)
+
+            f_out.write(str(decoded_tree).replace('\n','')+'\n')
+            tree_string   = ""
+            tree_counter+=1
+        tree_string += line
+        labels_counter += 1
+    
+    return tree_counter, labels_counter
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/dependency.py b/tania_scripts/supar/codelin/encs/dependency.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2a49d41dbc2f038812142ee05741099e8460a28
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/dependency.py
@@ -0,0 +1,129 @@
+import stanza
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.models.deps_label import D_Label
+from supar.codelin.utils.extract_feats import extract_features_deps
+from supar.codelin.encs.enc_deps import *
+from supar.codelin.utils.constants import *
+from supar.codelin.models.deps_tree import D_Tree
+
+# Encoding
+def encode_dependencies(in_path, out_path, encoding_type, separator, displacement, planar_alg, root_enc, features):
+    '''
+    Encodes the selected file according to the specified parameters:
+    :param in_path: Path of the file to be encoded
+    :param out_path: Path where to write the encoded labels
+    :param encoding_type: Encoding used
+    :param separator: string used to separate label fields
+    :param displacement: boolean to indicate if use displacement in bracket based encodings
+    :param planar_alg: string used to choose the plane separation algorithm
+    :param features: features to add as columns to the labels file
+    '''
+
+    # Create the encoder
+    if encoding_type == D_ABSOLUTE_ENCODING:
+            encoder = D_NaiveAbsoluteEncoding(separator)
+    elif encoding_type == D_RELATIVE_ENCODING:
+            encoder = D_NaiveRelativeEncoding(separator, root_enc)
+    elif encoding_type == D_POS_ENCODING:
+            encoder = D_PosBasedEncoding(separator)
+    elif encoding_type == D_BRACKET_ENCODING:
+            encoder = D_BrkBasedEncoding(separator, displacement)
+    elif encoding_type == D_BRACKET_ENCODING_2P:
+            encoder = D_Brk2PBasedEncoding(separator, displacement, planar_alg)
+    else:
+        raise Exception("Unknown encoding type")
+    
+    f_idx_dict = {}
+    if features:
+        if features == ["ALL"]:
+            features = extract_features_deps(in_path)
+        i=0
+        for f in features:
+            f_idx_dict[f]=i
+            i+=1
+
+    file_out = open(out_path,"w+")
+    label_set = set()
+    tree_counter = 0
+    label_counter = 0
+    trees = D_Tree.read_conllu_file(in_path, filter_projective=False)
+    
+    for t in trees:
+        # encode labels
+        linearized_tree = encoder.encode(t)        
+        file_out.write(linearized_tree.to_string(f_idx_dict))
+        file_out.write("\n")
+        
+        tree_counter+=1
+        label_counter+=len(linearized_tree)
+        
+        for lbl in linearized_tree.get_labels():
+            label_set.add(str(lbl))      
+    
+    print('/supar/codelin/encs/dependency.py')
+    return tree_counter, label_counter, len(label_set)
+
+# Decoding
+
+def decode_dependencies(in_path, out_path, encoding_type, separator, displacement, multiroot, root_search, root_enc, postags, lang):
+    '''
+    Decodes the selected file according to the specified parameters:
+    :param in_path: Path of the file to be encoded
+    :param out_path: Path where to write the encoded labels
+    :param encoding_type: Encoding used
+    :param separator: string used to separate label fields
+    :param displacement: boolean to indicate if use displacement in bracket based encodings
+    :param multiroot: boolean to indicate if multiroot conll trees are allowed
+    :param root_search: strategy to select how to search the root if no root found in decoded tree
+    '''
+
+    if encoding_type == D_ABSOLUTE_ENCODING:
+        decoder = D_NaiveAbsoluteEncoding(separator)
+    elif encoding_type == D_RELATIVE_ENCODING:
+        decoder = D_NaiveRelativeEncoding(separator, root_enc)
+    elif encoding_type == D_POS_ENCODING:
+        decoder = D_PosBasedEncoding(separator)
+    elif encoding_type == D_BRACKET_ENCODING:
+        decoder = D_BrkBasedEncoding(separator, displacement)
+    elif encoding_type == D_BRACKET_ENCODING_2P:
+        decoder = D_Brk2PBasedEncoding(separator, displacement, None)
+    else:
+        raise Exception("Unknown encoding type")
+
+    f_in=open(in_path)
+    f_out=open(out_path,"w+")
+
+    tree_counter=0
+    labels_counter=0
+
+    tree_string = ""
+    
+    print('/supar/codelin/encs/dependency.py 101')
+    
+    if postags:
+        stanza.download(lang=lang)
+        nlp = stanza.Pipeline(lang=lang, processors='tokenize,pos')
+        
+
+    for line in f_in:
+        if line == "\n":
+            tree_string = tree_string.rstrip()
+            current_tree = LinearizedTree.from_string(tree_string, mode="DEPS", separator=separator)
+            
+            if postags:
+                c_tags = nlp(current_tree.get_sentence())
+                current_tree.set_postags([word.pos for word in c_tags])
+
+            decoded_tree = decoder.decode(current_tree)
+            decoded_tree.postprocess_tree(root_search, multiroot)
+            f_out.write("# text = "+decoded_tree.get_sentence()+"\n")
+            f_out.write(str(decoded_tree))
+            
+            tree_string = ""
+            tree_counter+=1
+        
+        tree_string += line
+        labels_counter += 1
+
+
+    return tree_counter, labels_counter
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__init__.py b/tania_scripts/supar/codelin/encs/enc_const/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d803a3b0a700fe22bdc565435e46c84b885edf38
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_const/__init__.py
@@ -0,0 +1,4 @@
+from .naive_absolute import C_NaiveAbsoluteEncoding
+from .naive_relative import C_NaiveRelativeEncoding
+from .naive_dynamic import C_NaiveDynamicEncoding
+from .naive_incremental import C_NaiveIncrementalEncoding
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41d16cbfa21181d065d0c1f0e255961bc5fc54b8
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12f0189eb2fabe6bb71d622e2452f312b08a7c8c
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e654e52ff2f574d7c86285d440042167471c068d
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb3ffaca5b999a39848ef14e042cb37fd4f4ebae
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d6495f2c7305a9257f7a536493cc1034c193804
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30ad5b6badac7cd3836fb4078ae00b30e8c29bd2
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0df5670f680c988a0e1be47f68f977f4e1915c30
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f6374cacf05ac756434ee357b9b18a208fcc681
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58ce1fee263b5b20fbc8fe069e0ca2f9155706e4
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b42a4c47d2c9a3e7745917396a876bf0af990d50
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_absolute.py b/tania_scripts/supar/codelin/encs/enc_const/naive_absolute.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dbb9844109e33406fb7056bdfe3e4e6e2e07c17
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_const/naive_absolute.py
@@ -0,0 +1,144 @@
+from supar.codelin.encs.abstract_encoding import ACEncoding
+from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL
+from supar.codelin.models.const_label import C_Label
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.models.const_tree import C_Tree
+
+import re
+
+class C_NaiveAbsoluteEncoding(ACEncoding):
+    def __init__(self, separator, unary_joiner):
+        self.separator = separator
+        self.unary_joiner = unary_joiner
+
+    def __str__(self):
+        return "Constituent Naive Absolute Encoding"
+
+    def encode(self, constituent_tree):        
+        lc_tree = LinearizedTree.empty_tree()
+        leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner)
+        
+        for i in range(0, len(leaf_paths)-1):
+            path_a = leaf_paths[i]
+            path_b = leaf_paths[i+1]
+            
+            last_common = ""
+            n_commons   = 0
+            for a,b in zip(path_a, path_b):
+
+                if (a!=b):
+                    # Remove the digits and aditional feats in the last common node
+                    last_common = re.sub(r'[0-9]+', '', last_common)
+                    last_common = last_common.split("##")[0]
+
+                    # Get word and POS tag
+                    word = path_a[-1]
+                    postag = path_a[-2]
+                    
+                    # Build the Leaf Unary Chain
+                    unary_chain = None
+                    leaf_unary_chain = postag.split(self.unary_joiner)
+                    if len(leaf_unary_chain)>1:
+                        unary_list = []
+                        for element in leaf_unary_chain[:-1]:
+                            unary_list.append(element.split("##")[0])
+
+                        unary_chain = self.unary_joiner.join(unary_list)
+                        postag = leaf_unary_chain[len(leaf_unary_chain)-1]
+                    
+                    # Clean the POS Tag and extract additional features
+                    postag_split = postag.split("##")
+                    feats = [None]
+
+                    if len(postag_split) > 1:
+                        postag = re.sub(r'[0-9]+', '', postag_split[0])
+                        feats = postag_split[1].split("|")
+                    else:
+                        postag = re.sub(r'[0-9]+', '', postag)
+
+                    c_label= C_Label(n_commons, last_common, unary_chain, C_ABSOLUTE_ENCODING, 
+                                                   self.separator, self.unary_joiner)
+                    
+                    # Append the data
+                    lc_tree.add_row(word, postag, feats, c_label)
+                
+                    break
+                
+                n_commons  += len(a.split(self.unary_joiner))
+                last_common = a
+            
+        # n = max number of features of the tree
+        lc_tree.n = max([len(f) for f in lc_tree.additional_feats])
+        return lc_tree
+
+    def decode(self, linearized_tree):
+        # Check valid labels
+        if not linearized_tree:
+            print("[!] Error while decoding: Null tree.")
+            return
+
+        # Create constituent tree
+        tree = C_Tree(C_ROOT_LABEL, [])
+        current_level = tree
+
+        old_n_commons = 0
+        old_level = None
+        for word, postag, feats, label in linearized_tree.iterrows():
+
+            # Descend through the tree until reach the level indicated by last_common
+            current_level = tree
+            for level_index in range(label.n_commons):
+                if (current_level.is_terminal()) or (level_index >= old_n_commons):
+                    current_level.add_child(C_Tree(C_NONE_LABEL, []))
+                
+                current_level = current_level.r_child()
+
+            # Split the Last Common field of the Label in case it has a Unary Chain Collapsed
+            label.last_common = label.last_common.split(self.unary_joiner)
+
+            if len(label.last_common) == 1:
+                # If current level has no label yet, put the label
+                # If current level has label but different than this one, set it as a conflict
+                if (current_level.label == C_NONE_LABEL):
+                    current_level.label = label.last_common[0].rstrip()
+                else:
+                    current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0]
+            if len(label.last_common)>1:
+                current_level = tree
+                
+                # problem when n_commons predicted is LESS than the number of last commons predicted
+                descend_levels = label.n_commons - (len(label.last_common)) + 1
+                for level_index in range(descend_levels):
+                    current_level = current_level.r_child()
+                for i in range(len(label.last_common)-1):
+                    if (current_level.label == C_NONE_LABEL):
+                        current_level.label = label.last_common[i]
+                    else:
+                        current_level.label = current_level.label+C_CONFLICT_SEPARATOR+label.last_common[i]
+
+                    if len(current_level.children)>0:
+                        current_level = current_level.r_child()
+
+                # If we reach a POS tag, set it as child of the current chain
+                if current_level.is_preterminal():
+                    temp_current_level_children = current_level.children
+                    current_level.label = label.last_common[i+1]
+                    current_level.children = temp_current_level_children
+                    for c in temp_current_level_children:
+                        c.parent = current_level
+                else:
+                    current_level.label=label.last_common[i+1]
+            
+            
+            # Fill POS tag in this node or previous one
+            if (label.n_commons >= old_n_commons):
+                current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+            else:
+                old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+
+            old_n_commons=label.n_commons
+            old_level=current_level
+
+        tree.inherit_tree()
+        
+        return tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_dynamic.py b/tania_scripts/supar/codelin/encs/enc_const/naive_dynamic.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4dc11f5ce16a283057c0852ca4062d21e92e8b6
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_const/naive_dynamic.py
@@ -0,0 +1,166 @@
+from supar.codelin.encs.abstract_encoding import ACEncoding
+from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_RELATIVE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL
+from supar.codelin.models.const_label import C_Label
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.models.const_tree import C_Tree
+
+import re
+
+class C_NaiveDynamicEncoding(ACEncoding):
+    def __init__(self, separator, unary_joiner):
+        self.separator = separator
+        self.unary_joiner = unary_joiner
+
+    def __str__(self):
+        return "Constituent Naive Dynamic Encoding"
+
+    def encode(self, constituent_tree):
+        lc_tree = LinearizedTree.empty_tree()
+        leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner)
+
+        last_n_common=0
+        for i in range(0, len(leaf_paths)-1):
+            path_a=leaf_paths[i]
+            path_b=leaf_paths[i+1]
+            
+            last_common=""
+            n_commons=0
+            for a,b in zip(path_a, path_b):
+
+                if (a!=b):
+                    # Remove the digits and aditional feats in the last common node
+                    last_common = re.sub(r'[0-9]+', '', last_common)
+                    last_common = last_common.split("##")[0]
+
+                    # Get word and POS tag
+                    word = path_a[-1]
+                    postag = path_a[-2]
+                    
+                    # Build the Leaf Unary Chain
+                    unary_chain = None
+                    leaf_unary_chain = postag.split(self.unary_joiner)
+                    if len(leaf_unary_chain)>1:
+                        unary_list = []
+                        for element in leaf_unary_chain[:-1]:
+                            unary_list.append(element.split("##")[0])
+
+                        unary_chain = self.unary_joiner.join(unary_list)
+                        postag = leaf_unary_chain[len(leaf_unary_chain)-1]
+                    
+                    # Clean the POS Tag and extract additional features
+                    postag_split = postag.split("##")
+                    feats = [None]
+
+                    if len(postag_split) > 1:
+                        postag = re.sub(r'[0-9]+', '', postag_split[0])
+                        feats = postag_split[1].split("|")
+                    else:
+                        postag = re.sub(r'[0-9]+', '', postag)
+
+                    # Compute the encoded value
+                    abs_val=n_commons
+                    rel_val=(n_commons-last_n_common)
+
+                    if (abs_val<=3 and rel_val<=-2):
+                        c_label = (C_Label(abs_val, last_common, unary_chain, C_ABSOLUTE_ENCODING, self.separator, self.unary_joiner))
+                    else:
+                        c_label = (C_Label(rel_val, last_common, unary_chain, C_RELATIVE_ENCODING, self.separator, self.unary_joiner))
+                    
+                    lc_tree.add_row(word, postag, feats, c_label)
+
+                    last_n_common=n_commons
+                    break
+                
+                # Store Last Common and increase n_commons 
+                # Note: When increasing n_commons use the number from split the collapsed chains
+                n_commons += len(a.split(self.unary_joiner))
+                last_common = a
+        
+        # n = max number of features of the tree
+        lc_tree.n = max([len(f) for f in lc_tree.additional_feats])
+        return lc_tree
+
+    def decode(self, linearized_tree):
+        # Check valid labels 
+        if not linearized_tree:
+            print("[*] Error while decoding: Null tree.")
+            return
+        
+        # Create constituent tree
+        tree = C_Tree(C_ROOT_LABEL, [])
+        current_level = tree
+
+        old_n_commons=0
+        old_last_common=''
+        old_level=None
+        is_first = True
+        last_label = None
+
+        for word, postag, feats, label in linearized_tree.iterrows():
+            
+            # Convert the labels to absolute scale
+            if last_label!=None and label.encoding_type==C_RELATIVE_ENCODING:
+                label.to_absolute(last_label)
+            
+            # First label must have a positive n_commons value
+            if is_first and label.n_commons <= 0:
+                label.n_commons = 1
+
+            # Descend through the tree until reach the level indicated by last_common
+            current_level = tree
+            for level_index in range(label.n_commons):
+                if (len(current_level.children)==0) or (level_index >= old_n_commons):
+                    current_level.add_child(C_Tree(C_NONE_LABEL, []))
+                
+                current_level = current_level.r_child()
+
+            # Split the Last Common field of the Label in case it has a Unary Chain Collapsed
+            label.last_common = label.last_common.split(self.unary_joiner)
+
+            if len(label.last_common)==1:
+                # If current level has no label yet, put the label
+                # If current level has label but different than this one, set it as a conflict
+                if (current_level.label==C_NONE_LABEL):
+                    current_level.label = label.last_common[0].rstrip()
+                else:
+                    current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0]
+            if len(label.last_common)>1:
+                current_level = tree
+                
+                # problem when n_commons predicted is LESS than the number of last commons predicted
+                descend_levels = label.n_commons - (len(label.last_common)) + 1
+                for level_index in range(descend_levels):
+                    current_level = current_level.r_child()
+                for i in range(len(label.last_common)-1):
+                    if (current_level.label == C_NONE_LABEL):
+                        current_level.label = label.last_common[i]
+                    else:
+                        current_level.label = current_level.label+C_CONFLICT_SEPARATOR+label.last_common[i]
+
+                    if len(current_level.children)>0:
+                        current_level = current_level.r_child()
+
+                # If we reach a POS tag, set it as child of the current chain
+                if current_level.is_preterminal():
+                    temp_current_level_children = current_level.children
+                    current_level.label = label.last_common[i+1]
+                    current_level.children = temp_current_level_children
+                    for c in temp_current_level_children:
+                        c.parent = current_level
+                else:
+                    current_level.label=label.last_common[i+1]
+            
+            
+            # Fill POS tag in this node or previous one
+            if (label.n_commons >= old_n_commons):
+                current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+            else:
+                old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+
+            old_n_commons=label.n_commons
+            old_level=current_level
+            last_label=label
+        
+        tree.inherit_tree()
+        
+        return tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_incremental.py b/tania_scripts/supar/codelin/encs/enc_const/naive_incremental.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ceb06b39916535b0e15110abf66fe7d5997387f
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_const/naive_incremental.py
@@ -0,0 +1,160 @@
+from supar.codelin.encs.abstract_encoding import ACEncoding
+from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL, C_DUMMY_END
+from supar.codelin.models.const_label import C_Label
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.models.const_tree import C_Tree
+
+import re
+
+class C_NaiveIncrementalEncoding(ACEncoding):
+    def __init__(self, separator, unary_joiner):
+        self.separator = separator
+        self.unary_joiner = unary_joiner
+
+    def __str__(self):
+        return "Constituent Naive Incremental Encoding"
+
+    def get_unary_chain(self, postag):
+        unary_chain = None
+        leaf_unary_chain = postag.split(self.unary_joiner)
+
+        if len(leaf_unary_chain)>1:
+            unary_list = []
+            for element in leaf_unary_chain[:-1]:
+                unary_list.append(element.split("##")[0])
+
+            unary_chain = self.unary_joiner.join(unary_list)
+            postag = leaf_unary_chain[len(leaf_unary_chain)-1]
+        
+        return unary_chain, postag
+    
+    def get_features(self, node, feature_marker="##", feature_splitter="|"):
+        postag_split = node.split(feature_marker)
+        feats = None
+
+        if len(postag_split) > 1:
+            postag = re.sub(r'[0-9]+', '', postag_split[0])
+            feats = postag_split[1].split(feature_splitter)
+        else:
+            postag = re.sub(r'[0-9]+', '', node)
+        return postag, feats
+    
+    def clean_last_common(self, node, feature_marker="##"):
+        node = re.sub(r'[0-9]+', '', node)
+        last_common = node.split(feature_marker)[0]
+        return last_common
+
+    def encode(self, constituent_tree):
+        constituent_tree.reverse_tree()
+        leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner)
+        lc_tree = LinearizedTree.empty_tree()
+
+        for i in range(1, len(leaf_paths)):
+            path_a = leaf_paths[i-1]
+            path_b = leaf_paths[i]
+            
+            last_common = ""
+            n_commons   = 0
+
+            for a,b in zip(path_a, path_b):
+                if (a!=b):
+                    # Remove the digits and aditional feats in the last common node
+                    last_common = self.clean_last_common(last_common)
+
+                    # Get word and POS tag
+                    word   = path_a[-1]
+                    postag = path_a[-2]
+                    
+                    # Build the Leaf Unary Chain
+                    unary_chain, postag = self.get_unary_chain(postag)
+                    
+                    # Clean the POS Tag and extract additional features
+                    postag, feats = self.get_features(postag)
+
+                    # Append the data
+                    c_label = (C_Label(n_commons, last_common, unary_chain, C_ABSOLUTE_ENCODING, self.separator, self.unary_joiner))
+                    lc_tree.add_row(word, postag, feats, c_label)
+
+                    break
+                
+                # Store Last Common and increase n_commons 
+                # Note: When increasing n_commons use the number from split the collapsed chains
+                n_commons  += len(a.split(self.unary_joiner))
+                last_common = a
+        
+        # reverse and return
+        lc_tree.reverse_tree(ignore_bos_eos=False)
+        return lc_tree
+
+    def decode(self, linearized_tree):
+        # Check valid labels 
+        if not linearized_tree:
+            print("[*] Error while decoding: Null tree.")
+            return
+
+        # Create constituent tree
+        tree = C_Tree(C_ROOT_LABEL, [])
+        current_level = tree
+
+        old_n_commons=0
+        old_level=None
+
+        linearized_tree.reverse_tree(ignore_bos_eos=False)
+        for word, postag, feats, label in linearized_tree.iterrows():
+            
+            # Descend through the tree until reach the level indicated by last_common
+            current_level = tree
+            for level_index in range(label.n_commons):
+                if (current_level.is_terminal()) or (level_index >= old_n_commons):
+                    current_level.add_child(C_Tree(C_NONE_LABEL, []))
+                
+                current_level = current_level.r_child()
+
+            # Split the Last Common field of the Label in case it has a Unary Chain Collapsed
+            label.last_common = label.last_common.split(self.unary_joiner)
+
+            if len(label.last_common) == 1:
+                # If current level has no label yet, put the label
+                # If current level has label but different than this one, set it as a conflict
+                if (current_level.label == C_NONE_LABEL):
+                    current_level.label = label.last_common[0]
+                else:
+                    current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0]
+            else:
+                current_level = tree
+                
+                # Descend to the beginning of the Unary Chain and fill it
+                descend_levels = label.n_commons - (len(label.last_common)) + 1
+                
+                for level_index in range(descend_levels):
+                    current_level = current_level.r_child()
+                
+                for i in range(len(label.last_common)-1):
+                    if (current_level.label == C_NONE_LABEL):
+                        current_level.label = label.last_common[i]
+                    else:
+                        current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[i]
+                    current_level = current_level.r_child()
+
+                # If we reach a POS tag, set it as child of the current chain
+                if current_level.is_preterminal():
+                    temp_current_level = current_level
+                    current_level.label = label.last_common[i+1]
+                    current_level.children = [temp_current_level]
+                
+                else:
+                    current_level.label=label.last_common[i+1]
+            
+            # Fill POS tag in this node or previous one
+            if (label.n_commons >= old_n_commons):
+                current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+            
+            else:
+                old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+
+            old_n_commons=label.n_commons
+            old_level=current_level
+
+        tree.inherit_tree()
+        tree.reverse_tree()
+        return tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_relative.py b/tania_scripts/supar/codelin/encs/enc_const/naive_relative.py
new file mode 100644
index 0000000000000000000000000000000000000000..0794e3be915243487d985a33af95a3be427552e4
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_const/naive_relative.py
@@ -0,0 +1,152 @@
+from supar.codelin.encs.abstract_encoding import ACEncoding
+from supar.codelin.utils.constants import C_RELATIVE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL
+from supar.codelin.models.const_label import C_Label
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.models.const_tree import C_Tree
+
+import re
+
+class C_NaiveRelativeEncoding(ACEncoding):
+    def __init__(self, separator, unary_joiner):
+        self.separator = separator
+        self.unary_joiner = unary_joiner
+
+    def __str__(self):
+        return "Constituent Naive Relative Encoding"
+
+    def encode(self, constituent_tree):
+        leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner)
+        lc_tree = LinearizedTree.empty_tree()
+
+        last_n_common=0
+        for i in range(0, len(leaf_paths)-1):
+            path_a=leaf_paths[i]
+            path_b=leaf_paths[i+1]
+            
+            last_common=""
+            n_commons=0
+            for a,b in zip(path_a, path_b):
+
+                if (a!=b):
+                    # Remove the digits and aditional feats in the last common node
+                    last_common = re.sub(r'[0-9]+', '', last_common)
+                    last_common = last_common.split("##")[0]
+
+                    # Get word and POS tag
+                    word = path_a[-1]
+                    postag = path_a[-2]
+                    
+                    # Build the Leaf Unary Chain
+                    unary_chain = None
+                    leaf_unary_chain = postag.split(self.unary_joiner)
+                    if len(leaf_unary_chain)>1:
+                        unary_list = []
+                        for element in leaf_unary_chain[:-1]:
+                            unary_list.append(element.split("##")[0])
+
+                        unary_chain = self.unary_joiner.join(unary_list)
+                        postag = leaf_unary_chain[len(leaf_unary_chain)-1]
+                    
+                    # Clean the POS Tag and extract additional features
+                    postag_split = postag.split("##")
+                    feats = None
+
+                    if len(postag_split) > 1:
+                        postag = re.sub(r'[0-9]+', '', postag_split[0])
+                        feats = postag_split[1].split("|")
+                    else:
+                        postag = re.sub(r'[0-9]+', '', postag)
+
+                    c_label = C_Label((n_commons-last_n_common), last_common, unary_chain, C_RELATIVE_ENCODING, self.separator, self.unary_joiner)
+                    lc_tree.add_row(word, postag, feats, c_label)
+
+                    last_n_common=n_commons
+                    break
+                
+                # Store Last Common and increase n_commons 
+                # Note: When increasing n_commons use the number from split the collapsed chains
+                n_commons += len(a.split(self.unary_joiner))
+                last_common = a
+        
+        return lc_tree
+
+    def decode(self, linearized_tree):
+        # Check valid labels 
+        if not linearized_tree:
+            print("[*] Error while decoding: Null tree.")
+            return
+        
+        # Create constituent tree
+        tree = C_Tree(C_ROOT_LABEL, [])
+        current_level = tree
+
+        old_n_commons=0
+        old_level=None
+
+        is_first = True
+        last_label = None
+
+        for word, postag, feats, label in linearized_tree.iterrows():
+            # Convert the labels to absolute scale
+            if last_label!=None:
+                label.to_absolute(last_label)
+            
+            # First label must have a positive n_commons value
+            if is_first and label.n_commons <= 0:
+                label.n_commons = 1
+            
+            # Descend through the tree until reach the level indicated by last_common
+            current_level = tree
+            for level_index in range(label.n_commons):
+                if (current_level.is_terminal()) or (level_index >= old_n_commons):
+                    current_level.add_child(C_Tree(C_NONE_LABEL, []))
+                current_level = current_level.r_child()
+            
+            # Split the Last Common field of the Label in case it has a Unary Chain Collapsed
+            label.last_common = label.last_common.split(self.unary_joiner)            
+
+            if len(label.last_common)==1:
+                if (current_level.label==C_NONE_LABEL):
+                    current_level.label=label.last_common[0].rstrip()
+                else:
+                    current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0]
+            
+            if len(label.last_common)>1:
+                current_level = tree
+                
+                # problem when n_commons predicted is LESS than the number of last commons predicted
+                descend_levels = label.n_commons - (len(label.last_common)) + 1
+                for level_index in range(descend_levels):
+                    current_level = current_level.r_child()
+                for i in range(len(label.last_common)-1):
+                    if (current_level.label == C_NONE_LABEL):
+                        current_level.label = label.last_common[i]
+                    else:
+                        current_level.label = current_level.label+C_CONFLICT_SEPARATOR+label.last_common[i]
+
+                    if len(current_level.children)>0:
+                        current_level = current_level.r_child()
+
+                # If we reach a POS tag, set it as child of the current chain
+                if current_level.is_preterminal():
+                    temp_current_level_children = current_level.children
+                    current_level.label = label.last_common[i+1]
+                    current_level.children = temp_current_level_children
+                    for c in temp_current_level_children:
+                        c.parent = current_level
+                else:
+                    current_level.label=label.last_common[i+1]
+            
+            # Fill POS tag in this node or previous one
+            if (label.n_commons >= old_n_commons):
+                current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+            else:
+                old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
+
+            old_n_commons=label.n_commons
+            old_level=current_level
+            
+            last_label=label
+        
+        tree.inherit_tree()
+        return tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__init__.py b/tania_scripts/supar/codelin/encs/enc_deps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dd652c52beb69640767bbcf8201eaa1800818c0
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_deps/__init__.py
@@ -0,0 +1,6 @@
+## Dependency Encodings Package
+from .naive_absolute import D_NaiveAbsoluteEncoding
+from .naive_relative import D_NaiveRelativeEncoding
+from .brk_based import D_BrkBasedEncoding
+from .pos_based import D_PosBasedEncoding
+from .brk2p_based import D_Brk2PBasedEncoding
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5d526f8782acc6ddb352bc54ce8ef508f64fe80
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ce46dfd865d40f78b478a5a314786080bc9266c
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50aeb39bab298e183c8d9afec05cac65fb3f9a5e
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbbe00aafec12095b89292d53a158512e1969458
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08db178e6917bb38a99b0d745ffe3f08121eeeec
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20c2dea9b223e0a9c2a7ac3e40c647b6abc9ca0a
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d7369766076aec1dad859894574c0db23c85e29
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c239f14078cb1d2da6df29f582713a3c14d788e
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bda4d2e8d01c87252607b4b5da3424fdb7ff31f0
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..269a35bd9bb6605006c7431ea0d320ed293302d0
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82512a33f500fe06e096da4cc221deeac69f42c8
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..741e084997009584513b4b9673da439210e63bce
Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/brk2p_based.py b/tania_scripts/supar/codelin/encs/enc_deps/brk2p_based.py
new file mode 100644
index 0000000000000000000000000000000000000000..665e9e6c61fd6a729d023349d203f87ae3765546
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_deps/brk2p_based.py
@@ -0,0 +1,215 @@
+from supar.codelin.encs.abstract_encoding import ADEncoding
+from supar.codelin.utils.constants import D_2P_GREED, D_2P_PROP, D_NONE_LABEL
+from supar.codelin.models.deps_label import D_Label
+from supar.codelin.models.deps_tree import D_Tree
+from supar.codelin.models.linearized_tree import LinearizedTree
+
+
+class D_Brk2PBasedEncoding(ADEncoding):
+    def __init__(self, separator, displacement, planar_alg):
+        if planar_alg and planar_alg not in [D_2P_GREED, D_2P_PROP]:
+            print("[*] Error: Unknown planar separation algorithm")
+            exit(1)
+        super().__init__(separator)
+        self.displacement = displacement
+        self.planar_alg = planar_alg
+
+    def __str__(self):
+        return "Dependency 2-Planar Bracketing Based Encoding"
+
+    def get_next_edge(self, dep_tree, idx_l, idx_r):
+        next_arc=None
+
+        if dep_tree[idx_l].head==idx_r:
+            next_arc = dep_tree[idx_l]
+        
+        elif dep_tree[idx_r].head==idx_l:
+            next_arc = dep_tree[idx_r]
+        
+        return next_arc
+
+    def two_planar_propagate(self, nodes):
+        p1=[]
+        p2=[]
+        fp1=[]
+        fp2=[]
+
+        for i in range(0, (len(nodes))):
+            for j in range(i, -1, -1):
+                # if the node in position 'i' has an arc to 'j' 
+                # or node in position 'j' has an arc to 'i'
+                next_arc=self.get_next_edge(nodes, i, j)
+                if next_arc is None:
+                    continue
+                else:
+                    # check restrictions
+                    if next_arc not in fp1:
+                        p1.append(next_arc)
+                        fp1, fp2 = self.propagate(nodes, fp1, fp2, next_arc, 2)
+                    
+                    elif next_arc not in fp2:
+                        p2.append(next_arc)
+                        fp1, fp2 = self.propagate(nodes, fp1, fp2, next_arc, 1)
+        return p1, p2
+    def propagate(self, nodes, fp1, fp2, current_edge, i):
+        # add the current edge to the forbidden plane opposite to the plane
+        # where the node has already been added
+        fpi  = None
+        fp3mi= None
+        if i==1:
+            fpi  = fp1
+            fp3mi= fp2
+        if i==2:
+            fpi  = fp2
+            fp3mi= fp1
+
+        fpi.append(current_edge)
+        
+        # add all nodes from the dependency graph that crosses the current edge
+        # to the corresponding forbidden plane
+        for node in nodes:
+            if current_edge.check_cross(node):
+                if node not in fp3mi:
+                    (fp1, fp2)=self.propagate(nodes, fp1, fp2, node, 3-i)
+        
+        return fp1, fp2
+
+    def two_planar_greedy(self, dep_tree):    
+        plane_1 = []
+        plane_2 = []
+
+        for i in range(len(dep_tree)):
+            for j in range(i, -1, -1):
+                # if the node in position 'i' has an arc to 'j' 
+                # or node in position 'j' has an arc to 'i'
+                next_arc = self.get_next_edge(dep_tree, i, j)
+                if next_arc is None:
+                    continue
+
+                else:
+                    cross_plane_1 = False
+                    cross_plane_2 = False
+                    for node in plane_1:                
+                        cross_plane_1 = cross_plane_1 or next_arc.check_cross(node)
+                    for node in plane_2:        
+                        cross_plane_2 = cross_plane_2 or next_arc.check_cross(node)
+                    
+                    if not cross_plane_1:
+                        plane_1.append(next_arc)
+                    elif not cross_plane_2:
+                        plane_2.append(next_arc)
+
+        # processs them separately
+        return plane_1,plane_2
+
+
+    def encode(self, dep_tree):
+        # create brackets array
+        n_nodes = len(dep_tree)
+        labels_brk     = [""] * (n_nodes + 1)
+
+        # separate the planes
+        if self.planar_alg==D_2P_GREED:
+            p1_nodes, p2_nodes = self.two_planar_greedy(dep_tree)
+        elif self.planar_alg==D_2P_PROP:
+            p1_nodes, p2_nodes = self.two_planar_propagate(dep_tree)
+            
+        # get brackets separatelly
+        labels_brk = self.encode_step(p1_nodes, labels_brk, ['>','/','\\','<'])
+        labels_brk = self.encode_step(p2_nodes, labels_brk, ['>*','/*','\\*','<*'])
+        
+        # merge and obtain labels
+        lbls=[]
+        dep_tree.remove_dummy()
+        for node in dep_tree:
+            current = D_Label(labels_brk[node.id], node.relation, self.separator)
+            lbls.append(current)
+        return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), lbls, len(lbls))
+
+    def encode_step(self, p, lbl_brk, brk_chars):
+        for node in p:
+            # skip root relations (optional?)
+            if node.head==0:
+                continue
+            if node.id < node.head:
+                if self.displacement:
+                    lbl_brk[node.id+1]+=brk_chars[3]
+                else:
+                    lbl_brk[node.id]+=brk_chars[3]
+
+                lbl_brk[node.head]+=brk_chars[2]
+            else:
+                if self.displacement:
+                    lbl_brk[node.head+1]+=brk_chars[1]
+                else:
+                    lbl_brk[node.head]+=brk_chars[1]
+
+                lbl_brk[node.id]+=brk_chars[0]
+        return lbl_brk
+
+    def decode(self, lin_tree):
+        decoded_tree = D_Tree.empty_tree(len(lin_tree)+1)
+        
+        # create plane stacks
+        l_stack_p1=[]
+        l_stack_p2=[]
+        r_stack_p1=[]
+        r_stack_p2=[]
+        
+        current_node=1
+
+        for word, postag, features, label in lin_tree.iterrows():
+            brks = list(label.xi) if label.xi != D_NONE_LABEL else []
+            temp_brks=[]
+            
+            for i in range(0, len(brks)):
+                current_char=brks[i]
+                if brks[i]=="*":
+                    current_char=temp_brks.pop()+brks[i]
+                temp_brks.append(current_char)
+                    
+            brks=temp_brks
+            
+            # set parameters to the node
+            decoded_tree.update_word(current_node, word)
+            decoded_tree.update_upos(current_node, postag)
+            decoded_tree.update_relation(current_node, label.li)
+            
+            # fill the relation using brks
+            for char in brks:
+                if char == "<":
+                    node_id=current_node + (-1 if self.displacement else 0)
+                    r_stack_p1.append((node_id,char))
+                
+                if char == "\\":
+                    head_id = r_stack_p1.pop()[0] if len(r_stack_p1)>0 else 0
+                    decoded_tree.update_head(head_id, current_node)
+                
+                if char =="/":
+                    node_id=current_node + (-1 if self.displacement else 0)
+                    l_stack_p1.append((node_id,char))
+
+                if char == ">":
+                    head_id = l_stack_p1.pop()[0] if len(l_stack_p1)>0 else 0
+                    decoded_tree.update_head(current_node, head_id)
+
+                if char == "<*":
+                    node_id=current_node + (-1 if self.displacement else 0)
+                    r_stack_p2.append((node_id,char))
+                
+                if char == "\\*":
+                    head_id = r_stack_p2.pop()[0] if len(r_stack_p2)>0 else 0
+                    decoded_tree.update_head(head_id, current_node)
+                
+                if char =="/*":
+                    node_id=current_node + (-1 if self.displacement else 0)
+                    l_stack_p2.append((node_id,char))
+
+                if char == ">*":
+                    head_id = l_stack_p2.pop()[0] if len(l_stack_p2)>0 else 0
+                    decoded_tree.update_head(current_node, head_id)
+            
+            current_node+=1
+
+        decoded_tree.remove_dummy()
+        return decoded_tree   
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/brk_based.py b/tania_scripts/supar/codelin/encs/enc_deps/brk_based.py
new file mode 100644
index 0000000000000000000000000000000000000000..58758155d91e6fe7b43544f6d8fb0eb262879e65
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_deps/brk_based.py
@@ -0,0 +1,87 @@
+from supar.codelin.encs.abstract_encoding import ADEncoding
+from supar.codelin.models.deps_label import D_Label
+from supar.codelin.models.deps_tree import D_Tree
+from supar.codelin.utils.constants import D_NONE_LABEL
+from supar.codelin.models.linearized_tree import LinearizedTree
+
+class D_BrkBasedEncoding(ADEncoding):
+    
+    def __init__(self, separator, displacement):
+        super().__init__(separator)
+        self.displacement = displacement
+
+    def __str__(self):
+        return "Dependency Bracketing Based Encoding"
+
+
+    def encode(self, dep_tree):
+        n_nodes = len(dep_tree)
+        labels_brk     = [""] * (n_nodes + 1)
+        encoded_labels = []
+        
+        # compute brackets array
+        # brackets array should be sorted ?
+        dep_tree.remove_dummy()
+        for node in dep_tree:
+            # skip root relations (optional?)
+            if node.head == 0:
+                continue
+            
+            if node.is_left_arc():
+                labels_brk[node.id + (1 if self.displacement else 0)]+='<'
+                labels_brk[node.head]+='\\'
+            
+            else:
+                labels_brk[node.head + (1 if self.displacement else 0)]+='/'
+                labels_brk[node.id]+='>'
+        
+        # encode labels
+        for node in dep_tree:
+            li = node.relation
+            xi = labels_brk[node.id]
+
+            current = D_Label(xi, li, self.separator)
+            encoded_labels.append(current)
+
+        return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels))
+
+    def decode(self, lin_tree):
+        # Create an empty tree with n labels
+        decoded_tree = D_Tree.empty_tree(len(lin_tree)+1)
+        
+        l_stack = []
+        r_stack = []
+        
+        current_node = 1
+        for word, postag, features, label in lin_tree.iterrows():
+            
+            # get the brackets
+            brks = list(label.xi) if label.xi != D_NONE_LABEL else []
+                       
+            # set parameters to the node
+            decoded_tree.update_word(current_node, word)
+            decoded_tree.update_upos(current_node, postag)
+            decoded_tree.update_relation(current_node, label.li)
+
+            # fill the relation using brks
+            for char in brks:
+                if char == "<":
+                    node_id = current_node + (-1 if self.displacement else 0)
+                    r_stack.append(node_id)
+
+                if char == "\\":
+                    head_id = r_stack.pop() if len(r_stack) > 0 else 0
+                    decoded_tree.update_head(head_id, current_node)
+                
+                if char =="/":
+                    node_id = current_node + (-1 if self.displacement else 0)
+                    l_stack.append(node_id)
+
+                if char == ">":
+                    head_id = l_stack.pop() if len(l_stack) > 0 else 0
+                    decoded_tree.update_head(current_node, head_id)
+            
+            current_node+=1
+
+        decoded_tree.remove_dummy()
+        return decoded_tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/naive_absolute.py b/tania_scripts/supar/codelin/encs/enc_deps/naive_absolute.py
new file mode 100644
index 0000000000000000000000000000000000000000..e624bb34feeddfa2cb8c149bfbf4f45f3b813abf
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_deps/naive_absolute.py
@@ -0,0 +1,42 @@
+from supar.codelin.encs.abstract_encoding import ADEncoding
+from supar.codelin.models.deps_label import D_Label
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.models.deps_tree import D_Tree
+from supar.codelin.utils.constants import D_NONE_LABEL
+
+class D_NaiveAbsoluteEncoding(ADEncoding):
+    def __init__(self, separator):
+        super().__init__(separator)
+
+    def __str__(self):
+        return "Dependency Naive Absolute Encoding"
+
+    def encode(self, dep_tree):
+        encoded_labels = []
+        dep_tree.remove_dummy()
+        
+        for node in dep_tree:
+            li = node.relation
+            xi = node.head
+            
+            current = D_Label(xi, li, self.separator)
+            encoded_labels.append(current)
+
+        return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels))
+    
+    def decode(self, lin_tree):
+        dep_tree = D_Tree.empty_tree(len(lin_tree)+1)
+        
+        i=1
+        for word, postag, features, label in lin_tree.iterrows(): 
+            if label.xi == D_NONE_LABEL:
+                label.xi = 0
+            
+            dep_tree.update_word(i, word)
+            dep_tree.update_upos(i, postag)
+            dep_tree.update_relation(i, label.li)
+            dep_tree.update_head(i, int(label.xi))
+            i+=1
+
+        dep_tree.remove_dummy()
+        return dep_tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/naive_relative.py b/tania_scripts/supar/codelin/encs/enc_deps/naive_relative.py
new file mode 100644
index 0000000000000000000000000000000000000000..6880f271246ae0d1ef2ce07132b1ea255708b60c
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_deps/naive_relative.py
@@ -0,0 +1,48 @@
+from supar.codelin.encs.abstract_encoding import ADEncoding
+from supar.codelin.models.deps_label import D_Label
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.models.deps_tree import D_Tree
+from supar.codelin.utils.constants import D_NONE_LABEL
+
+class D_NaiveRelativeEncoding(ADEncoding):
+    def __init__(self, separator, hang_from_root):
+        super().__init__(separator)
+        self.hfr = hang_from_root
+
+    def __str__(self):
+        return "Dependency Naive Relative Encoding"
+
+    def encode(self, dep_tree):
+        encoded_labels = []
+        dep_tree.remove_dummy()
+        for node in dep_tree:
+            li = node.relation 
+            xi = node.delta_head()
+
+            if node.relation == 'root' and self.hfr:
+                xi = D_NONE_LABEL
+            
+            current = D_Label(xi, li, self.separator)
+            encoded_labels.append(current)
+
+        return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels))
+
+    def decode(self, lin_tree):
+        dep_tree = D_Tree.empty_tree(len(lin_tree)+1)
+
+        i = 1
+        for word, postag, features, label in lin_tree.iterrows():
+            if label.xi == D_NONE_LABEL:
+                # set as root
+                dep_tree.update_head(i, 0)
+            else:
+                dep_tree.update_head(i, int(label.xi)+(i))
+                
+            
+            dep_tree.update_word(i, word)
+            dep_tree.update_upos(i, postag)
+            dep_tree.update_relation(i, label.li)
+            i+=1
+
+        dep_tree.remove_dummy()
+        return dep_tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/encs/enc_deps/pos_based.py b/tania_scripts/supar/codelin/encs/enc_deps/pos_based.py
new file mode 100644
index 0000000000000000000000000000000000000000..714efd8d1aff41234ef06409474d491ff96d23f9
--- /dev/null
+++ b/tania_scripts/supar/codelin/encs/enc_deps/pos_based.py
@@ -0,0 +1,89 @@
+from supar.codelin.encs.abstract_encoding import ADEncoding
+from supar.codelin.models.deps_label import D_Label
+from supar.codelin.models.deps_tree import D_Tree
+from supar.codelin.models.linearized_tree import LinearizedTree
+from supar.codelin.utils.constants import D_POSROOT, D_NONE_LABEL
+
+POS_ROOT_LABEL = "0--ROOT"
+
+class D_PosBasedEncoding(ADEncoding):
+    def __init__(self, separator):
+        super().__init__(separator)
+
+    def __str__(self) -> str:
+        return "Dependency Part-of-Speech Based Encoding"
+        
+    def encode(self, dep_tree):
+        
+        print("upar/codelin/encs/enc_deps/pos_based.py")
+        encoded_labels = []
+        
+        for node in dep_tree:
+            if node.id == 0:
+                # skip dummy root
+                continue
+
+            li = node.relation     
+            pi = dep_tree[node.head].upos           
+            oi = 0
+            
+            # move left or right depending if the node 
+            # dependency edge is to the left or to the right
+
+            step = 1 if node.id < node.head else -1
+            for i in range(node.id + step, node.head + step, step):
+                if pi == dep_tree[i].upos:
+                    oi += step
+
+            xi = str(oi)+"--"+pi
+
+            current = D_Label(xi, li, self.separator)
+            encoded_labels.append(current)
+
+        dep_tree.remove_dummy()
+        return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels))
+
+    def decode(self, lin_tree):
+        dep_tree = D_Tree.empty_tree(len(lin_tree)+1)
+
+        i = 1
+        postags = lin_tree.postags
+        for word, postag, features, label in lin_tree.iterrows():
+            node_id = i
+            if label.xi == D_NONE_LABEL:
+                label.xi = POS_ROOT_LABEL
+            
+            dep_tree.update_word(node_id, word)
+            dep_tree.update_upos(node_id, postag)
+            dep_tree.update_relation(node_id, label.li)
+            
+            oi, pi = label.xi.split('--')
+            oi = int(oi)
+
+            # Set head for root
+            if (pi==D_POSROOT or oi==0):
+                dep_tree.update_head(node_id, 0)
+                i+=1
+                continue
+
+            # Compute head position
+            target_oi = oi
+
+            step = 1 if oi > 0 else -1
+            stop_point = (len(postags)+1) if oi > 0 else 0
+
+            for j in range(node_id+step, stop_point, step):
+                if (pi == postags[j-1]):
+                    target_oi -= step
+                
+                if (target_oi==0):
+                    break
+            
+            head_id = j
+            dep_tree.update_head(node_id, head_id)
+            
+            i+=1
+
+        
+        dep_tree.remove_dummy()
+        return dep_tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/models/__init__.py b/tania_scripts/supar/codelin/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..055dc41308b5847ea209b4ae31a538031037b6ad
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7f007831d62522735be047094f4cf4fdcc4000c
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ab31bdd2279ddff308fa26008115f5475640b1c
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6816f095c20760ddde9f1421313ff80f2f8e887
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..069e2ef3ba4d3d9cbe4ae8d6dea1e8105dfb3d47
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..376f81b9df97423cdc912c3eac7e7063b9e838a0
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a316c2c1296885e31867d3fd9fabd908a3cc177
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9643deb605dccdecaca4931fdcf8ec1c1ad66a6b
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8e4df612a26f2cf508b2bb5fa7e85648a86b36d
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca3679ea018cddc145eefe45675c608e6475204d
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6b35816c4e91140d23f0a05ada4956a3eca4872
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30e902912d8da97efe2c282ebdf01db36bf1e563
Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/models/const_label.py b/tania_scripts/supar/codelin/models/const_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df73ca2919dc76c935c1e1641554be3869e8ae1
--- /dev/null
+++ b/tania_scripts/supar/codelin/models/const_label.py
@@ -0,0 +1,38 @@
+from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_RELATIVE_ENCODING, C_NONE_LABEL
+
+class C_Label:
+    def __init__(self, nc, lc, uc, et, sp, uj):
+        self.encoding_type = et
+
+        self.n_commons = int(nc)
+        self.last_common = lc
+        self.unary_chain = uc if uc != C_NONE_LABEL else None
+        self.separator = sp
+        self.unary_joiner = uj
+    
+    def __repr__(self):
+        unary_str = self.unary_joiner.join([self.unary_chain]) if self.unary_chain else ""
+        return (str(self.n_commons) + ("*" if self.encoding_type==C_RELATIVE_ENCODING else "")
+        + self.separator + self.last_common + (self.separator + unary_str if self.unary_chain else ""))
+    
+    def to_absolute(self, last_label):
+        self.n_commons+=last_label.n_commons
+        if self.n_commons<=0:
+            self.n_commons = 1
+        
+        self.encoding_type=C_ABSOLUTE_ENCODING
+    
+    @staticmethod
+    def from_string(l, sep, uj):
+        label_components = l.split(sep)
+        
+        if len(label_components)== 2:
+            nc, lc = label_components
+            uc = None
+        else:
+            nc, lc, uc = label_components
+        
+        et = C_RELATIVE_ENCODING if '*' in nc else C_ABSOLUTE_ENCODING
+        nc = nc.replace("*","")
+        return C_Label(nc, lc, uc, et, sep, uj)
+    
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/models/const_tree.py b/tania_scripts/supar/codelin/models/const_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5466582822ecf47bdf76131495191c49da94946
--- /dev/null
+++ b/tania_scripts/supar/codelin/models/const_tree.py
@@ -0,0 +1,414 @@
+from supar.codelin.utils.constants import C_END_LABEL, C_START_LABEL, C_NONE_LABEL
+from supar.codelin.utils.constants import C_CONFLICT_SEPARATOR, C_STRAT_MAX, C_STRAT_FIRST, C_STRAT_LAST, C_NONE_LABEL, C_ROOT_LABEL
+import copy
+
+class C_Tree:
+    def __init__(self, label, children=[], feats=None):
+        self.parent = None
+        self.label = label
+        self.children = children
+        self.features = feats
+
+# Adders and deleters
+    def add_child(self, child):
+        '''
+        Function that adds a child to the current tree
+        '''
+        if type(child) is list:
+            for c in child:
+                self.add_child(c)
+        elif type(child) is C_Tree:
+            self.children.append(child)
+            child.parent = self
+
+        else:
+            raise TypeError("[!] Child must be a ConstituentTree or a list of Constituent Trees")
+    
+    def add_left_child(self, child):
+        '''
+        Function that adds a child to the left of the current tree
+        '''
+        if type(child) is not C_Tree:
+            raise TypeError("[!] Child must be a ConstituentTree")
+        
+        self.children = [child] + self.children
+        child.parent = self
+
+    def del_child(self, child):
+        '''
+        Function that deletes a child from the current tree
+        without adding its children to the current tree
+        '''
+        if type(child) is not C_Tree:
+            raise TypeError("[!] Child must be a ConstituentTree")
+
+        self.children.remove(child)
+        child.parent = None
+
+# Getters
+    def r_child(self):
+        '''
+        Function that returns the rightmost child of a tree
+        '''
+        return self.children[len(self.children)-1]
+    
+    def l_child(self):
+        '''
+        Function that returns the leftmost child of a tree
+        '''
+        return self.children[0]
+
+    def r_siblings(self):
+        '''
+        Function that returns the right siblings of a tree
+        '''
+        return self.parent.children[self.parent.children.index(self)+1:]
+
+    def l_siblings(self):
+        '''
+        Function that returns the left siblings of a tree
+        '''
+        return self.parent.children[:self.parent.children.index(self)]
+
+    def get_root(self):
+        '''
+        Function that returns the root of a tree
+        '''
+        if self.parent is None:
+            return self
+        else:
+            return self.parent.get_root()
+
+    def extract_features(self, f_mark = "##", f_sep = "|"):
+        # go through all pre-terminal nodes
+        # of the tree
+        for node in self.get_preterminals():
+            
+            if f_mark in node.label:
+                node.features = {}
+                label = node.label.split(f_mark)[0]
+                features   = node.label.split(f_mark)[1]
+
+                node.label = label
+
+                # add features to the tree
+                for feature in features.split(f_sep):
+                    
+                    if feature == "_":
+                        continue
+                
+                    key = feature.split("=")[0]
+                    value = feature.split("=")[1]
+
+                    node.features[key]=value
+
+    def get_feature_names(self):
+        '''
+        Returns a set containing all feature names
+        for the tree
+        '''
+        feat_names = set()
+
+        for child in self.children:
+            feat_names = feat_names.union(child.get_feature_names())
+        if self.features is not None:
+            feat_names = feat_names.union(set(self.features.keys()))
+
+        return feat_names            
+
+# Word and Postags getters
+    def get_words(self):
+        '''
+        Function that returns the terminal nodes of a tree
+        '''
+        if self.is_terminal():
+            return [self.label]
+        else:
+            return [node for child in self.children for node in child.get_words()]
+
+    def get_postags(self):
+        '''
+        Function that returns the preterminal nodes of a tree
+        '''
+        if self.is_preterminal():
+            return [self.label]
+        else:
+            return [node for child in self.children for node in child.get_postags()]
+
+# Terminal checking
+    def is_terminal(self):
+        '''
+        Function that checks if a tree is a terminal
+        '''
+        return len(self.children) == 0
+
+    def is_preterminal(self):
+        '''
+        Function that checks if a tree is a preterminal
+        '''
+        return len(self.children) == 1 and self.children[0].is_terminal()
+
+# Terminal getters
+    def get_terminals(self):
+        '''
+        Function that returns the terminal nodes of a tree
+        '''
+        if self.is_terminal():
+            return [self]
+        else:
+            return [node for child in self.children for node in child.get_terminals()]
+
+    def get_preterminals(self):
+        '''
+        Function that returns the terminal nodes of a tree
+        '''
+        if self.is_preterminal():
+            return [self]
+        else:
+            return [node for child in self.children for node in child.get_preterminals()]
+
+# Tree processing
+    def collapse_unary(self, unary_joiner="+"):
+        '''
+        Function that collapses unary chains
+        into single nodes using a unary_joiner as join character
+        '''
+        for child in self.children:
+            child.collapse_unary(unary_joiner)
+        if len(self.children)==1 and not self.is_preterminal():
+            self.label += unary_joiner + self.children[0].label
+            self.children = self.children[0].children
+
+    def inherit_tree(self):
+        '''
+        Removes the top node of the tree and delegates it
+        to its firstborn child. 
+        
+        (S (NP (NNP John)) (VP (VBD died))) => (NP (NNP John))
+        '''
+        self.label = self.children[0].label
+        self.children = self.children[0].children
+
+    def add_end_node(self):
+        '''
+        Function that adds a dummy end node to the 
+        rightmost part of the tree
+        '''
+        self.add_child(C_Tree(C_END_LABEL, []))
+
+    def add_start_node(self):
+        '''
+        Function that adds a dummy start node to the leftmost
+        part of the tree
+        '''
+        self.add_left_child(C_Tree(C_START_LABEL, []))
+        
+    def path_to_leaves(self, collapse_unary=True, unary_joiner="+"):
+        '''
+        Function that given a Tree returns a list of paths
+        from the root to the leaves, encoding a level index into
+        nodes to make them unique.
+        '''
+        self.add_end_node()
+                    
+        if collapse_unary:
+            self.collapse_unary(unary_joiner)
+
+        paths = self.path_to_leaves_rec([],[],0)
+        return paths
+
+    def path_to_leaves_rec(self, current_path, paths, idx):
+        '''
+        Recursive step of the path_to_leaves function where we store
+        the common path based on the current node
+        '''
+        # pass by value
+        path = copy.deepcopy(current_path)
+        
+        if (len(self.children)==0):
+            # we are at a leaf. store the path in a new list
+            path.append(self.label)
+            paths.append(path)
+        else:
+            path.append(self.label+str(idx))
+            for child in self.children:
+                child.path_to_leaves_rec(path, paths, idx)
+                idx+=1
+        return paths
+
+    def fill_pos_nodes(self, postag, word, unary_chain, unary_joiner):
+        if self.label == postag:
+            # if the current level is already a postag level. This may happen on 
+            # trees shaped as (NP tree) that exist on the SPMRL treebanks
+            self.children.append(C_Tree(word, []))
+            return
+        
+        if unary_chain:
+            unary_chain = unary_chain.split(unary_joiner)
+            unary_chain.reverse()
+            pos_tree = C_Tree(postag, [C_Tree(word, [])])
+            for node in unary_chain:
+                temp_tree = C_Tree(node, [pos_tree])
+                pos_tree = temp_tree
+        else:
+            pos_tree = C_Tree(postag, [C_Tree(word, [])])
+
+        self.add_child(pos_tree)
+
+    def renounce_children(self):
+        '''
+        Function that deletes current tree from its parent
+        and adds its children to the parent
+        '''
+        self.parent.children = self.l_siblings() + self.children + self.r_siblings()
+        for child in self.children:
+            child.parent = self.parent
+
+
+    def prune_nones(self):
+        """
+        Return a copy of the tree without 
+        null nodes (nodes with label C_NONE_LABEL)
+        """
+        if self.label != C_NONE_LABEL:
+            t = C_Tree(self.label, [])
+            new_childs = [c.prune_nones() for c in self.children]
+            t.add_child(new_childs)
+            return t
+        
+        else:
+            return [c.prune_nones() for c in self.children]
+
+    def remove_conflicts(self, conflict_strat):
+        '''
+        Removes all conflicts in the label of the tree generated
+        during the decoding process. Conflicts will be signaled by -||- 
+        string.
+        '''
+        for c in self.children:
+            if type(c) is C_Tree:
+                c.remove_conflicts(conflict_strat)
+        if C_CONFLICT_SEPARATOR in self.label:
+            labels = self.label.split(C_CONFLICT_SEPARATOR)
+            
+            if conflict_strat == C_STRAT_MAX:
+                self.label = max(set(labels), key=labels.count)
+            if conflict_strat == C_STRAT_FIRST:
+                self.label = labels[0]
+            if conflict_strat == C_STRAT_LAST:
+                self.label = labels[len(labels)-1]
+
+    def postprocess_tree(self, conflict_strat, clean_nulls=True, default_root="S"):
+        '''
+        Returns a C_Tree object with conflicts in node labels removed
+        and with NULL nodes cleaned.
+        '''
+        if clean_nulls:
+            if self.label == C_NONE_LABEL or self.label==C_ROOT_LABEL:
+                self.label = default_root
+            t = self.prune_nones()
+        else:
+            t = self
+        t.remove_conflicts(conflict_strat)
+        return t
+        
+        # print( fix_tree)
+        
+    def reverse_tree(self):
+        '''
+        Reverses the order of all the tree children
+        '''
+        for c in self.children:
+            if type(c) is C_Tree:
+                c.reverse_tree()
+        self.children.reverse()
+
+# Printing and python-related functions
+    def __str__(self):
+        if len(self.children) == 0:
+            label_str = self.label
+            
+            if self.features is not None:
+                features_str = "##" + "|".join([key+"="+value for key,value in self.features.items()])
+            
+            label_str = label_str.replace("(","-LRB-")
+            label_str = label_str.replace(")","-RRB-")
+        else:
+            label_str =  "(" + self.label + " "
+            if self.features is not None:
+                features_str = "##"+ "|".join([key+"="+value for key,value in self.features.items()]) 
+            
+            label_str += " ".join([str(child) for child in self.children]) + ")"
+        return label_str
+
+    def __repr__(self):
+        return self.__str__()
+
+    def __eq__(self, other):
+        if isinstance(other, C_Tree):
+            return self.label == other.label and self.children == other.children
+        return False
+
+    def __hash__(self):
+        return hash((self.label, tuple(self.children)))
+
+    def __len__(self):
+        return len(self.children)
+
+    def __iter__(self):
+        yield self.label
+        for child in self.children:
+            yield child
+
+    def __contains__(self, item):
+        return item in self.label or item in self.children
+
+
+# Tree creation
+    @staticmethod
+    def from_string(s):
+        s = s.replace("(","( ")
+        s = s.replace(")"," )")
+        s = s.split(" ")
+        
+        # create dummy label and append it to the stack
+        stack = []        
+        i=0
+        while i < (len(s)):
+            if s[i]=="(":
+                # If we find a l_brk we create a new tree
+                # with label=next_word. Skip next_word.
+                w = s[i+1]
+                t = C_Tree(w, [])
+                stack.append(t)
+                i+=1
+
+            elif s[i]==")":
+                # If we find a r_brk set top of the stack
+                # as children to the second top of the stack
+
+                t = stack.pop()
+                
+                if len(stack)==0:
+                    return t
+
+                pt = stack.pop()
+                pt.add_child(t)
+                stack.append(pt)
+            
+            else:
+                # If we find a word set it as children
+                # of the current tree.
+                t = stack.pop()
+                w = s[i]
+                c = C_Tree(w, [])
+                t.add_child(c)
+                stack.append(t)
+
+            i+=1
+        return t
+
+# Default trees
+    @staticmethod
+    def empty_tree():
+        return C_Tree(C_NONE_LABEL, [])
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/models/deps_label.py b/tania_scripts/supar/codelin/models/deps_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada0b83c036cabb6bff8ebf715d1c068cc2f2c82
--- /dev/null
+++ b/tania_scripts/supar/codelin/models/deps_label.py
@@ -0,0 +1,18 @@
+class D_Label:
+    def __init__(self, xi, li, sp):
+        self.separator = sp
+
+        self.xi = xi    # dependency relation
+        self.li = li    # encoding
+
+    def __repr__(self):
+        return f'{self.xi}{self.separator}{self.li}'
+
+    @staticmethod
+    def from_string(lbl_str, sep):
+        xi, li = lbl_str.split(sep)
+        return D_Label(xi, li, sep)
+    
+    @staticmethod
+    def empty_label(separator):
+        return D_Label("", "", separator)
diff --git a/tania_scripts/supar/codelin/models/deps_tree.py b/tania_scripts/supar/codelin/models/deps_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..43130267b3ea24f27855f420ae2ce4abbbf08f94
--- /dev/null
+++ b/tania_scripts/supar/codelin/models/deps_tree.py
@@ -0,0 +1,371 @@
+from supar.codelin.utils.constants import D_ROOT_HEAD, D_NULLHEAD, D_ROOT_REL, D_POSROOT, D_EMPTYREL
+
+class D_Node:
+    def __init__(self, wid, form, lemma=None, upos=None, xpos=None, feats=None, head=None, deprel=None, deps=None, misc=None):
+        self.id = int(wid)                      # word id
+        
+        self.form = form if form else "_"       # word 
+        self.lemma = lemma if lemma else "_"    # word lemma/stem
+        self.upos = upos if upos else "_"       # universal postag
+        self.xpos = xpos if xpos else "_"       # language_specific postag
+        self.feats = self.parse_feats(feats) if feats else "_"    # morphological features
+        
+        self.head = int(head)                   # id of the word that depends on
+        self.relation = deprel                  # type of relation with head
+
+        self.deps = deps if deps else "_"       # enhanced dependency graph
+        self.misc = misc if misc else "_"       # miscelaneous data
+    
+    def is_left_arc(self):
+        return self.head > self.id
+
+    def delta_head(self):
+        return self.head - self.id
+    
+    def parse_feats(self, feats):
+        if feats == '_':
+            return [None]
+        else:
+            return [x for x in feats.split('|')]
+
+    def check_cross(self, other):
+        if ((self.head == other.head) or (self.head==other.id)):
+                return False
+
+        r_id_inside = (other.head < self.id < other.id)
+        l_id_inside = (other.id < self.id < other.head)
+
+        id_inside = r_id_inside or l_id_inside
+
+        r_head_inside = (other.head < self.head < other.id)
+        l_head_inside = (other.id < self.head < other.head)
+
+        head_inside = r_head_inside or l_head_inside
+
+        return head_inside^id_inside
+    
+    def __repr__(self):
+        return '\t'.join(str(e) for e in list(self.__dict__.values()))+'\n'
+
+    def __eq__(self, other):
+        return self.__dict__ == other.__dict__
+    
+
+    @staticmethod
+    def from_string(conll_str):
+        wid,form,lemma,upos,xpos,feats,head,deprel,deps,misc = conll_str.split('\t')
+        return D_Node(int(wid), form, lemma, upos, xpos, feats, int(head), deprel, deps, misc)
+
+    @staticmethod
+    def dummy_root():
+        return D_Node(0, D_POSROOT, None, D_POSROOT, None, None, 0, D_EMPTYREL, None, None)
+    
+    @staticmethod
+    def empty_node():
+        return D_Node(0, None, None, None, None, None, 0, None, None, None)
+
+class D_Tree:
+    def __init__(self, nodes):
+        self.nodes = nodes
+
+# getters    
+    def get_node(self, id):
+        return self.nodes[id-1]
+
+    def get_edges(self):
+        '''
+        Return sentence dependency edges as a tuple 
+        shaped as ((d,h),r) where d is the dependant of the relation,
+        h the head of the relation and r the relationship type
+        '''
+        return list(map((lambda x :((x.id, x.head), x.relation)), self.nodes))
+    
+    def get_arcs(self):
+        '''
+        Return sentence dependency edges as a tuple 
+        shaped as (d,h) where d is the dependant of the relation,
+        and h the head of the relation.
+        '''
+        return list(map((lambda x :(x.id, x.head)), self.nodes))
+
+    def get_relations(self):
+        '''
+        Return a list of relationships betwee nodes
+        '''
+        return list(map((lambda x :x.relation), self.nodes))
+    
+    def get_sentence(self):
+        '''
+        Return the sentence as a string
+        '''
+        return " ".join(list(map((lambda x :x.form), self.nodes)))
+
+    def get_words(self):
+        '''
+        Returns the words of the sentence as a list
+        '''
+        return list(map((lambda x :x.form), self.nodes))
+
+    def get_indexes(self):
+        '''
+        Returns a list of integers representing the words of the 
+        dependency tree
+        '''
+        return list(map((lambda x :x.id), self.nodes))
+
+    def get_postags(self):
+        '''
+        Returns the part of speech tags of the tree
+        '''
+        return list(map((lambda x :x.upos), self.nodes))
+
+    def get_lemmas(self):
+        '''
+        Returns the lemmas of the tree
+        '''
+        return list(map((lambda x :x.lemma), self.nodes))
+
+    def get_heads(self):
+        '''
+        Returns the heads of the tree
+        '''
+        return list(map((lambda x :x.head), self.nodes))
+    
+    def get_feats(self):
+        '''
+        Returns the morphological features of the tree
+        '''
+        return list(map((lambda x :x.feats), self.nodes))
+
+# update functions
+    def append_node(self, node):
+        '''
+        Append a node to the tree and sorts the nodes by id
+        '''
+        self.nodes.append(node)
+        self.nodes.sort(key=lambda x: x.id)
+
+    def update_head(self, node_id, head_value):
+        '''
+        Update the head of a node indicated by its id
+        '''
+        for node in self.nodes:
+            if node.id == node_id:
+                node.head = head_value
+                break
+    
+    def update_relation(self, node_id, relation_value):
+        '''
+        Update the relation of a node indicated by its id
+        '''
+        for node in self.nodes:
+            if node.id == node_id:
+                node.relation = relation_value
+                break
+    
+    def update_word(self, node_id, word):
+        '''
+        Update the word of a node indicated by its id
+        '''
+        for node in self.nodes:
+            if node.id == node_id:
+                node.form = word
+                break
+
+    def update_upos(self, node_id, postag):
+        '''
+        Update the upos field of a node indicated by its id
+        '''
+        for node in self.nodes:
+            if node.id == node_id:
+                node.upos = postag
+                break
+
+# properties functions
+    def is_projective(self):
+        '''
+        Returns a boolean indicating if the dependency tree
+        is projective (i.e. no edges are crossing)
+        '''
+        arcs = self.get_arcs()
+        for (i,j) in arcs:
+            for (k,l) in arcs:
+                if (i,j) != (k,l) and min(i,j) < min(k,l) < max(i,j) < max(k,l):
+                    return False
+        return True
+    
+# postprocessing
+    def remove_dummy(self):
+        self.nodes = self.nodes[1:]
+
+    def postprocess_tree(self, search_root_strat, allow_multi_roots=False):
+        '''
+        Postprocess the tree by finding the root according to the selected 
+        strategy and fixing cycles and out of bounds heads
+        '''
+        # 1) Find the root
+        root = self.root_search(search_root_strat)
+        
+        # 2) Fix oob heads
+        self.fix_oob_heads()
+        
+        # 3) Fix cycles
+        self.fix_cycles(root)
+        
+        # 4) Set all null heads to root and remove other root candidates
+        for node in self.nodes:
+            if node.id == root:
+                node.head = 0
+                continue
+            if node.head == D_NULLHEAD:
+                node.head = root
+            if not allow_multi_roots and node.head == 0:
+                node.head = root
+
+    def root_search(self, search_root_strat):
+        '''
+        Search for the root of the tree using the method indicated
+        '''
+        root = 1 # Default root
+        for node in self.nodes:    
+            if search_root_strat == D_ROOT_HEAD:
+                if node.head == 0:
+                    root = node.id
+                    break
+            
+            elif search_root_strat == D_ROOT_REL:
+                if node.rel == 'root' or node.rel == 'ROOT':
+                    root = node.id
+                    break
+        return root
+
+    def fix_oob_heads(self):
+        '''
+        Fixes heads of the tree (if they dont exist, if they are out of bounds, etc)
+        If a head is out of bounds set it to nullhead
+        '''
+        for node in self.nodes:
+            if node.head==D_NULLHEAD:
+                continue
+            if int(node.head) < 0:
+                node.head = D_NULLHEAD
+            elif int(node.head) > len(self.nodes):
+                node.head = D_NULLHEAD
+    
+    def fix_cycles(self, root):
+        '''
+        Breaks cycles in the tree by setting the head of the node to root_id
+        '''
+        for node in self.nodes:
+            visited = []
+            
+            while (node.id != root) and (node.head !=D_NULLHEAD):
+                if node in visited:
+                    node.head = D_NULLHEAD
+                else:
+                    visited.append(node)
+                    next_node = min(max(node.head-1, 0), len(self.nodes)-1)
+                    node = self.nodes[next_node]
+        
+# python related functions
+    def __repr__(self):
+        return "".join(str(e) for e in self.nodes)+"\n"
+    
+    def __iter__(self):
+        for n in self.nodes:
+            yield n
+
+    def __getitem__(self, key):
+        return self.nodes[key]  
+    
+    def __len__(self):
+        return self.nodes.__len__()
+
+# base tree
+    @staticmethod
+    def empty_tree(l=1):
+        ''' 
+        Creates an empty dependency tree with l nodes
+        '''
+        t = D_Tree([])
+        for i in range(l):
+            n = D_Node.empty_node()
+            n.id = i
+            t.append_node(n)
+        return t
+
+# reader and writter
+    @staticmethod
+    def from_string(conll_str, dummy_root=True, clean_contractions=True, clean_omisions=True):
+        '''
+        Create a ConllTree from a dependency tree conll-u string.
+        '''
+        data = conll_str.split('\n')
+        dependency_tree_start_index = 0
+        for line in data:
+            if len(line)>0 and line[0]!="#":
+                break
+            dependency_tree_start_index+=1
+        data = data[dependency_tree_start_index:]
+        nodes = []
+        if dummy_root:
+            nodes.append(D_Node.dummy_root())
+        
+        for line in data:
+            # check if not valid line (empty or not enough fields)
+            if (len(line)<=1) or len(line.split('\t'))<10:
+                continue 
+            
+            wid = line.split('\t')[0]
+
+            # check if node is a comment (comments are marked with #)
+            if "#" in wid:
+                continue
+            
+            # check if node is a contraction (multiexp lines are marked with .)
+            if clean_contractions and "-" in wid:    
+                continue
+            
+            # check if node is an omited word (empty nodes are marked with .)
+            if clean_omisions and "." in wid:
+                continue
+
+            conll_node = D_Node.from_string(line)
+            nodes.append(conll_node)
+        
+        return D_Tree(nodes)
+    
+    @staticmethod
+    def read_conllu_file(file_path, filter_projective = True):
+        '''
+        Read a conllu file and return a list of ConllTree objects.
+        '''
+        with open(file_path, 'r') as f:
+            data = f.read()
+        data = data.split('\n\n')
+        # remove last empty line
+        data = data[:-1]
+        
+        trees = []
+        for x in data:
+            t = D_Tree.from_string(x)
+            if not filter_projective or t.is_projective():
+                trees.append(t)
+        return trees    
+
+    @staticmethod
+    def write_conllu_file(file_path, trees):
+        '''
+        Write a list of ConllTree objects to a conllu file.
+        '''
+        with open(file_path, 'w') as f:
+            f.write("".join(str(e) for e in trees))
+
+    @staticmethod
+    def write_conllu(file_io, tree):
+        '''
+        Write a single ConllTree to a already open file.
+        Includes the # text = ... line
+        '''
+        file_io.write("# text = "+tree.get_sentence()+"\n")
+        file_io.write("".join(str(e) for e in tree)+"\n")
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/models/linearized_tree.py b/tania_scripts/supar/codelin/models/linearized_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..1619b315c986e2c682f1e8d474f053653dd86558
--- /dev/null
+++ b/tania_scripts/supar/codelin/models/linearized_tree.py
@@ -0,0 +1,179 @@
+from supar.codelin.utils.constants import BOS, EOS, C_NO_POSTAG_LABEL
+from supar.codelin.models.const_label import C_Label
+from supar.codelin.models.deps_label import D_Label
+
+class LinearizedTree:
+    def __init__(self, words, postags, additional_feats, labels, n_feats):
+        self.words = words
+        self.postags = postags
+        self.additional_feats = additional_feats
+        self.labels = labels
+        #len(f_idx_dict.keys()) = n_feats
+        
+    def get_sentence(self):
+        return "".join(self.words)
+    
+    def get_labels(self):
+        return self.labels
+
+    def get_word(self, index):
+        return self.words[index]
+
+    def get_postag(self, index):
+        return self.postags[index]
+    
+    def get_additional_feat(self, index):
+        return self.additional_feats[index] if len(self.additional_feats) > 0 else None
+    
+    def get_label(self, index):
+        return self.labels[index]
+    
+    def reverse_tree(self, ignore_bos_eos=True):
+        '''
+        Reverses the lists of words, postags, additional_feats and labels.
+        Do not reverses the first (BOS) and last (EOS) elements
+        '''
+        if ignore_bos_eos:
+            self.words = self.words[1:-1][::-1]
+            self.postags = self.postags[1:-1][::-1]
+            self.additional_feats = self.additional_feats[1:-1][::-1]
+            self.labels = self.labels[1:-1][::-1]
+        else:
+            self.words = self.words[::-1]
+            self.postags = self.postags[::-1]
+            self.additional_feats = self.additional_feats[::-1]
+            self.labels = self.labels[::-1]
+
+    def add_row(self, word, postag, additional_feat, label):
+        self.words.append(word)
+        self.postags.append(postag)
+        self.additional_feats.append(additional_feat)
+        self.labels.append(label)
+    
+    def iterrows(self):
+        for i in range(len(self)):
+            yield self.get_word(i), self.get_postag(i), self.get_additional_feat(i), self.get_label(i)
+            
+    def __len__(self):
+        return len(self.words)
+
+    def __repr__(self):        
+        return self.to_string()
+    
+    def to_string(self, f_idx_dict=None, add_bos_eos=True):
+        if add_bos_eos:
+            self.words = [BOS] + self.words + [EOS]
+            self.postags = [BOS] + self.postags + [EOS]
+            if f_idx_dict:
+                self.additional_feats = [len(f_idx_dict.keys()) * [BOS]] + self.additional_feats + [len(f_idx_dict.keys()) * [EOS]]
+            else:
+                self.additional_feats = []
+            
+            self.labels = [BOS] + self.labels + [EOS]
+        
+        tree_string = ""
+        for w, p, af, l in self.iterrows():
+            # create the output line of the linearized tree
+            output_line = [w,p]
+            
+            # check for features
+            if f_idx_dict:
+                if w == BOS:
+                    f_list = [BOS] * (len(f_idx_dict.keys())+1)
+                elif w == EOS:
+                    f_list = [EOS] * (len(f_idx_dict.keys())+1)
+                else:
+                    f_list = ["_"] * (len(f_idx_dict.keys())+1)
+                
+                if af != [None]:
+                    for element in af:
+                        key, value = element.split("=", 1) if len(element.split("=",1))==2 else (None, None)
+                        if key in f_idx_dict.keys():
+                            f_list[f_idx_dict[key]] = value
+                
+                # append the additional elements or the placehodler
+                for element in f_list:
+                    output_line.append(element)
+
+            # add the label
+            output_line.append(str(l))
+            tree_string+=u"\t".join(output_line)+u"\n"
+        
+        if add_bos_eos:
+            self.words = self.words[1:-1]
+            self.postags = self.postags[1:-1]
+            if f_idx_dict:
+                self.additional_feats = self.additional_feats[len(f_idx_dict.keys()):-len(f_idx_dict.keys())]
+            self.labels = self.labels[1:-1]
+
+        return tree_string
+
+    @staticmethod
+    def empty_tree(n_feats = 1):
+        temp_tree = LinearizedTree(labels=[], words=[], postags=[], additional_feats=[], n_feats=n_feats)
+        return temp_tree
+
+    @staticmethod
+    def from_string(content, mode, separator="_", unary_joiner="|", n_features=0):
+        '''
+        Reads a linearized tree from a string shaped as
+        -BOS- \t -BOS- \t (...) \t -BOS- \n
+        word \t postag \t (...) \t label \n
+        word \t postag \t (...) \t label \n
+        -EOS- \t -EOS- \t (...) \t -EOS- \n
+        '''
+        labels = []
+        words  = []
+        postags = []
+        additional_feats = []
+        
+        linearized_tree = None
+        for line in content.split("\n"):
+            if line=="\n":
+                print("Empty line")
+            # skip empty line
+            if len(line) <= 1:
+                continue
+
+            # Separate the label file into columns
+            line_columns = line.split("\t") if ("\t") in line else line.split(" ")
+            word = line_columns[0]
+
+            if BOS == word:
+                labels = []
+                words  = []
+                postags = []
+                additional_feats = []
+
+                continue
+            
+            if EOS == word:
+                linearized_tree = LinearizedTree(words, postags, additional_feats, labels, n_features)
+                continue
+
+            if len(line_columns) == 2:
+                word, label = line_columns
+                postag = C_NO_POSTAG_LABEL
+                feats = "_"
+            elif len(line_columns) == 3:
+                word, postag, label = line_columns[0], line_columns[1], line_columns[2]
+                feats = "_"
+            else:
+                word, postag, *feats, label = line_columns[0], line_columns[1], line_columns[1:-1], line_columns[-1]
+            
+            # Check for predictions with no label
+            if BOS in label or EOS in label:
+                label = "1"+separator+"ROOT"
+
+            words.append(word)
+            postags.append(postag)
+            if mode == "CONST":
+                labels.append(C_Label.from_string(label, separator, unary_joiner))
+            elif mode == "DEPS":
+                labels.append(D_Label.from_string(label, separator))
+            else:
+                raise ValueError("[!] Unknown mode: %s" % mode)
+            
+            additional_feats.append(feats)
+
+        return linearized_tree
\ No newline at end of file
diff --git a/tania_scripts/supar/codelin/utils/__init__.py b/tania_scripts/supar/codelin/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e624705816d7181d4f932b5d47cf136958510862
Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d9b483a30076ad9fcf5cd79e963435945670210
Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-310.pyc b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..653174b605a0b2dafa036c155f450c26128665e9
Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-310.pyc differ
diff --git a/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-311.pyc b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee60e7632471dbdadeb0deae1b865ffba40f7ef2
Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-311.pyc differ
diff --git a/tania_scripts/supar/codelin/utils/constants.py b/tania_scripts/supar/codelin/utils/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..66b0745e63dcd603f3f81dc61198c047bdd53ac8
--- /dev/null
+++ b/tania_scripts/supar/codelin/utils/constants.py
@@ -0,0 +1,54 @@
+# COMMON CONSTANTS
+
+EOS = "-EOS-"
+BOS = "-BOS-"
+
+F_CONSTITUENT = "CONST"
+F_DEPENDENCY = "DEPS"
+
+OP_ENC = "ENC"
+OP_DEC = "DEC"
+
+# CONSTITUENT ENCODINGS
+
+C_ABSOLUTE_ENCODING = 'ABS'
+C_RELATIVE_ENCODING = 'REL'
+C_DYNAMIC_ENCODING = 'DYN'
+C_INCREMENTAL_ENCODING = 'INC'
+
+C_STRAT_FIRST="strat_first"
+C_STRAT_LAST="strat_last"
+C_STRAT_MAX="strat_max"
+C_STRAT_NONE="strat_none"
+
+# CONSTITUENT MISC
+
+C_NONE_LABEL = "-NONE-"
+C_NO_POSTAG_LABEL = "-NOPOS-"
+C_ROOT_LABEL = "-ROOT-"
+C_END_LABEL = "-END-"
+C_START_LABEL = "-START-"
+C_CONFLICT_SEPARATOR = "-||-"
+C_DUMMY_END = "DUMMY_END"
+
+# DEPENDENCIY ENCODINGS
+
+D_NONE_LABEL = "-NONE-"
+
+D_ABSOLUTE_ENCODING = 'ABS'
+D_RELATIVE_ENCODING = 'REL'
+D_POS_ENCODING = 'POS'
+D_BRACKET_ENCODING = 'BRK'
+D_BRACKET_ENCODING_2P = 'BRK_2P'
+
+D_2P_GREED = 'GREED'
+D_2P_PROP = 'PROPAGATE'
+
+# DEPENDENCY MISC
+
+D_EMPTYREL = "-NOREL-"
+D_POSROOT = "-ROOT-"
+D_NULLHEAD = "-NULL-"
+
+D_ROOT_HEAD = "strat_gethead"
+D_ROOT_REL = "strat_getrel"
diff --git a/tania_scripts/supar/codelin/utils/extract_feats.py b/tania_scripts/supar/codelin/utils/extract_feats.py
new file mode 100644
index 0000000000000000000000000000000000000000..be867de99e84f24d9cc35366f9904502011c3be9
--- /dev/null
+++ b/tania_scripts/supar/codelin/utils/extract_feats.py
@@ -0,0 +1,41 @@
+import argparse
+from supar.codelin.models.const_tree import C_Tree
+from supar.codelin.models.deps_tree import D_Tree
+
+def extract_features_const(in_path):
+    file_in = open(in_path, "r")
+    feats_set = set()
+    for line in file_in:
+        line = line.rstrip()
+        tree = C_Tree.from_string(line)
+        tree.extract_features()
+        feats = tree.get_feature_names()
+        
+        feats_set = feats_set.union(feats)
+
+    return sorted(feats_set)
+
+def extract_features_deps(in_path):
+    feats_list=set()
+    trees = D_Tree.read_conllu_file(in_path, filter_projective=False)
+    for t in trees:
+        for node in t:
+            if node.feats != "_":
+                feats_list = feats_list.union(a for a in (node.feats.keys()))
+    return sorted(feats_list)
+    
+
+'''
+Python script that returns a ordered list of the features
+included in a conll tree or a constituent tree
+'''
+
+# parser = argparse.ArgumentParser(description='Prints all features in a constituent treebank')
+# parser.add_argument('form', metavar='formalism', type=str, choices=['CONST','DEPS'], help='Grammar encoding the file to extract features')
+# parser.add_argument('input', metavar='in file', type=str, help='Path of the file to clean (.trees file).')
+# args = parser.parse_args()
+# if args.form=='CONST':
+#     feats = extract_features_const(args.input)
+# elif args.form=='DEPS':
+#     feats = extract_features_conll(args.input)
+# print(" ".join(feats))
diff --git a/tania_scripts/supar/model.py b/tania_scripts/supar/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2d52770165cede33df41efe7dd6aace5dd87452
--- /dev/null
+++ b/tania_scripts/supar/model.py
@@ -0,0 +1,249 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+
+from supar.modules import (CharLSTM, ELMoEmbedding, IndependentDropout,
+                           SharedDropout, TransformerEmbedding,
+                           TransformerWordEmbedding, VariationalLSTM)
+from supar.modules.transformer import (TransformerEncoder,
+                                       TransformerEncoderLayer)
+from supar.utils import Config
+#from supar.vector_quantize import VectorQuantize
+
+from vector_quantize_pytorch import VectorQuantize
+
+
+class Model(nn.Module):
+
+    def __init__(self,
+                 n_words,
+                 n_tags=None,
+                 n_chars=None,
+                 n_lemmas=None,
+                 encoder='lstm',
+                 feat=['tag', 'char'],
+                 n_embed=300,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=45,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 char_dropout=0,
+                 elmo_bos_eos=(True, True),
+                 elmo_dropout=0.5,
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 encoder_dropout=.33,
+                 pad_index=0,
+
+                 n_decoder_layers: int = 2,
+                 bidirectional: bool = False,
+
+                 use_vq=False,
+                 vq_passes: int = 300,
+                 codebook_size=512,
+                 vq_decay=0.3,
+                 commitment_weight=3,
+
+                 **kwargs):
+        super().__init__()
+
+        self.args = Config().update(locals())
+
+        if encoder == 'lstm':
+            self.word_embed = nn.Embedding(num_embeddings=self.args.n_words,
+                                           embedding_dim=self.args.n_embed)
+
+            n_input = self.args.n_embed
+
+            if self.args.n_pretrained != self.args.n_embed:
+                n_input += self.args.n_pretrained
+            if 'tag' in self.args.feat:
+                self.tag_embed = nn.Embedding(num_embeddings=self.args.n_tags,
+                                              embedding_dim=self.args.n_feat_embed)
+                n_input += self.args.n_feat_embed
+
+            if 'char' in self.args.feat:
+                self.char_embed = CharLSTM(n_chars=self.args.n_chars,
+                                           n_embed=self.args.n_char_embed,
+                                           n_hidden=self.args.n_char_hidden,
+                                           n_out=self.args.n_feat_embed,
+                                           pad_index=self.args.char_pad_index,
+                                           dropout=self.args.char_dropout)
+                n_input += self.args.n_feat_embed
+                
+
+            if 'lemma' in self.args.feat:
+                self.lemma_embed = nn.Embedding(num_embeddings=self.args.n_lemmas,
+                                                embedding_dim=self.args.n_feat_embed)
+                n_input += self.args.n_feat_embed
+            if 'elmo' in self.args.feat:
+                self.elmo_embed = ELMoEmbedding(n_out=self.args.n_plm_embed,
+                                                bos_eos=self.args.elmo_bos_eos,
+                                                dropout=self.args.elmo_dropout,
+                                                finetune=self.args.finetune)
+                n_input += self.elmo_embed.n_out
+            if 'bert' in self.args.feat:
+                self.bert_embed = TransformerEmbedding(name=self.args.bert,
+                                                       n_layers=self.args.n_bert_layers,
+                                                       n_out=self.args.n_plm_embed,
+                                                       pooling=self.args.bert_pooling,
+                                                       pad_index=self.args.bert_pad_index,
+                                                       mix_dropout=self.args.mix_dropout,
+                                                       finetune=self.args.finetune)
+                n_input += self.bert_embed.n_out
+            self.embed_dropout = IndependentDropout(p=self.args.embed_dropout)
+            self.encoder = VariationalLSTM(
+                input_size=n_input,
+                hidden_size=self.args.n_encoder_hidden//2 if self.args.bidirectional else self.args.n_encoder_hidden,
+                num_layers=self.args.n_encoder_layers, bidirectional=self.args.bidirectional,
+                dropout=self.args.encoder_dropout)
+            self.encoder_dropout = SharedDropout(p=self.args.encoder_dropout)
+        elif encoder == 'transformer':
+            self.word_embed = TransformerWordEmbedding(n_vocab=self.args.n_words,
+                                                       n_embed=self.args.n_embed,
+                                                       pos=self.args.pos,
+                                                       pad_index=self.args.pad_index)
+            self.embed_dropout = nn.Dropout(p=self.args.embed_dropout)
+            self.encoder = TransformerEncoder(layer=TransformerEncoderLayer(n_heads=self.args.n_encoder_heads,
+                                                                            n_model=self.args.n_encoder_hidden,
+                                                                            n_inner=self.args.n_encoder_inner,
+                                                                            attn_dropout=self.args.encoder_attn_dropout,
+                                                                            ffn_dropout=self.args.encoder_ffn_dropout,
+                                                                            dropout=self.args.encoder_dropout),
+                                              n_layers=self.args.n_encoder_layers,
+                                              n_model=self.args.n_encoder_hidden)
+            self.encoder_dropout = nn.Dropout(p=self.args.encoder_dropout)
+        elif encoder == 'bert':
+            self.encoder = TransformerEmbedding(name=self.args.bert,
+                                                n_layers=self.args.n_bert_layers,
+                                                n_out=self.args.n_encoder_hidden,
+                                                pooling=self.args.bert_pooling,
+                                                pad_index=self.args.pad_index,
+                                                mix_dropout=self.args.mix_dropout,
+                                                finetune=self.args.finetune)
+            self.encoder_dropout = nn.Dropout(p=self.args.encoder_dropout)
+            self.args.n_encoder_hidden = self.encoder.n_out
+
+        self.passes_remaining = vq_passes
+        if use_vq:
+            self.vq = VectorQuantize(dim=self.args.n_encoder_hidden, codebook_size=codebook_size, decay=vq_decay,
+                                     commitment_weight=commitment_weight, eps=1e-5) #, wait_steps=0, observe_steps=vq_passes)
+        else:
+            self.vq = nn.Identity()
+
+    def load_pretrained(self, embed=None):
+        if embed is not None:
+            self.pretrained = nn.Embedding.from_pretrained(embed)
+            if embed.shape[1] != self.args.n_pretrained:
+                self.embed_proj = nn.Linear(embed.shape[1], self.args.n_pretrained)
+            nn.init.zeros_(self.word_embed.weight)
+        return self
+
+    def forward(self):
+        raise NotImplementedError
+
+    def vq_forward(self, x: torch.Tensor):
+        if not self.args.use_vq:
+            return x, torch.tensor(0)
+
+        if self.passes_remaining > (self.args.vq_passes / 2):
+            _, _, commit_loss, _ = self.vq(x)
+            self.passes_remaining -= 1
+        elif 0 < self.passes_remaining < (self.args.vq_passes / 2):
+            x_quantized, _, commit_loss, _ = self.vq(x)
+            x = torch.lerp(x, x_quantized, (self.passes_remaining - self.passes) / self.passes)
+            self.passes_remaining -= 1
+        else:
+            x, _, commit_loss, _ = self.vq(x)
+        qloss = commit_loss.squeeze()
+
+        return x, qloss
+
+    def loss(self):
+        raise NotImplementedError
+
+    def embed(self, words, feats=None):#feats=None
+        ext_words = words
+        # set the indices larger than num_embeddings to unk_index
+        if hasattr(self, 'pretrained'):
+            ext_mask = words.ge(self.word_embed.num_embeddings)
+            ext_words = words.masked_fill(ext_mask, self.args.unk_index)
+
+        # get outputs from embedding layers
+        word_embed = self.word_embed(ext_words)
+        if hasattr(self, 'pretrained'):
+            pretrained = self.pretrained(words)
+            if self.args.n_embed == self.args.n_pretrained:
+                word_embed += pretrained
+            else:
+                word_embed = torch.cat((word_embed, self.embed_proj(pretrained)), -1)
+        feat_embed = []
+
+        if 'tag' in self.args.feat:
+            feat_embed.append(self.tag_embed(feats.pop(0)))
+        if 'char' in self.args.feat:
+          
+            feat_embed.append(self.char_embed(feats.pop(0)))
+        if 'elmo' in self.args.feat:
+            feat_embed.append(self.elmo_embed(feats.pop(0)))
+        if 'bert' in self.args.feat:
+            feat_embed.append(self.bert_embed(feats.pop(0)))
+        if 'lemma' in self.args.feat:
+            feat_embed.append(self.lemma_embed(feats.pop(0)))
+        if isinstance(self.embed_dropout, IndependentDropout):
+            if len(feat_embed) == 0:
+                raise RuntimeError(f"`feat` is not allowed to be empty, which is {self.args.feat} now")
+            embed = torch.cat(self.embed_dropout(word_embed, torch.cat(feat_embed, -1)), -1)
+        else:
+            embed = word_embed
+            if len(feat_embed) > 0:
+                embed = torch.cat((embed, torch.cat(feat_embed, -1)), -1)
+            embed = self.embed_dropout(embed)
+        return embed
+
+    def encode(self, words, feats):#=None):
+        if self.args.encoder == 'lstm':
+            x = pack_padded_sequence(self.embed(words, feats), words.ne(self.args.pad_index).sum(1).tolist(), True, False)
+
+            x, _ = self.encoder(x)
+            x, _ = pad_packed_sequence(x, True, total_length=words.shape[1])
+        elif self.args.encoder == 'transformer':
+            x = self.encoder(self.embed(words, feats), words.ne(self.args.pad_index))
+        else:
+            x = self.encoder(words)
+        return self.encoder_dropout(x)
+
+    def decode(self):
+        raise NotImplementedError
+
+    def vq_forward(self, x: torch.Tensor):
+        if not self.args.use_vq:
+            return x, torch.tensor(0)
+
+        if self.passes_remaining > (self.args.vq_passes / 2):
+            _, _, commit_loss = self.vq(x)
+            self.passes_remaining -= 1
+        elif 0 < self.passes_remaining < (self.args.vq_passes / 2):
+            x_quantized, _, commit_loss = self.vq(x)
+            x = torch.lerp(x, x_quantized, (self.passes_remaining - self.passes) / self.passes)
+            self.passes_remaining -= 1
+        else:
+            x, indices, commit_loss  = self.vq(x)
+        qloss = commit_loss.squeeze()
+        return x, qloss
+    
+    @property 
+    def device(self):
+        if self.args.device == 'cpu':
+            return 'cpu'
+        else:
+            return f'cuda:{self.args.device}'
diff --git a/tania_scripts/supar/models/.ipynb_checkpoints/__init__-checkpoint.py b/tania_scripts/supar/models/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54ae9fad6a968ecbdb484b334c299ce00f07d8e
--- /dev/null
+++ b/tania_scripts/supar/models/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+
+from .const import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos, CRFConstituencyParser,
+                    TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser)
+from .dep import (BiaffineDependencyParser, CRF2oDependencyParser,
+                  CRFDependencyParser, VIDependencyParser,
+                  SLDependencyParser, ArcEagerDependencyParser)
+from .sdp import BiaffineSemanticDependencyParser, VISemanticDependencyParser
+
+__all__ = ['BiaffineDependencyParser',
+           'CRFDependencyParser',
+           'CRF2oDependencyParser',
+           'VIDependencyParser',
+           'AttachJuxtaposeConstituencyParser',
+           'CRFConstituencyParser',
+           'TetraTaggingConstituencyParser',
+           'VIConstituencyParser',
+           'SLConstituentParser',
+           'BiaffineSemanticDependencyParser',
+           'VISemanticDependencyParser',
+           'SLDependencyParser',
+           'ArcEagerDependencyParser']
\ No newline at end of file
diff --git a/tania_scripts/supar/models/__init__.py b/tania_scripts/supar/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54ae9fad6a968ecbdb484b334c299ce00f07d8e
--- /dev/null
+++ b/tania_scripts/supar/models/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+
+from .const import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos, CRFConstituencyParser,
+                    TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser)
+from .dep import (BiaffineDependencyParser, CRF2oDependencyParser,
+                  CRFDependencyParser, VIDependencyParser,
+                  SLDependencyParser, ArcEagerDependencyParser)
+from .sdp import BiaffineSemanticDependencyParser, VISemanticDependencyParser
+
+__all__ = ['BiaffineDependencyParser',
+           'CRFDependencyParser',
+           'CRF2oDependencyParser',
+           'VIDependencyParser',
+           'AttachJuxtaposeConstituencyParser',
+           'CRFConstituencyParser',
+           'TetraTaggingConstituencyParser',
+           'VIConstituencyParser',
+           'SLConstituentParser',
+           'BiaffineSemanticDependencyParser',
+           'VISemanticDependencyParser',
+           'SLDependencyParser',
+           'ArcEagerDependencyParser']
\ No newline at end of file
diff --git a/tania_scripts/supar/models/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c56ad26f856966acabe985c3781a35035bb4d79
Binary files /dev/null and b/tania_scripts/supar/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9521805749c757cbf900e8dbd4ebb1339359dd88
Binary files /dev/null and b/tania_scripts/supar/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ca2e6217fc9a781a4fd1f5df3b0ce2c9a6ddf79
Binary files /dev/null and b/tania_scripts/supar/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/tania_scripts/supar/models/const/__init__.py b/tania_scripts/supar/models/const/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a15c165338e852b5d02ed348db763d43da09628
--- /dev/null
+++ b/tania_scripts/supar/models/const/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+
+from .aj import (AttachJuxtaposeConstituencyModel,AttachJuxtaposeConstituencyModelPos,
+                 AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos)
+from .crf import CRFConstituencyModel, CRFConstituencyParser
+from .tt import TetraTaggingConstituencyModel, TetraTaggingConstituencyParser
+from .vi import VIConstituencyModel, VIConstituencyParser
+from .sl import SLConstituentParser, SLConstituentModel
+
+__all__ = ['AttachJuxtaposeConstituencyModel', 'AttachJuxtaposeConstituencyModelPos',   'AttachJuxtaposeConstituencyParser', 'AttachJuxtaposeConstituencyParserPos',
+           'CRFConstituencyModel', 'CRFConstituencyParser',
+           'TetraTaggingConstituencyModel', 'TetraTaggingConstituencyParser',
+           'VIConstituencyModel', 'VIConstituencyParser',
+           'SLConstituentModel', 'SLConstituentParser']
diff --git a/tania_scripts/supar/models/const/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d250b33146033de8f6e222867a4f01c6f2d8a54
Binary files /dev/null and b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24326643d82fa6033432ee9bbd7ebf061b74a4d0
Binary files /dev/null and b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19a59b8eb7ae83aa8e4089626004ec916d1b1ec8
Binary files /dev/null and b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-39.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/__init__-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe4855b790afea3b677caaaa7de5f6ac3549216
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos
+from .parser import AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos
+
+__all__ = ['AttachJuxtaposeConstituencyModel', 'AttachJuxtaposeConstituencyModelPos', 'AttachJuxtaposeConstituencyParser', 'AttachJuxtaposeConstituencyParserPos']
diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/model-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/model-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..30de07c6de4c6e6c60b32a31a94aa0a29040d8ff
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/model-checkpoint.py
@@ -0,0 +1,786 @@
+# -*- coding: utf-8 -*-
+
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from supar.model import Model
+from supar.models.const.aj.transform import AttachJuxtaposeTree
+from supar.modules import GraphConvolutionalNetwork, MLP, DecoderLSTM
+from supar.utils import Config
+from supar.utils.common import INF
+from supar.utils.fn import pad
+
+class DecoderLSTMPos(nn.Module):
+    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, device):
+        super().__init__()
+        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
+                            batch_first=True, dropout=dropout)
+        self.classifier = nn.Linear(hidden_dim, output_dim)
+
+    def forward(self, x):
+        # x: [batch_size, seq_len, input_dim]
+        output, _ = self.lstm(x)
+        logits = self.classifier(output)
+        return logits
+
+class AttachJuxtaposeConstituencyModel(Model):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layers. Default: .33.
+        n_gnn_layers (int):
+            The number of GNN layers. Default: 3.
+        gnn_dropout (float):
+            The dropout ratio of GNN layers. Default: .33.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, True),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_gnn_layers=3,
+                 gnn_dropout=.33,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        # the last one represents the dummy node in the initial states
+        self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden)
+        self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden,
+                                                    n_layers=self.args.n_gnn_layers,
+                                                    dropout=self.args.gnn_dropout)
+
+        self.node_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 1),
+        )
+        self.label_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels),
+        )
+
+        # create delay projection
+        if self.args.delay != 0:
+            self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1),
+                                       n_out=self.args.n_encoder_hidden, dropout=gnn_dropout)
+
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(
+        self,
+        words: torch.LongTensor,
+        feats: List[torch.LongTensor]
+    ) -> Tuple[torch.Tensor]:
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor:
+                Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input.
+        """
+        x = self.encode(words, feats)
+
+        # adjust lengths to allow delay predictions
+        if self.args.delay != 0:
+            x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i)] for i in range(self.args.delay + 1)], dim=2)
+            x = self.delay_proj(x)
+
+        # pass through vector quantization
+        x, qloss = self.vq_forward(x)
+
+        return x, qloss
+
+    def loss(
+        self,
+        x: torch.Tensor,
+        nodes: torch.LongTensor,
+        parents: torch.LongTensor,
+        news: torch.LongTensor,
+        mask: torch.BoolTensor
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            nodes (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The target node positions on rightmost chains.
+            parents (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of terminals.
+            news (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of juxtaposed targets and terminals.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        spans, s_node, x_node = None, [], []
+        actions = torch.stack((nodes, parents, news))
+        for t, action in enumerate(actions.unbind(-1)):
+            if t == 0:
+                x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)
+            s_node.append(self.node_classifier(x_rightmost).squeeze(-1))
+            # we found softmax is slightly better than sigmoid in the original paper
+            s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0)
+            x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1))
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+        attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index)
+        s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1)
+        s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1)
+        s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1)
+        s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1)
+        node_loss = self.criterion(s_node[mask], nodes[mask])
+        label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask])
+        return node_loss + label_loss
+
+    def decode(
+        self,
+        x: torch.Tensor,
+        mask: torch.BoolTensor,
+        beam_size: int = 1
+    ) -> List[List[Tuple]]:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+            beam_size (int):
+                Beam size for decoding. Default: 1.
+
+        Returns:
+            List[List[Tuple]]:
+                Sequences of factorized labeled trees.
+        """
+        tokenwise_predictions = []
+        spans = None
+        batch_size, *_ = x.shape
+        n_labels = self.args.n_labels
+        # [batch_size * beam_size, ...]
+        x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:])
+        mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:])
+        # [batch_size]
+        batches = x.new_tensor(range(batch_size)).long() * beam_size
+        # accumulated scores
+        scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1)
+        for t in range(x.shape[1]):
+            if t == 0:
+                x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1)
+            s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1)
+            # we found softmax is slightly better than sigmoid in the original paper
+            x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1)
+            s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1)
+            s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1)
+            if t == 0:
+                s_parent[:, self.args.nul_index] = -INF
+                s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF
+            s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1)
+            s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1)
+            s_new, news = s_new.topk(min(n_labels, beam_size), -1)
+            s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1)
+            s_action = s_action.view(x.shape[0], -1)
+            k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1]
+            # [batch_size * beam_size, k_beam]
+            scores = scores.unsqueeze(-1) + s_action
+            # [batch_size, beam_size]
+            scores, cands = scores.view(batch_size, -1).topk(beam_size, -1)
+            # [batch_size * beam_size]
+            scores = scores.view(-1)
+            beams = cands.div(k_beam, rounding_mode='floor')
+            nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor'))
+            indices = (batches.unsqueeze(-1) + beams).view(-1)
+            
+            #print('indices', indices)
+            parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent)
+            news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent)
+            action = torch.stack((nodes, parents, news)).view(3, -1)
+            tokenwise_predictions.append([t, [x[0] for x in action.tolist()]])
+            spans = spans[indices] if spans is not None else None
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+            #print("SPANS", spans)
+        mask = mask.view(batch_size, beam_size, -1)[:, 0]
+        # select an 1-best tree for each sentence
+        spans = spans[batches + scores.view(batch_size, -1).argmax(-1)]
+        span_mask = spans.ge(0)
+        span_indices = torch.where(span_mask)
+        span_labels = spans[span_indices]
+        chart_preds = [[] for _ in range(x.shape[0])]
+        for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()):
+            chart_preds[i].append(span)
+        kk = [chart_preds + tokenwise_predictions]
+        return kk
+
+    def rightmost_chain(
+        self,
+        x: torch.Tensor,
+        spans: torch.LongTensor,
+        mask: torch.BoolTensor,
+        t: int
+    ) -> torch.Tensor:
+        x_p, mask_p = x[:, :t], mask[:, :t]
+        lens = mask_p.sum(-1)
+        span_mask = spans[:, :-1, 1:].ge(0)
+        span_lens = span_mask.sum((-1, -2))
+        span_indices = torch.where(span_mask)
+        span_labels = spans[:, :-1, 1:][span_indices]
+        x_span = self.label_embed(span_labels)
+        x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]]
+        node_lens = lens + span_lens
+        adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max())))
+        x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1])))
+        span_mask = ~x_mask & adj_mask
+        # concatenate terminals and spans
+        x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p])
+        x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span)
+        adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1])
+        adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1)
+        adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:]))
+        adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0]
+        adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2))
+        # set the parent of root as itself
+        adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1))
+        adj_parent = adj_parent & span_mask.unsqueeze(1)
+        # closet ancestor spans as parents
+        adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1)
+        adj.scatter_(-1, adj_parent.unsqueeze(-1), 1)
+        adj = (adj | adj.transpose(-1, -2)).float()
+        x_tree = self.gnn_layers(x_tree, adj, adj_mask)
+        span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1))
+        span_lens = span_mask.sum(-1)
+        x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+        x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree)
+        return x_span
+
+
+class AttachJuxtaposeConstituencyModelPos(Model):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layers. Default: .33.
+        n_gnn_layers (int):
+            The number of GNN layers. Default: 3.
+        gnn_dropout (float):
+            The dropout ratio of GNN layers. Default: .33.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=16,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char', 'tag'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, True),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_gnn_layers=3,
+                 gnn_dropout=.33,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        # the last one represents the dummy node in the initial states
+        self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden)
+        self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden,
+                                                    n_layers=self.args.n_gnn_layers,
+                                                    dropout=self.args.gnn_dropout)
+
+        self.node_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 1),
+        )
+        self.label_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels),
+        )
+        self.pos_classifier = DecoderLSTMPos(
+                self.args.n_encoder_hidden, self.args.n_encoder_hidden, self.args.n_tags, 
+                num_layers=1, dropout=encoder_dropout, device=self.device
+            )
+        
+        #self.pos_tagger = nn.Identity()
+        # create delay projection
+        if self.args.delay != 0:
+            self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1),
+                                       n_out=self.args.n_encoder_hidden, dropout=gnn_dropout)
+
+        self.criterion = nn.CrossEntropyLoss()
+        
+    def encoder_forward(self, words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]:
+        """
+        Applies encoding forward pass. Maps a tensor of word indices (`words`) to their corresponding neural
+        representation.
+        Args:
+            words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay]
+            feats: List[torch.Tensor]
+            lens: List[int]
+
+        Returns: x, qloss
+            x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim]
+            qloss: torch.FloatTensor ~ 1
+
+        """
+
+        x = super().encode(words, feats)
+        s_tag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :])
+
+        # adjust lengths to allow delay predictions
+        # x ~ [batch_size, bos + pad(seq_len) + eos, embed_dim]
+        if self.args.delay != 0:
+            x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay + 1)], dim=2)
+            x = self.delay_proj(x)
+
+        # pass through vector quantization
+        x, qloss = self.vq_forward(x)
+        return x, s_tag, qloss
+    
+    
+    def forward(
+        self,
+        words: torch.LongTensor,
+        feats: List[torch.LongTensor]
+    ) -> Tuple[torch.Tensor]:
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor:
+                Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input.
+        """
+        x, s_tag, qloss = self.encoder_forward(words, feats)
+
+        return x, s_tag, qloss
+
+    def loss(
+        self,
+        x: torch.Tensor,
+        nodes: torch.LongTensor,
+        parents: torch.LongTensor,
+        news: torch.LongTensor,
+        mask: torch.BoolTensor, s_tags: torch.LongTensor, tags: torch.LongTensor
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            nodes (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The target node positions on rightmost chains.
+            parents (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of terminals.
+            news (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of juxtaposed targets and terminals.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        spans, s_node, x_node = None, [], []
+        actions = torch.stack((nodes, parents, news))
+        for t, action in enumerate(actions.unbind(-1)):
+            if t == 0:
+                x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)
+            s_node.append(self.node_classifier(x_rightmost).squeeze(-1))
+            # we found softmax is slightly better than sigmoid in the original paper
+            s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0)
+            x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1))
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+        attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index)
+        s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1)
+        s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1)
+        #s_postag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :]).chunk(2, -1)
+        s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1)
+        s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1)
+        node_loss = self.criterion(s_node[mask], nodes[mask])
+        #print('node loss', node_loss)
+        label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask])
+        #print('label loss', label_loss)
+
+        #print(s_tag[mask].shape, tags[mask].shape)
+        tag_loss = self.criterion(s_tags[mask], tags[mask])
+        #print('tag loss', tag_loss)
+        #tag_loss = self.pos_loss(s_tags, tags, mask)
+        print("node loss, label loss, tag loss", node_loss, label_loss, tag_loss, node_loss + label_loss + tag_loss)
+        return node_loss + label_loss + tag_loss
+
+    def decode(
+        self,
+        x: torch.Tensor,
+        mask: torch.BoolTensor,
+        beam_size: int = 1
+    ) -> List[List[Tuple]]:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+            beam_size (int):
+                Beam size for decoding. Default: 1.
+
+        Returns:
+            List[List[Tuple]]:
+                Sequences of factorized labeled trees.
+        """
+        tokenwise_predictions = []
+        spans = None
+        batch_size, *_ = x.shape
+        n_labels = self.args.n_labels
+        # [batch_size * beam_size, ...]
+        x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:])
+        mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:])
+        # [batch_size]
+        batches = x.new_tensor(range(batch_size)).long() * beam_size
+        # accumulated scores
+        scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1)
+        for t in range(x.shape[1]):
+            if t == 0:
+                x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1)
+            s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1)
+            # we found softmax is slightly better than sigmoid in the original paper
+            x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1)
+            s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1)
+            s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1)
+            if t == 0:
+                s_parent[:, self.args.nul_index] = -INF
+                s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF
+            s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1)
+            s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1)
+            s_new, news = s_new.topk(min(n_labels, beam_size), -1)
+            s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1)
+            s_action = s_action.view(x.shape[0], -1)
+            k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1]
+            # [batch_size * beam_size, k_beam]
+            scores = scores.unsqueeze(-1) + s_action
+            # [batch_size, beam_size]
+            scores, cands = scores.view(batch_size, -1).topk(beam_size, -1)
+            # [batch_size * beam_size]
+            scores = scores.view(-1)
+            beams = cands.div(k_beam, rounding_mode='floor')
+            nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor'))
+            indices = (batches.unsqueeze(-1) + beams).view(-1)
+            
+            #print('indices', indices)
+            parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent)
+            news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent)
+            action = torch.stack((nodes, parents, news)).view(3, -1)
+            tokenwise_predictions.append([t, [x[0] for x in action.tolist()]])
+            spans = spans[indices] if spans is not None else None
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+            #print("SPANS", spans)
+        mask = mask.view(batch_size, beam_size, -1)[:, 0]
+        # select an 1-best tree for each sentence
+        spans = spans[batches + scores.view(batch_size, -1).argmax(-1)]
+        span_mask = spans.ge(0)
+        span_indices = torch.where(span_mask)
+        span_labels = spans[span_indices]
+        chart_preds = [[] for _ in range(x.shape[0])]
+        for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()):
+            chart_preds[i].append(span)
+        kk = [chart_preds + tokenwise_predictions]
+        return kk
+
+    def rightmost_chain(
+        self,
+        x: torch.Tensor,
+        spans: torch.LongTensor,
+        mask: torch.BoolTensor,
+        t: int
+    ) -> torch.Tensor:
+        x_p, mask_p = x[:, :t], mask[:, :t]
+        lens = mask_p.sum(-1)
+        span_mask = spans[:, :-1, 1:].ge(0)
+        span_lens = span_mask.sum((-1, -2))
+        span_indices = torch.where(span_mask)
+        span_labels = spans[:, :-1, 1:][span_indices]
+        x_span = self.label_embed(span_labels)
+        x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]]
+        node_lens = lens + span_lens
+        adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max())))
+        x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1])))
+        span_mask = ~x_mask & adj_mask
+        # concatenate terminals and spans
+        x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p])
+        x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span)
+        adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1])
+        adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1)
+        adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:]))
+        adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0]
+        adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2))
+        # set the parent of root as itself
+        adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1))
+        adj_parent = adj_parent & span_mask.unsqueeze(1)
+        # closet ancestor spans as parents
+        adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1)
+        adj.scatter_(-1, adj_parent.unsqueeze(-1), 1)
+        adj = (adj | adj.transpose(-1, -2)).float()
+        x_tree = self.gnn_layers(x_tree, adj, adj_mask)
+        span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1))
+        span_lens = span_mask.sum(-1)
+        x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+        x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree)
+        return x_span
+
+    
+    def pos_loss(self, pos_logits: torch.Tensor, pos_tags: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor:
+        """
+        Args:
+            pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags].
+            pos_tags (~torch.LongTensor): [batch_size, seq_len].
+            mask (~torch.BoolTensor): [batch_size, seq_len].
+
+        Returns:
+            torch.Tensor: The POS tagging loss.
+        """
+        loss_fn = nn.CrossEntropyLoss()
+        return loss_fn(pos_logits[mask], pos_tags[mask])
+
+    def decode_pos(self, s_tag: torch.Tensor):
+        """
+        Decode the most likely POS tags.
+
+        Args:
+            pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags]
+            mask (~torch.BoolTensor): [batch_size, seq_len]
+
+        Returns:
+            List[List[int]]: POS tags per token for each sentence in the batch.
+        """
+        pos_preds = pos_logits.argmax(-1)
+        #return [seq[mask[i]].tolist() for i, seq in enumerate(pos_preds)]
+        return pos_preds
diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/parser-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/parser-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..543e8284b90805f045e1d5de2b15aa7c75ffc417
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/parser-checkpoint.py
@@ -0,0 +1,534 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Dict, Iterable, Set, Union
+
+import torch
+
+from supar.models.const.aj.model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos
+from supar.models.const.aj.transform import AttachJuxtaposeTree
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, EOS, NUL, PAD, UNK
+from supar.utils.field import Field, RawField, SubwordField
+from supar.utils.logging import get_logger
+from supar.utils.metric import SpanMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+from torch.nn.utils.rnn import pad_sequence
+
+
+logger = get_logger(__name__)
+
+
+def compute_pos_accuracy(pos_gold, pos_preds):
+    correct = 0
+    total = 0
+    for gold_seq, pred_seq in zip(pos_gold, pos_preds):
+        for g, p in zip(gold_seq, pred_seq):
+            if g == p:
+                correct += 1
+        total += len(gold_seq)
+    return correct, total
+
+
+    
+
+class AttachJuxtaposeConstituencyParser(Parser):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+    """
+
+    NAME = 'attach-juxtapose-constituency'
+    MODEL = AttachJuxtaposeConstituencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TREE = self.transform.TREE
+        self.NODE = self.transform.NODE
+        self.PARENT = self.transform.PARENT
+        self.NEW = self.transform.NEW
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        print("here")
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        verbose: bool = True,
+        **kwargs
+    ):
+
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        #print("TRAIN STEP")
+        words, *feats, trees, nodes, parents, news = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        #print("s_tag", s_tag)
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        #print("EVAL STEP")
+        words, *feats, trees, nodes, parents, news = batch
+        #print("WORDS", words.shape, words)
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, qloss = self.model(words, feats)
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #print("CHART PREDS")
+        #print("self new vocab", self.NEW.vocab.items())
+        #print()
+        preds = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+                 for tree, chart in zip(trees, chart_preds)]
+
+        for tree, chart in zip(trees, chart_preds):
+            print(tree, chart)
+
+        print()
+        for tree in trees:
+            print("ORIG TREE", tree)
+        print()
+        for tree in preds:
+            print("PRED TREE", tree)
+
+        print("=========================")
+
+        return SpanMetric(loss,
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds],
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees])
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats, trees = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, _ = self.model(words, feats)
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        chart_preds = chart_preds[0]
+        batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+                       for tree, chart in zip(trees, chart_preds)]
+        if self.args.prob:
+            raise NotImplementedError("Returning action probs are currently not supported yet.")
+
+        new_tokenwise_preds = []
+        for k, y in chart_preds[1:]:
+            new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]])
+
+        chart_preds = [[[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]]] + new_tokenwise_preds
+        #for x in chart_preds[1:]:
+        #    new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]
+
+        return chart_preds #batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR = None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            pad_token = t.pad if t.pad else PAD
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay)
+        TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+        TREE = RawField('trees')
+        NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK)
+        transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if CHAR is not None:
+                CHAR.build(train)
+        TAG.build(train)
+        PARENT, NEW = PARENT.build(train), NEW.build(train)
+        PARENT.vocab = NEW.vocab.update(PARENT.vocab)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_labels': len(NEW.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index,
+            'eos_index': WORD.eos_index,
+            'nul_index': NEW.vocab[NUL]
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
+
+def flatten(x):
+    result = []
+    for el in x:
+        if isinstance(el, list):
+            result.extend(flatten(el))
+        else:
+            result.append(el)
+    return result
+
+
+
+
+class AttachJuxtaposeConstituencyParserPos(Parser):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+    """
+
+    NAME = 'attach-juxtapose-constituency'
+    MODEL = AttachJuxtaposeConstituencyModelPos
+
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TREE = self.transform.TREE
+        self.NODE = self.transform.NODE
+        self.PARENT = self.transform.PARENT
+        self.NEW = self.transform.NEW
+        self.TAG = self.transform.POS
+        
+        #print(self.TAG)
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        verbose: bool = True,
+        **kwargs
+    ):
+
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        #print("TRAIN STEP")
+        words, *feats, postags, nodes, parents, news = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        tags = feats[-1]
+        #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits
+        #loss = loss.mean()
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss
+        
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        
+        #print("EVAL STEP")
+        words, *feats, trees, nodes, parents, news = batch
+
+        tags = feats[-1]
+        #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits
+        #loss = loss.mean()
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #print('CHART PREDS', chart_preds[0])
+
+        pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)]
+        trimmed_tags = [tag_seq[1:-1] for tag_seq in tags]
+        trimmed_tags_list = [x.cpu().tolist() for x in trimmed_tags]
+        pos_gold = [[self.TAG.vocab[p] for p in sent] for sent in trimmed_tags_list]
+
+        pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds)
+        pos_acc = pos_correct/pos_total
+        print(pos_acc)
+  
+        """
+        tags = [*feats[-1:]]
+        trimmed_tags = [tag_seq[1:-1] for tag_seq in tags[0]]
+        gtags = [*feats[-1:]][0].cpu().tolist()
+        gtags = [x[1:-1] for x in gtags]
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+
+        tag_tensor = pad_sequence(trimmed_tags, batch_first=True, padding_value=self.TAG.pad_index).to(s_tag.device)
+
+        assert s_tag.shape[:2] == tag_tensor.shape[:2]
+        print("s_tag shape", s_tag.shape[:2], "tag_tensor shape", tag_tensor.shape[:2])
+        
+
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tag_tensor) + qloss
+
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)]
+
+        for tag_seq in gtags:
+            for pos in tag_seq:
+                assert 0 <= pos < len(self.TAG.vocab)
+        
+        pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)]
+        pos_gold = [[self.TAG.vocab[pos] for pos in tag_seq] for tag_seq in gtags]
+
+        #print("357!", len(pos_gold), len(pos_preds))
+        pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds)
+        print("358!", pos_correct, pos_total, pos_correct/pos_total)
+
+            
+        #print('%%% POS preds', pos_preds)
+        #print('%%%% POS gold', pos_gold)
+        """
+
+        #batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+         #              for tree, chart in zip(trees, chart_preds)]
+        #print(batch.trees)
+              
+        preds = [AttachJuxtaposeTree.build(tree, [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart], {UNK, NUL})
+                 for tree, chart in zip(trees, chart_preds[0])]
+        #print('POS preds', s_tag.argmax(-1))
+        #print("PREDS", len(preds))
+      
+
+        """
+        span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0][0]]
+        #print("new SPAN preds", span_preds)
+        verbalized_preds = []
+
+
+        for x in chart_preds[1:]:
+            new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]]
+            #print(new_item)
+            verbalized_preds.append(new_item)
+        """
+        
+        print("424", loss, pos_acc)
+        return SpanMetric(loss,
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds],
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees], pos_acc)
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats, trees = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #print('VHART PREDS: ', chart_preds)
+        #print('LEN VHART PREDS: ', len(chart_preds))
+
+        chart_preds = chart_preds[0]
+        batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+                       for tree, chart in zip(trees, chart_preds)]
+        if self.args.prob:
+            raise NotImplementedError("Returning action probs are currently not supported yet.")
+
+        new_tokenwise_preds = []
+        #for pred in chart_preds:
+        #    print("PRED", pred)
+        #    new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]])
+
+
+        span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]] + new_tokenwise_preds
+        #print("new chart preds", span_preds)
+        verbalized_preds = []
+        for x in chart_preds[1:]:
+            new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]]
+            #print(new_item)
+            verbalized_preds.append(new_item)
+        #print('POS preds', s_tag.argmax(-1))
+        pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)]
+        return span_preds, verbalized_preds, pos_preds #batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR = None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            pad_token = t.pad if t.pad else PAD
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay)
+        TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+        TREE = RawField('trees')
+        NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK)
+        transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if CHAR is not None:
+                CHAR.build(train)
+        TAG.build(train)
+        #print("203 TAG VOCAB", [x for x in TAG.vocab.items()])
+        PARENT, NEW = PARENT.build(train), NEW.build(train)
+        PARENT.vocab = NEW.vocab.update(PARENT.vocab)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_labels': len(NEW.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index,
+            'eos_index': WORD.eos_index,
+            'nul_index': NEW.vocab[NUL]
+        })
+        logger.info(f"{transform}")
+        
+        #print('TAG VOCAB', TAG.vocab.items())
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
\ No newline at end of file
diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/transform-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/transform-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/transform-checkpoint.py
@@ -0,0 +1,459 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
+
+import nltk
+import torch
+
+from supar.models.const.crf.transform import Tree
+from supar.utils.common import NUL
+from supar.utils.logging import get_logger
+from supar.utils.tokenizer import Tokenizer
+from supar.utils.transform import Sentence
+
+if TYPE_CHECKING:
+    from supar.utils import Field
+
+logger = get_logger(__name__)
+
+
+class AttachJuxtaposeTree(Tree):
+    r"""
+    :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class,
+    supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`.
+
+    Attributes:
+        WORD:
+            Words in the sentence.
+        POS:
+            Part-of-speech tags, or underscores if not available.
+        TREE:
+            The raw constituency tree in :class:`nltk.tree.Tree` format.
+        NODE:
+            The target node on each rightmost chain.
+        PARENT:
+            The label of the parent node of each terminal.
+        NEW:
+            The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children.
+            ``NUL`` represents the `Attach` action.
+    """
+
+    fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW']
+
+    def __init__(
+        self,
+        WORD: Optional[Union[Field, Iterable[Field]]] = None,
+        POS: Optional[Union[Field, Iterable[Field]]] = None,
+        TREE: Optional[Union[Field, Iterable[Field]]] = None,
+        NODE: Optional[Union[Field, Iterable[Field]]] = None,
+        PARENT: Optional[Union[Field, Iterable[Field]]] = None,
+        NEW: Optional[Union[Field, Iterable[Field]]] = None
+    ) -> Tree:
+        super().__init__()
+
+        self.WORD = WORD
+        self.POS = POS
+        self.TREE = TREE
+        self.NODE = NODE
+        self.PARENT = PARENT
+        self.NEW = NEW
+
+    @property
+    def src(self):
+        return self.WORD, self.POS, self.TREE
+
+    @property
+    def tgt(self):
+        return self.NODE, self.PARENT, self.NEW
+
+    @classmethod
+    def tree2action(cls, tree: nltk.Tree):
+        r"""
+        Converts a constituency tree into AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                A constituency tree in :class:`nltk.tree.Tree` format.
+
+        Returns:
+            A sequence of AttachJuxtapose actions.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = nltk.Tree.fromstring('''
+                                            (TOP
+                                              (S
+                                                (NP (_ Arthur))
+                                                (VP
+                                                  (_ is)
+                                                  (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons)))))
+                                                (_ .)))
+                                            ''')
+            >>> tree.pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+            >>> AttachJuxtaposeTree.tree2action(tree)
+            [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+             (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+             (0, '<nul>', '<nul>')]
+        """
+
+        def isroot(node):
+            return node == tree[0]
+
+        def isterminal(node):
+            return len(node) == 1 and not isinstance(node[0], nltk.Tree)
+
+        def last_leaf(node):
+            pos = ()
+            while True:
+                pos += (len(node) - 1,)
+                node = node[-1]
+                if isterminal(node):
+                    return node, pos
+
+        def parent(position):
+            return tree[position[:-1]]
+
+        def grand(position):
+            return tree[position[:-2]]
+
+        def detach(tree):
+            last, last_pos = last_leaf(tree)
+            siblings = parent(last_pos)[:-1]
+
+            if len(siblings) > 0:
+                last_subtree = last
+                last_subtree_siblings = siblings
+                parent_label = NUL
+            else:
+                last_subtree, last_pos = parent(last_pos), last_pos[:-1]
+                last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1]
+                parent_label = last_subtree.label()
+
+            target_pos, new_label, last_tree = 0, NUL, tree
+            if isroot(last_subtree):
+                last_tree = None
+            
+            elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]):
+                new_label = parent(last_pos).label()
+                new_label = new_label
+                target = last_subtree_siblings[0]
+                last_grand = grand(last_pos)
+                if last_grand is None:
+                    last_tree = targetistermina
+                else:
+                    last_grand[-1] = target
+                target_pos = len(last_pos) - 2
+            else:
+                target = parent(last_pos)
+                target.pop()
+                target_pos = len(last_pos) - 2
+            action = target_pos, parent_label, new_label
+            return action, last_tree
+        if tree is None:
+            return []
+        action, last_tree = detach(tree)
+        return cls.tree2action(last_tree) + [action]
+
+    @classmethod
+    def action2tree(
+        cls,
+        tree: nltk.Tree,
+        actions: List[Tuple[int, str, str]],
+        join: str = '::',
+    ) -> nltk.Tree:
+        r"""
+        Recovers a constituency tree from a sequence of AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                An empty tree that provides a base for building a result tree.
+            actions (List[Tuple[int, str, str]]):
+                A sequence of AttachJuxtapose actions.
+            join (str):
+                A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains.
+                Default: ``'::'``.
+
+        Returns:
+            A result constituency tree.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.action2tree(tree,
+                                                [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+                                                 (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+                                                 (0, '<nul>', '<nul>')]).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+        """
+
+        def target(node, depth):
+            node_pos = ()
+            for _ in range(depth):
+                node_pos += (len(node) - 1,)
+                node = node[-1]
+            return node, node_pos
+
+        def parent(tree, position):
+            return tree[position[:-1]]
+
+        def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree:
+            target_pos, parent_label, new_label, post = action
+            #print(target_pos, parent_label, new_label)
+            new_leaf = nltk.Tree(post, [terminal[0]])
+
+            # create the subtree to be inserted
+            new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf])
+            # find the target position at which to insert the new subtree
+            target_node = tree
+            if target_node is not None:
+                target_node, target_pos = target(target_node, target_pos)
+
+            # Attach
+            if new_label == NUL:
+                # attach the first token
+                if target_node is None:
+                    return new_subtree
+                target_node.append(new_subtree)
+            # Juxtapose
+            else:
+                new_subtree = nltk.Tree(new_label, [target_node, new_subtree])
+                if len(target_pos) > 0:
+                    parent_node = parent(tree, target_pos)
+                    parent_node[-1] = new_subtree
+                else:
+                    tree = new_subtree
+            return tree
+
+        tree, root, terminals = None, tree.label(), tree.pos()
+        for terminal, action in zip(terminals, actions):
+            tree = execute(tree, terminal, action)
+        # recover unary chains
+        nodes = [tree]
+        while nodes:
+            node = nodes.pop()
+            if isinstance(node, nltk.Tree):
+                nodes.extend(node)
+                if join in node.label():
+                    labels = node.label().split(join)
+                    node.set_label(labels[0])
+                    subtree = nltk.Tree(labels[-1], node)
+                    for label in reversed(labels[1:-1]):
+                        subtree = nltk.Tree(label, [subtree])
+                    node[:] = [subtree]
+        return nltk.Tree(root, [tree])
+
+    @classmethod
+    def action2span(
+        cls,
+        action: torch.Tensor,
+        spans: torch.Tensor = None,
+        nul_index: int = -1,
+        mask: torch.BoolTensor = None
+    ) -> torch.Tensor:
+        r"""
+        Converts a batch of the tensorized action at a given step into spans.
+
+        Args:
+            action (~torch.Tensor): ``[3, batch_size]``.
+                A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels.
+            spans (~torch.Tensor):
+                Spans generated at previous steps, ``None`` at the first step. Default: ``None``.
+            nul_index (int):
+                The index for the obj:`NUL` token, representing the Attach action. Default: -1.
+            mask (~torch.BoolTensor): ``[batch_size]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            A tensor representing a batch of spans for the given step.
+
+        Examples:
+            >>> from collections import Counter
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab
+            >>> from supar.utils.common import NUL
+            >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL),
+                                             (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL),
+                                             (0, NUL, NUL)])
+            >>> vocab = Vocab(Counter(sorted(set([*parents, *news]))))
+            >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1)
+            >>> spans = None
+            >>> for action in actions.unbind(-1):
+            ...     spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL])
+            ...
+            >>> spans
+            tensor([[[-1,  1, -1, -1, -1, -1, -1,  3],
+                     [-1, -1, -1, -1, -1, -1,  4, -1],
+                     [-1, -1, -1,  1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1,  2, -1],
+                     [-1, -1, -1, -1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1]]])
+            >>> sequence = torch.where(spans.ge(0))
+            >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]]))
+            >>> sequence
+            [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')]
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+
+        """
+
+        # [batch_size]
+        target, parent, new = action
+        if spans is None:
+            spans = action.new_full((action.shape[1], 2, 2), -1)
+            spans[:, 0, 1] = parent
+            return spans
+        if mask is None:
+            mask = torch.ones_like(target, dtype=bool)
+        juxtapose_mask = new.ne(nul_index) & mask
+        # ancestor nodes are those on the rightmost chain and higher than the target node
+        # [batch_size, seq_len]
+        rightmost_mask = spans[..., -1].ge(0)
+        ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1
+        # should not include the target node for the Juxtapose action
+        ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1))
+        target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1]
+        # the right boundaries of ancestor nodes should be aligned with the new generated terminals
+        spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1)
+        spans[..., -2].masked_fill_(ancestor_mask, -1)
+        spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask]
+        spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask]
+        # [batch_size, seq_len+1, seq_len+1]
+        spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1)
+        return spans
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        **kwargs
+    ) -> List[AttachJuxtaposeTreeSentence]:
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                A filename or a list of instances.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+
+        Returns:
+            A list of :class:`AttachJuxtaposeTreeSentence` instances.
+        """
+
+        if lang is not None:
+            tokenizer = Tokenizer(lang)
+        if isinstance(data, str) and os.path.exists(data):
+            if data.endswith('.txt'):
+                data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1)
+            else:
+                data = open(data)
+        else:
+            if lang is not None:
+                data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+
+        index = 0
+        for s in data:
+
+            try:
+                tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            except ValueError:
+                logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!")
+                continue
+            except IndexError:
+                tree = nltk.Tree.fromstring('(S ' + s + ')')
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            else:
+                yield sentence
+                index += 1
+        self.root = tree.label()
+
+
+class AttachJuxtaposeTreeSentence(Sentence):
+    r"""
+    Args:
+        transform (AttachJuxtaposeTree):
+            A :class:`AttachJuxtaposeTree` object.
+        tree (nltk.tree.Tree):
+            A :class:`nltk.tree.Tree` object.
+        index (Optional[int]):
+            Index of the sentence in the corpus. Default: ``None``.
+    """
+
+    def __init__(
+        self,
+        transform: AttachJuxtaposeTree,
+        tree: nltk.Tree,
+        index: Optional[int] = None
+    ) -> AttachJuxtaposeTreeSentence:
+        super().__init__(transform, index)
+
+        words, tags = zip(*tree.pos())
+        nodes, parents, news = None, None, None
+        if transform.training:
+            oracle_tree = tree.copy(True)
+            # the root node must have a unary chain
+            if len(oracle_tree) > 1:
+                oracle_tree[:] = [nltk.Tree('*', oracle_tree)]
+            oracle_tree.collapse_unary(joinChar='::')
+            if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree):
+                oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]])
+            nodes, parents, news = zip(*transform.tree2action(oracle_tree))
+        tags = [x.split("##")[0] for x in tags]
+        self.values = [words, tags, tree, nodes, parents, news]
+
+    def __repr__(self):
+        return self.values[-4].pformat(1000000)
+
+    def pretty_print(self):
+        self.values[-4].pretty_print()
diff --git a/tania_scripts/supar/models/const/aj/__init__.py b/tania_scripts/supar/models/const/aj/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe4855b790afea3b677caaaa7de5f6ac3549216
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos
+from .parser import AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos
+
+__all__ = ['AttachJuxtaposeConstituencyModel', 'AttachJuxtaposeConstituencyModelPos', 'AttachJuxtaposeConstituencyParser', 'AttachJuxtaposeConstituencyParserPos']
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf09a38138e561426e88c74ad582913e27d79d0e
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bd79cb4caac4fa72e4cd50659aee9fc13dcc3d4
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb0340d7a5173f1782e452ef95c2faf9025e674d
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-39.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4da15ae58bc418dda6f19b8e44874d984ba30fda
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a76052d5e93b7b2ce0a33eae863662d4a4a1438
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-39.pyc b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d40ac4f8b95a668781a063d3ee4aa591fe26927
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-39.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..080b0dba88ea7c05edcde2a12f2d5e17b2756b70
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b58212f55501521f49a41f39da967e6efed3566
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..731814f0b914ad90ae5e604f6a0883b85690c09e
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aba0194d23a0b3d6f8024b35018c2b00be0c9752
Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/aj/model.py b/tania_scripts/supar/models/const/aj/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f4a8ce0b5ab9a6a35e8adb18eecb6b8cacecd81
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/model.py
@@ -0,0 +1,786 @@
+# -*- coding: utf-8 -*-
+
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from supar.model import Model
+from supar.models.const.aj.transform import AttachJuxtaposeTree
+from supar.modules import GraphConvolutionalNetwork, MLP, DecoderLSTM
+from supar.utils import Config
+from supar.utils.common import INF
+from supar.utils.fn import pad
+
+class DecoderLSTMPos(nn.Module):
+    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, device):
+        super().__init__()
+        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
+                            batch_first=True, dropout=dropout)
+        self.classifier = nn.Linear(hidden_dim, output_dim)
+
+    def forward(self, x):
+        # x: [batch_size, seq_len, input_dim]
+        output, _ = self.lstm(x)
+        logits = self.classifier(output)
+        return logits
+
+class AttachJuxtaposeConstituencyModel(Model):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layers. Default: .33.
+        n_gnn_layers (int):
+            The number of GNN layers. Default: 3.
+        gnn_dropout (float):
+            The dropout ratio of GNN layers. Default: .33.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, True),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_gnn_layers=3,
+                 gnn_dropout=.33,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        # the last one represents the dummy node in the initial states
+        self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden)
+        self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden,
+                                                    n_layers=self.args.n_gnn_layers,
+                                                    dropout=self.args.gnn_dropout)
+
+        self.node_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 1),
+        )
+        self.label_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels),
+        )
+
+        # create delay projection
+        if self.args.delay != 0:
+            self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1),
+                                       n_out=self.args.n_encoder_hidden, dropout=gnn_dropout)
+
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(
+        self,
+        words: torch.LongTensor,
+        feats: List[torch.LongTensor]
+    ) -> Tuple[torch.Tensor]:
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor:
+                Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input.
+        """
+        x = self.encode(words, feats)
+
+        # adjust lengths to allow delay predictions
+        if self.args.delay != 0:
+            x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i)] for i in range(self.args.delay + 1)], dim=2)
+            x = self.delay_proj(x)
+
+        # pass through vector quantization
+        x, qloss = self.vq_forward(x)
+
+        return x, qloss
+
+    def loss(
+        self,
+        x: torch.Tensor,
+        nodes: torch.LongTensor,
+        parents: torch.LongTensor,
+        news: torch.LongTensor,
+        mask: torch.BoolTensor
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            nodes (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The target node positions on rightmost chains.
+            parents (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of terminals.
+            news (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of juxtaposed targets and terminals.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        spans, s_node, x_node = None, [], []
+        actions = torch.stack((nodes, parents, news))
+        for t, action in enumerate(actions.unbind(-1)):
+            if t == 0:
+                x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)
+            s_node.append(self.node_classifier(x_rightmost).squeeze(-1))
+            # we found softmax is slightly better than sigmoid in the original paper
+            s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0)
+            x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1))
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+        attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index)
+        s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1)
+        s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1)
+        s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1)
+        s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1)
+        node_loss = self.criterion(s_node[mask], nodes[mask])
+        label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask])
+        return node_loss + label_loss
+
+    def decode(
+        self,
+        x: torch.Tensor,
+        mask: torch.BoolTensor,
+        beam_size: int = 1
+    ) -> List[List[Tuple]]:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+            beam_size (int):
+                Beam size for decoding. Default: 1.
+
+        Returns:
+            List[List[Tuple]]:
+                Sequences of factorized labeled trees.
+        """
+        tokenwise_predictions = []
+        spans = None
+        batch_size, *_ = x.shape
+        n_labels = self.args.n_labels
+        # [batch_size * beam_size, ...]
+        x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:])
+        mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:])
+        # [batch_size]
+        batches = x.new_tensor(range(batch_size)).long() * beam_size
+        # accumulated scores
+        scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1)
+        for t in range(x.shape[1]):
+            if t == 0:
+                x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1)
+            s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1)
+            # we found softmax is slightly better than sigmoid in the original paper
+            x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1)
+            s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1)
+            s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1)
+            if t == 0:
+                s_parent[:, self.args.nul_index] = -INF
+                s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF
+            s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1)
+            s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1)
+            s_new, news = s_new.topk(min(n_labels, beam_size), -1)
+            s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1)
+            s_action = s_action.view(x.shape[0], -1)
+            k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1]
+            # [batch_size * beam_size, k_beam]
+            scores = scores.unsqueeze(-1) + s_action
+            # [batch_size, beam_size]
+            scores, cands = scores.view(batch_size, -1).topk(beam_size, -1)
+            # [batch_size * beam_size]
+            scores = scores.view(-1)
+            beams = cands.div(k_beam, rounding_mode='floor')
+            nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor'))
+            indices = (batches.unsqueeze(-1) + beams).view(-1)
+            
+            #print('indices', indices)
+            parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent)
+            news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent)
+            action = torch.stack((nodes, parents, news)).view(3, -1)
+            tokenwise_predictions.append([t, [x[0] for x in action.tolist()]])
+            spans = spans[indices] if spans is not None else None
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+            #print("SPANS", spans)
+        mask = mask.view(batch_size, beam_size, -1)[:, 0]
+        # select an 1-best tree for each sentence
+        spans = spans[batches + scores.view(batch_size, -1).argmax(-1)]
+        span_mask = spans.ge(0)
+        span_indices = torch.where(span_mask)
+        span_labels = spans[span_indices]
+        chart_preds = [[] for _ in range(x.shape[0])]
+        for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()):
+            chart_preds[i].append(span)
+        kk = [chart_preds + tokenwise_predictions]
+        return kk
+
+    def rightmost_chain(
+        self,
+        x: torch.Tensor,
+        spans: torch.LongTensor,
+        mask: torch.BoolTensor,
+        t: int
+    ) -> torch.Tensor:
+        x_p, mask_p = x[:, :t], mask[:, :t]
+        lens = mask_p.sum(-1)
+        span_mask = spans[:, :-1, 1:].ge(0)
+        span_lens = span_mask.sum((-1, -2))
+        span_indices = torch.where(span_mask)
+        span_labels = spans[:, :-1, 1:][span_indices]
+        x_span = self.label_embed(span_labels)
+        x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]]
+        node_lens = lens + span_lens
+        adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max())))
+        x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1])))
+        span_mask = ~x_mask & adj_mask
+        # concatenate terminals and spans
+        x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p])
+        x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span)
+        adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1])
+        adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1)
+        adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:]))
+        adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0]
+        adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2))
+        # set the parent of root as itself
+        adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1))
+        adj_parent = adj_parent & span_mask.unsqueeze(1)
+        # closet ancestor spans as parents
+        adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1)
+        adj.scatter_(-1, adj_parent.unsqueeze(-1), 1)
+        adj = (adj | adj.transpose(-1, -2)).float()
+        x_tree = self.gnn_layers(x_tree, adj, adj_mask)
+        span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1))
+        span_lens = span_mask.sum(-1)
+        x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+        x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree)
+        return x_span
+
+
+class AttachJuxtaposeConstituencyModelPos(Model):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layers. Default: .33.
+        n_gnn_layers (int):
+            The number of GNN layers. Default: 3.
+        gnn_dropout (float):
+            The dropout ratio of GNN layers. Default: .33.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=16,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char', 'tag'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, True),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_gnn_layers=3,
+                 gnn_dropout=.33,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        # the last one represents the dummy node in the initial states
+        self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden)
+        self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden,
+                                                    n_layers=self.args.n_gnn_layers,
+                                                    dropout=self.args.gnn_dropout)
+
+        self.node_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 1),
+        )
+        self.label_classifier = nn.Sequential(
+            nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
+            nn.LayerNorm(self.args.n_encoder_hidden // 2),
+            nn.ReLU(),
+            nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels),
+        )
+        self.pos_classifier = DecoderLSTMPos(
+                self.args.n_encoder_hidden, self.args.n_encoder_hidden, self.args.n_tags, 
+                num_layers=1, dropout=encoder_dropout, device=self.device
+            )
+        
+        #self.pos_tagger = nn.Identity()
+        # create delay projection
+        if self.args.delay != 0:
+            self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1),
+                                       n_out=self.args.n_encoder_hidden, dropout=gnn_dropout)
+
+        self.criterion = nn.CrossEntropyLoss()
+        
+    def encoder_forward(self, words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]:
+        """
+        Applies encoding forward pass. Maps a tensor of word indices (`words`) to their corresponding neural
+        representation.
+        Args:
+            words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay]
+            feats: List[torch.Tensor]
+            lens: List[int]
+
+        Returns: x, qloss
+            x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim]
+            qloss: torch.FloatTensor ~ 1
+
+        """
+
+        x = super().encode(words, feats)
+        s_tag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :])
+
+        # adjust lengths to allow delay predictions
+        # x ~ [batch_size, bos + pad(seq_len) + eos, embed_dim]
+        if self.args.delay != 0:
+            x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay + 1)], dim=2)
+            x = self.delay_proj(x)
+
+        # pass through vector quantization
+        x, qloss = self.vq_forward(x)
+        return x, s_tag, qloss
+    
+    
+    def forward(
+        self,
+        words: torch.LongTensor,
+        feats: List[torch.LongTensor]
+    ) -> Tuple[torch.Tensor]:
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor:
+                Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input.
+        """
+        x, s_tag, qloss = self.encoder_forward(words, feats)
+
+        return x, s_tag, qloss
+
+    def loss(
+        self,
+        x: torch.Tensor,
+        nodes: torch.LongTensor,
+        parents: torch.LongTensor,
+        news: torch.LongTensor,
+        mask: torch.BoolTensor, s_tags: torch.LongTensor, tags: torch.LongTensor
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            nodes (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The target node positions on rightmost chains.
+            parents (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of terminals.
+            news (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The parent node labels of juxtaposed targets and terminals.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        spans, s_node, x_node = None, [], []
+        actions = torch.stack((nodes, parents, news))
+        for t, action in enumerate(actions.unbind(-1)):
+            if t == 0:
+                x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)
+            s_node.append(self.node_classifier(x_rightmost).squeeze(-1))
+            # we found softmax is slightly better than sigmoid in the original paper
+            s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0)
+            x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1))
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+        attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index)
+        s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1)
+        s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1)
+        #s_postag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :]).chunk(2, -1)
+        s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1)
+        s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1)
+        node_loss = self.criterion(s_node[mask], nodes[mask])
+        #print('node loss', node_loss)
+        label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask])
+        #print('label loss', label_loss)
+
+        #print(s_tag[mask].shape, tags[mask].shape)
+        tag_loss = self.criterion(s_tags[mask], tags[mask])
+        #print('tag loss', tag_loss)
+        #tag_loss = self.pos_loss(s_tags, tags, mask)
+        #print("node loss, label loss, tag loss", node_loss, label_loss, tag_loss, node_loss + label_loss + tag_loss)
+        return node_loss + label_loss + tag_loss
+
+    def decode(
+        self,
+        x: torch.Tensor,
+        mask: torch.BoolTensor,
+        beam_size: int = 1
+    ) -> List[List[Tuple]]:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
+                Contextualized output hidden states.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+            beam_size (int):
+                Beam size for decoding. Default: 1.
+
+        Returns:
+            List[List[Tuple]]:
+                Sequences of factorized labeled trees.
+        """
+        tokenwise_predictions = []
+        spans = None
+        batch_size, *_ = x.shape
+        n_labels = self.args.n_labels
+        # [batch_size * beam_size, ...]
+        x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:])
+        mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:])
+        # [batch_size]
+        batches = x.new_tensor(range(batch_size)).long() * beam_size
+        # accumulated scores
+        scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1)
+        for t in range(x.shape[1]):
+            if t == 0:
+                x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels))
+                span_mask = mask[:, :1]
+            else:
+                x_span = self.rightmost_chain(x, spans, mask, t)
+                span_lens = spans[:, :-1, -1].ge(0).sum(-1)
+                span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+            s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1)
+            s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1)
+            # we found softmax is slightly better than sigmoid in the original paper
+            x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1)
+            s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1)
+            s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1)
+            if t == 0:
+                s_parent[:, self.args.nul_index] = -INF
+                s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF
+            s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1)
+            s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1)
+            s_new, news = s_new.topk(min(n_labels, beam_size), -1)
+            s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1)
+            s_action = s_action.view(x.shape[0], -1)
+            k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1]
+            # [batch_size * beam_size, k_beam]
+            scores = scores.unsqueeze(-1) + s_action
+            # [batch_size, beam_size]
+            scores, cands = scores.view(batch_size, -1).topk(beam_size, -1)
+            # [batch_size * beam_size]
+            scores = scores.view(-1)
+            beams = cands.div(k_beam, rounding_mode='floor')
+            nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor'))
+            indices = (batches.unsqueeze(-1) + beams).view(-1)
+            
+            #print('indices', indices)
+            parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent)
+            news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent)
+            action = torch.stack((nodes, parents, news)).view(3, -1)
+            tokenwise_predictions.append([t, [x[0] for x in action.tolist()]])
+            spans = spans[indices] if spans is not None else None
+            spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
+            #print("SPANS", spans)
+        mask = mask.view(batch_size, beam_size, -1)[:, 0]
+        # select an 1-best tree for each sentence
+        spans = spans[batches + scores.view(batch_size, -1).argmax(-1)]
+        span_mask = spans.ge(0)
+        span_indices = torch.where(span_mask)
+        span_labels = spans[span_indices]
+        chart_preds = [[] for _ in range(x.shape[0])]
+        for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()):
+            chart_preds[i].append(span)
+        kk = [chart_preds + tokenwise_predictions]
+        return kk
+
+    def rightmost_chain(
+        self,
+        x: torch.Tensor,
+        spans: torch.LongTensor,
+        mask: torch.BoolTensor,
+        t: int
+    ) -> torch.Tensor:
+        x_p, mask_p = x[:, :t], mask[:, :t]
+        lens = mask_p.sum(-1)
+        span_mask = spans[:, :-1, 1:].ge(0)
+        span_lens = span_mask.sum((-1, -2))
+        span_indices = torch.where(span_mask)
+        span_labels = spans[:, :-1, 1:][span_indices]
+        x_span = self.label_embed(span_labels)
+        x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]]
+        node_lens = lens + span_lens
+        adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max())))
+        x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1])))
+        span_mask = ~x_mask & adj_mask
+        # concatenate terminals and spans
+        x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p])
+        x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span)
+        adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1])
+        adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1)
+        adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:]))
+        adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0]
+        adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2))
+        # set the parent of root as itself
+        adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1))
+        adj_parent = adj_parent & span_mask.unsqueeze(1)
+        # closet ancestor spans as parents
+        adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1)
+        adj.scatter_(-1, adj_parent.unsqueeze(-1), 1)
+        adj = (adj | adj.transpose(-1, -2)).float()
+        x_tree = self.gnn_layers(x_tree, adj, adj_mask)
+        span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1))
+        span_lens = span_mask.sum(-1)
+        x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
+        x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree)
+        return x_span
+
+    
+    def pos_loss(self, pos_logits: torch.Tensor, pos_tags: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor:
+        """
+        Args:
+            pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags].
+            pos_tags (~torch.LongTensor): [batch_size, seq_len].
+            mask (~torch.BoolTensor): [batch_size, seq_len].
+
+        Returns:
+            torch.Tensor: The POS tagging loss.
+        """
+        loss_fn = nn.CrossEntropyLoss()
+        return loss_fn(pos_logits[mask], pos_tags[mask])
+
+    def decode_pos(self, s_tag: torch.Tensor):
+        """
+        Decode the most likely POS tags.
+
+        Args:
+            pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags]
+            mask (~torch.BoolTensor): [batch_size, seq_len]
+
+        Returns:
+            List[List[int]]: POS tags per token for each sentence in the batch.
+        """
+        pos_preds = pos_logits.argmax(-1)
+        #return [seq[mask[i]].tolist() for i, seq in enumerate(pos_preds)]
+        return pos_preds
diff --git a/tania_scripts/supar/models/const/aj/parser.py b/tania_scripts/supar/models/const/aj/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..543e8284b90805f045e1d5de2b15aa7c75ffc417
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/parser.py
@@ -0,0 +1,534 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Dict, Iterable, Set, Union
+
+import torch
+
+from supar.models.const.aj.model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos
+from supar.models.const.aj.transform import AttachJuxtaposeTree
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, EOS, NUL, PAD, UNK
+from supar.utils.field import Field, RawField, SubwordField
+from supar.utils.logging import get_logger
+from supar.utils.metric import SpanMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+from torch.nn.utils.rnn import pad_sequence
+
+
+logger = get_logger(__name__)
+
+
+def compute_pos_accuracy(pos_gold, pos_preds):
+    correct = 0
+    total = 0
+    for gold_seq, pred_seq in zip(pos_gold, pos_preds):
+        for g, p in zip(gold_seq, pred_seq):
+            if g == p:
+                correct += 1
+        total += len(gold_seq)
+    return correct, total
+
+
+    
+
+class AttachJuxtaposeConstituencyParser(Parser):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+    """
+
+    NAME = 'attach-juxtapose-constituency'
+    MODEL = AttachJuxtaposeConstituencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TREE = self.transform.TREE
+        self.NODE = self.transform.NODE
+        self.PARENT = self.transform.PARENT
+        self.NEW = self.transform.NEW
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        print("here")
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        verbose: bool = True,
+        **kwargs
+    ):
+
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        #print("TRAIN STEP")
+        words, *feats, trees, nodes, parents, news = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        #print("s_tag", s_tag)
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        #print("EVAL STEP")
+        words, *feats, trees, nodes, parents, news = batch
+        #print("WORDS", words.shape, words)
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, qloss = self.model(words, feats)
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #print("CHART PREDS")
+        #print("self new vocab", self.NEW.vocab.items())
+        #print()
+        preds = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+                 for tree, chart in zip(trees, chart_preds)]
+
+        for tree, chart in zip(trees, chart_preds):
+            print(tree, chart)
+
+        print()
+        for tree in trees:
+            print("ORIG TREE", tree)
+        print()
+        for tree in preds:
+            print("PRED TREE", tree)
+
+        print("=========================")
+
+        return SpanMetric(loss,
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds],
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees])
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats, trees = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, _ = self.model(words, feats)
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        chart_preds = chart_preds[0]
+        batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+                       for tree, chart in zip(trees, chart_preds)]
+        if self.args.prob:
+            raise NotImplementedError("Returning action probs are currently not supported yet.")
+
+        new_tokenwise_preds = []
+        for k, y in chart_preds[1:]:
+            new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]])
+
+        chart_preds = [[[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]]] + new_tokenwise_preds
+        #for x in chart_preds[1:]:
+        #    new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]
+
+        return chart_preds #batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR = None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            pad_token = t.pad if t.pad else PAD
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay)
+        TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+        TREE = RawField('trees')
+        NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK)
+        transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if CHAR is not None:
+                CHAR.build(train)
+        TAG.build(train)
+        PARENT, NEW = PARENT.build(train), NEW.build(train)
+        PARENT.vocab = NEW.vocab.update(PARENT.vocab)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_labels': len(NEW.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index,
+            'eos_index': WORD.eos_index,
+            'nul_index': NEW.vocab[NUL]
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
+
+def flatten(x):
+    result = []
+    for el in x:
+        if isinstance(el, list):
+            result.extend(flatten(el))
+        else:
+            result.append(el)
+    return result
+
+
+
+
+class AttachJuxtaposeConstituencyParserPos(Parser):
+    r"""
+    The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
+    """
+
+    NAME = 'attach-juxtapose-constituency'
+    MODEL = AttachJuxtaposeConstituencyModelPos
+
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TREE = self.transform.TREE
+        self.NODE = self.transform.NODE
+        self.PARENT = self.transform.PARENT
+        self.NEW = self.transform.NEW
+        self.TAG = self.transform.POS
+        
+        #print(self.TAG)
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        beam_size: int = 1,
+        verbose: bool = True,
+        **kwargs
+    ):
+
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        #print("TRAIN STEP")
+        words, *feats, postags, nodes, parents, news = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        tags = feats[-1]
+        #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits
+        #loss = loss.mean()
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss
+        
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        
+        #print("EVAL STEP")
+        words, *feats, trees, nodes, parents, news = batch
+
+        tags = feats[-1]
+        #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits
+        #loss = loss.mean()
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #print('CHART PREDS', chart_preds[0])
+
+        pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)]
+        trimmed_tags = [tag_seq[1:-1] for tag_seq in tags]
+        trimmed_tags_list = [x.cpu().tolist() for x in trimmed_tags]
+        pos_gold = [[self.TAG.vocab[p] for p in sent] for sent in trimmed_tags_list]
+
+        pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds)
+        pos_acc = pos_correct/pos_total
+        print(pos_acc)
+  
+        """
+        tags = [*feats[-1:]]
+        trimmed_tags = [tag_seq[1:-1] for tag_seq in tags[0]]
+        gtags = [*feats[-1:]][0].cpu().tolist()
+        gtags = [x[1:-1] for x in gtags]
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+
+        tag_tensor = pad_sequence(trimmed_tags, batch_first=True, padding_value=self.TAG.pad_index).to(s_tag.device)
+
+        assert s_tag.shape[:2] == tag_tensor.shape[:2]
+        print("s_tag shape", s_tag.shape[:2], "tag_tensor shape", tag_tensor.shape[:2])
+        
+
+        loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tag_tensor) + qloss
+
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)]
+
+        for tag_seq in gtags:
+            for pos in tag_seq:
+                assert 0 <= pos < len(self.TAG.vocab)
+        
+        pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)]
+        pos_gold = [[self.TAG.vocab[pos] for pos in tag_seq] for tag_seq in gtags]
+
+        #print("357!", len(pos_gold), len(pos_preds))
+        pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds)
+        print("358!", pos_correct, pos_total, pos_correct/pos_total)
+
+            
+        #print('%%% POS preds', pos_preds)
+        #print('%%%% POS gold', pos_gold)
+        """
+
+        #batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+         #              for tree, chart in zip(trees, chart_preds)]
+        #print(batch.trees)
+              
+        preds = [AttachJuxtaposeTree.build(tree, [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart], {UNK, NUL})
+                 for tree, chart in zip(trees, chart_preds[0])]
+        #print('POS preds', s_tag.argmax(-1))
+        #print("PREDS", len(preds))
+      
+
+        """
+        span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0][0]]
+        #print("new SPAN preds", span_preds)
+        verbalized_preds = []
+
+
+        for x in chart_preds[1:]:
+            new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]]
+            #print(new_item)
+            verbalized_preds.append(new_item)
+        """
+        
+        print("424", loss, pos_acc)
+        return SpanMetric(loss,
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds],
+                          [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees], pos_acc)
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats, trees = batch
+        mask = batch.mask[:, (2+self.args.delay):]
+        x, s_tag, qloss = self.model(words, feats)
+        chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size)
+        #print('VHART PREDS: ', chart_preds)
+        #print('LEN VHART PREDS: ', len(chart_preds))
+
+        chart_preds = chart_preds[0]
+        batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL})
+                       for tree, chart in zip(trees, chart_preds)]
+        if self.args.prob:
+            raise NotImplementedError("Returning action probs are currently not supported yet.")
+
+        new_tokenwise_preds = []
+        #for pred in chart_preds:
+        #    print("PRED", pred)
+        #    new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]])
+
+
+        span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]] + new_tokenwise_preds
+        #print("new chart preds", span_preds)
+        verbalized_preds = []
+        for x in chart_preds[1:]:
+            new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]]
+            #print(new_item)
+            verbalized_preds.append(new_item)
+        #print('POS preds', s_tag.argmax(-1))
+        pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)]
+        return span_preds, verbalized_preds, pos_preds #batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR = None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            pad_token = t.pad if t.pad else PAD
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay)
+        TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+        TREE = RawField('trees')
+        NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK)
+        transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if CHAR is not None:
+                CHAR.build(train)
+        TAG.build(train)
+        #print("203 TAG VOCAB", [x for x in TAG.vocab.items()])
+        PARENT, NEW = PARENT.build(train), NEW.build(train)
+        PARENT.vocab = NEW.vocab.update(PARENT.vocab)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_labels': len(NEW.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index,
+            'eos_index': WORD.eos_index,
+            'nul_index': NEW.vocab[NUL]
+        })
+        logger.info(f"{transform}")
+        
+        #print('TAG VOCAB', TAG.vocab.items())
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
\ No newline at end of file
diff --git a/tania_scripts/supar/models/const/aj/transform.py b/tania_scripts/supar/models/const/aj/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6
--- /dev/null
+++ b/tania_scripts/supar/models/const/aj/transform.py
@@ -0,0 +1,459 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
+
+import nltk
+import torch
+
+from supar.models.const.crf.transform import Tree
+from supar.utils.common import NUL
+from supar.utils.logging import get_logger
+from supar.utils.tokenizer import Tokenizer
+from supar.utils.transform import Sentence
+
+if TYPE_CHECKING:
+    from supar.utils import Field
+
+logger = get_logger(__name__)
+
+
+class AttachJuxtaposeTree(Tree):
+    r"""
+    :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class,
+    supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`.
+
+    Attributes:
+        WORD:
+            Words in the sentence.
+        POS:
+            Part-of-speech tags, or underscores if not available.
+        TREE:
+            The raw constituency tree in :class:`nltk.tree.Tree` format.
+        NODE:
+            The target node on each rightmost chain.
+        PARENT:
+            The label of the parent node of each terminal.
+        NEW:
+            The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children.
+            ``NUL`` represents the `Attach` action.
+    """
+
+    fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW']
+
+    def __init__(
+        self,
+        WORD: Optional[Union[Field, Iterable[Field]]] = None,
+        POS: Optional[Union[Field, Iterable[Field]]] = None,
+        TREE: Optional[Union[Field, Iterable[Field]]] = None,
+        NODE: Optional[Union[Field, Iterable[Field]]] = None,
+        PARENT: Optional[Union[Field, Iterable[Field]]] = None,
+        NEW: Optional[Union[Field, Iterable[Field]]] = None
+    ) -> Tree:
+        super().__init__()
+
+        self.WORD = WORD
+        self.POS = POS
+        self.TREE = TREE
+        self.NODE = NODE
+        self.PARENT = PARENT
+        self.NEW = NEW
+
+    @property
+    def src(self):
+        return self.WORD, self.POS, self.TREE
+
+    @property
+    def tgt(self):
+        return self.NODE, self.PARENT, self.NEW
+
+    @classmethod
+    def tree2action(cls, tree: nltk.Tree):
+        r"""
+        Converts a constituency tree into AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                A constituency tree in :class:`nltk.tree.Tree` format.
+
+        Returns:
+            A sequence of AttachJuxtapose actions.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = nltk.Tree.fromstring('''
+                                            (TOP
+                                              (S
+                                                (NP (_ Arthur))
+                                                (VP
+                                                  (_ is)
+                                                  (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons)))))
+                                                (_ .)))
+                                            ''')
+            >>> tree.pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+            >>> AttachJuxtaposeTree.tree2action(tree)
+            [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+             (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+             (0, '<nul>', '<nul>')]
+        """
+
+        def isroot(node):
+            return node == tree[0]
+
+        def isterminal(node):
+            return len(node) == 1 and not isinstance(node[0], nltk.Tree)
+
+        def last_leaf(node):
+            pos = ()
+            while True:
+                pos += (len(node) - 1,)
+                node = node[-1]
+                if isterminal(node):
+                    return node, pos
+
+        def parent(position):
+            return tree[position[:-1]]
+
+        def grand(position):
+            return tree[position[:-2]]
+
+        def detach(tree):
+            last, last_pos = last_leaf(tree)
+            siblings = parent(last_pos)[:-1]
+
+            if len(siblings) > 0:
+                last_subtree = last
+                last_subtree_siblings = siblings
+                parent_label = NUL
+            else:
+                last_subtree, last_pos = parent(last_pos), last_pos[:-1]
+                last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1]
+                parent_label = last_subtree.label()
+
+            target_pos, new_label, last_tree = 0, NUL, tree
+            if isroot(last_subtree):
+                last_tree = None
+            
+            elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]):
+                new_label = parent(last_pos).label()
+                new_label = new_label
+                target = last_subtree_siblings[0]
+                last_grand = grand(last_pos)
+                if last_grand is None:
+                    last_tree = targetistermina
+                else:
+                    last_grand[-1] = target
+                target_pos = len(last_pos) - 2
+            else:
+                target = parent(last_pos)
+                target.pop()
+                target_pos = len(last_pos) - 2
+            action = target_pos, parent_label, new_label
+            return action, last_tree
+        if tree is None:
+            return []
+        action, last_tree = detach(tree)
+        return cls.tree2action(last_tree) + [action]
+
+    @classmethod
+    def action2tree(
+        cls,
+        tree: nltk.Tree,
+        actions: List[Tuple[int, str, str]],
+        join: str = '::',
+    ) -> nltk.Tree:
+        r"""
+        Recovers a constituency tree from a sequence of AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                An empty tree that provides a base for building a result tree.
+            actions (List[Tuple[int, str, str]]):
+                A sequence of AttachJuxtapose actions.
+            join (str):
+                A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains.
+                Default: ``'::'``.
+
+        Returns:
+            A result constituency tree.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.action2tree(tree,
+                                                [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+                                                 (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+                                                 (0, '<nul>', '<nul>')]).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+        """
+
+        def target(node, depth):
+            node_pos = ()
+            for _ in range(depth):
+                node_pos += (len(node) - 1,)
+                node = node[-1]
+            return node, node_pos
+
+        def parent(tree, position):
+            return tree[position[:-1]]
+
+        def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree:
+            target_pos, parent_label, new_label, post = action
+            #print(target_pos, parent_label, new_label)
+            new_leaf = nltk.Tree(post, [terminal[0]])
+
+            # create the subtree to be inserted
+            new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf])
+            # find the target position at which to insert the new subtree
+            target_node = tree
+            if target_node is not None:
+                target_node, target_pos = target(target_node, target_pos)
+
+            # Attach
+            if new_label == NUL:
+                # attach the first token
+                if target_node is None:
+                    return new_subtree
+                target_node.append(new_subtree)
+            # Juxtapose
+            else:
+                new_subtree = nltk.Tree(new_label, [target_node, new_subtree])
+                if len(target_pos) > 0:
+                    parent_node = parent(tree, target_pos)
+                    parent_node[-1] = new_subtree
+                else:
+                    tree = new_subtree
+            return tree
+
+        tree, root, terminals = None, tree.label(), tree.pos()
+        for terminal, action in zip(terminals, actions):
+            tree = execute(tree, terminal, action)
+        # recover unary chains
+        nodes = [tree]
+        while nodes:
+            node = nodes.pop()
+            if isinstance(node, nltk.Tree):
+                nodes.extend(node)
+                if join in node.label():
+                    labels = node.label().split(join)
+                    node.set_label(labels[0])
+                    subtree = nltk.Tree(labels[-1], node)
+                    for label in reversed(labels[1:-1]):
+                        subtree = nltk.Tree(label, [subtree])
+                    node[:] = [subtree]
+        return nltk.Tree(root, [tree])
+
+    @classmethod
+    def action2span(
+        cls,
+        action: torch.Tensor,
+        spans: torch.Tensor = None,
+        nul_index: int = -1,
+        mask: torch.BoolTensor = None
+    ) -> torch.Tensor:
+        r"""
+        Converts a batch of the tensorized action at a given step into spans.
+
+        Args:
+            action (~torch.Tensor): ``[3, batch_size]``.
+                A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels.
+            spans (~torch.Tensor):
+                Spans generated at previous steps, ``None`` at the first step. Default: ``None``.
+            nul_index (int):
+                The index for the obj:`NUL` token, representing the Attach action. Default: -1.
+            mask (~torch.BoolTensor): ``[batch_size]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            A tensor representing a batch of spans for the given step.
+
+        Examples:
+            >>> from collections import Counter
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab
+            >>> from supar.utils.common import NUL
+            >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL),
+                                             (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL),
+                                             (0, NUL, NUL)])
+            >>> vocab = Vocab(Counter(sorted(set([*parents, *news]))))
+            >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1)
+            >>> spans = None
+            >>> for action in actions.unbind(-1):
+            ...     spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL])
+            ...
+            >>> spans
+            tensor([[[-1,  1, -1, -1, -1, -1, -1,  3],
+                     [-1, -1, -1, -1, -1, -1,  4, -1],
+                     [-1, -1, -1,  1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1,  2, -1],
+                     [-1, -1, -1, -1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1]]])
+            >>> sequence = torch.where(spans.ge(0))
+            >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]]))
+            >>> sequence
+            [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')]
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+
+        """
+
+        # [batch_size]
+        target, parent, new = action
+        if spans is None:
+            spans = action.new_full((action.shape[1], 2, 2), -1)
+            spans[:, 0, 1] = parent
+            return spans
+        if mask is None:
+            mask = torch.ones_like(target, dtype=bool)
+        juxtapose_mask = new.ne(nul_index) & mask
+        # ancestor nodes are those on the rightmost chain and higher than the target node
+        # [batch_size, seq_len]
+        rightmost_mask = spans[..., -1].ge(0)
+        ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1
+        # should not include the target node for the Juxtapose action
+        ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1))
+        target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1]
+        # the right boundaries of ancestor nodes should be aligned with the new generated terminals
+        spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1)
+        spans[..., -2].masked_fill_(ancestor_mask, -1)
+        spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask]
+        spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask]
+        # [batch_size, seq_len+1, seq_len+1]
+        spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1)
+        return spans
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        **kwargs
+    ) -> List[AttachJuxtaposeTreeSentence]:
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                A filename or a list of instances.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+
+        Returns:
+            A list of :class:`AttachJuxtaposeTreeSentence` instances.
+        """
+
+        if lang is not None:
+            tokenizer = Tokenizer(lang)
+        if isinstance(data, str) and os.path.exists(data):
+            if data.endswith('.txt'):
+                data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1)
+            else:
+                data = open(data)
+        else:
+            if lang is not None:
+                data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+
+        index = 0
+        for s in data:
+
+            try:
+                tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            except ValueError:
+                logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!")
+                continue
+            except IndexError:
+                tree = nltk.Tree.fromstring('(S ' + s + ')')
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            else:
+                yield sentence
+                index += 1
+        self.root = tree.label()
+
+
+class AttachJuxtaposeTreeSentence(Sentence):
+    r"""
+    Args:
+        transform (AttachJuxtaposeTree):
+            A :class:`AttachJuxtaposeTree` object.
+        tree (nltk.tree.Tree):
+            A :class:`nltk.tree.Tree` object.
+        index (Optional[int]):
+            Index of the sentence in the corpus. Default: ``None``.
+    """
+
+    def __init__(
+        self,
+        transform: AttachJuxtaposeTree,
+        tree: nltk.Tree,
+        index: Optional[int] = None
+    ) -> AttachJuxtaposeTreeSentence:
+        super().__init__(transform, index)
+
+        words, tags = zip(*tree.pos())
+        nodes, parents, news = None, None, None
+        if transform.training:
+            oracle_tree = tree.copy(True)
+            # the root node must have a unary chain
+            if len(oracle_tree) > 1:
+                oracle_tree[:] = [nltk.Tree('*', oracle_tree)]
+            oracle_tree.collapse_unary(joinChar='::')
+            if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree):
+                oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]])
+            nodes, parents, news = zip(*transform.tree2action(oracle_tree))
+        tags = [x.split("##")[0] for x in tags]
+        self.values = [words, tags, tree, nodes, parents, news]
+
+    def __repr__(self):
+        return self.values[-4].pformat(1000000)
+
+    def pretty_print(self):
+        self.values[-4].pretty_print()
diff --git a/tania_scripts/supar/models/const/crf/__init__.py b/tania_scripts/supar/models/const/crf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3a1e583e5a598110b37f745bb68b7f473f20149
--- /dev/null
+++ b/tania_scripts/supar/models/const/crf/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import CRFConstituencyModel
+from .parser import CRFConstituencyParser
+
+__all__ = ['CRFConstituencyModel', 'CRFConstituencyParser']
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b23ae21ab8db14bfbcb97e155b4203d8c61d82a
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ebb3eb81e33d0ac6cbfe3c6b1f467617348d959
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d01854d3d9e7b162c4bbd2903109ba1d1b0384a1
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ba33210adee6c2b20457972f93ac9ce0b2e68e0
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7e69785c24d5aba0b6e9363ddd28806b9f04cfa
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7282c0aa3c3e6e0aa427723393109f9b057bba50
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fec7bc1c3fc13a4209bb5c3f1be993b016f7cd7a
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..284c610e9e875040233cc1a2d0138345a2c0e4ae
Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/crf/model.py b/tania_scripts/supar/models/const/crf/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..03199e80df175292e83628962f576c11f9bf83be
--- /dev/null
+++ b/tania_scripts/supar/models/const/crf/model.py
@@ -0,0 +1,221 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from supar.model import Model
+from supar.modules import MLP, Biaffine
+from supar.structs import ConstituencyCRF
+from supar.utils import Config
+
+
+class CRFConstituencyModel(Model):
+    r"""
+    The implementation of CRF Constituency Parser :cite:`zhang-etal-2020-fast`,
+    also called FANCY (abbr. of Fast and Accurate Neural Crf constituencY) Parser.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_span_mlp (int):
+            Span MLP size. Default: 500.
+        n_label_mlp  (int):
+            Label MLP size. Default: 100.
+        mlp_dropout (float):
+            The dropout ratio of MLP layers. Default: .33.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, True),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_span_mlp=500,
+                 n_label_mlp=100,
+                 mlp_dropout=.33,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.span_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout)
+        self.span_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout)
+        self.label_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout)
+        self.label_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout)
+
+        self.span_attn = Biaffine(n_in=n_span_mlp, bias_x=True, bias_y=False)
+        self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words, feats=None):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible constituents.
+                The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
+                scores of all possible labels on each constituent.
+        """
+
+        x = self.encode(words, feats)
+
+        x_f, x_b = x.chunk(2, -1)
+        x = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1)
+
+        span_l = self.span_mlp_l(x)
+        span_r = self.span_mlp_r(x)
+        label_l = self.label_mlp_l(x)
+        label_r = self.label_mlp_r(x)
+
+        # [batch_size, seq_len, seq_len]
+        s_span = self.span_attn(span_l, span_r)
+        # [batch_size, seq_len, seq_len, n_labels]
+        s_label = self.label_attn(label_l, label_r).permute(0, 2, 3, 1)
+
+        return s_span, s_label
+
+    def loss(self, s_span, s_label, charts, mask, mbr=True):
+        r"""
+        Args:
+            s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all constituents.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all constituent labels.
+            charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                The tensor of gold-standard labels. Positions without labels are filled with -1.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+            mbr (bool):
+                If ``True``, returns marginals for MBR decoding. Default: ``True``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The training loss and original constituent scores
+                of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
+        """
+
+        span_mask = charts.ge(0) & mask
+        span_dist = ConstituencyCRF(s_span, mask[:, 0].sum(-1))
+        span_loss = -span_dist.log_prob(charts).sum() / mask[:, 0].sum()
+        span_probs = span_dist.marginals if mbr else s_span
+        label_loss = self.criterion(s_label[span_mask], charts[span_mask])
+        loss = span_loss + label_loss
+
+        return loss, span_probs
+
+    def decode(self, s_span, s_label, mask):
+        r"""
+        Args:
+            s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all constituents.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all constituent labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            List[List[Tuple]]:
+                Sequences of factorized labeled trees.
+        """
+
+        span_preds = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).argmax
+        label_preds = s_label.argmax(-1).tolist()
+        return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)]
diff --git a/tania_scripts/supar/models/const/crf/parser.py b/tania_scripts/supar/models/const/crf/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad5bd16b2da4d46f156b76fa4d64d97c5e88bcd9
--- /dev/null
+++ b/tania_scripts/supar/models/const/crf/parser.py
@@ -0,0 +1,205 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Dict, Iterable, Set, Union
+
+import torch
+
+from supar.models.const.crf.model import CRFConstituencyModel
+from supar.models.const.crf.transform import Tree
+from supar.parser import Parser
+from supar.structs import ConstituencyCRF
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, EOS, PAD, UNK
+from supar.utils.field import ChartField, Field, RawField, SubwordField
+from supar.utils.logging import get_logger
+from supar.utils.metric import SpanMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class CRFConstituencyParser(Parser):
+    r"""
+    The implementation of CRF Constituency Parser :cite:`zhang-etal-2020-fast`.
+    """
+
+    NAME = 'crf-constituency'
+    MODEL = CRFConstituencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TREE = self.transform.TREE
+        self.CHART = self.transform.CHART
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        mbr: bool = True,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        mbr: bool = True,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        mbr: bool = True,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, *feats, _, charts = batch
+        mask = batch.mask[:, 1:]
+        mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
+        s_span, s_label = self.model(words, feats)
+        loss, _ = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        words, *feats, trees, charts = batch
+        mask = batch.mask[:, 1:]
+        mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
+        s_span, s_label = self.model(words, feats)
+        loss, s_span = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
+        chart_preds = self.model.decode(s_span, s_label, mask)
+        preds = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
+                 for tree, chart in zip(trees, chart_preds)]
+        return SpanMetric(loss,
+                          [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds],
+                          [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees])
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats, trees = batch
+        mask, lens = batch.mask[:, 1:], batch.lens - 2
+        mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
+        s_span, s_label = self.model(words, feats)
+        s_span = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span
+        chart_preds = self.model.decode(s_span, s_label, mask)
+        batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
+                       for tree, chart in zip(trees, chart_preds)]
+        if self.args.prob:
+            batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)]
+        return batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR, ELMO, BERT = None, None, None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True)
+            if 'tag' in args.feat:
+                TAG = Field('tags', bos=BOS, eos=EOS)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len)
+            if 'elmo' in args.feat:
+                from allennlp.modules.elmo import batch_to_ids
+                ELMO = RawField('elmo')
+                ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device)
+            if 'bert' in args.feat:
+                t = TransformerTokenizer(args.bert)
+                BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t)
+                BERT.vocab = t.vocab
+        TREE = RawField('trees')
+        CHART = ChartField('charts')
+        transform = Tree(WORD=(WORD, CHAR, ELMO, BERT), POS=TAG, TREE=TREE, CHART=CHART)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if TAG is not None:
+                TAG.build(train)
+            if CHAR is not None:
+                CHAR.build(train)
+        CHART.build(train)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_labels': len(CHART.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'bert_pad_index': BERT.pad_index if BERT is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index,
+            'eos_index': WORD.eos_index
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
diff --git a/tania_scripts/supar/models/const/crf/transform.py b/tania_scripts/supar/models/const/crf/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3ba03c85e3e46439d62cd2fa13299ac847437b5
--- /dev/null
+++ b/tania_scripts/supar/models/const/crf/transform.py
@@ -0,0 +1,494 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple,
+                    Union)
+
+import nltk
+
+from supar.utils.logging import get_logger
+from supar.utils.tokenizer import Tokenizer
+from supar.utils.transform import Sentence, Transform
+
+if TYPE_CHECKING:
+    from supar.utils import Field
+
+logger = get_logger(__name__)
+
+
+class Tree(Transform):
+    r"""
+    A :class:`Tree` object factorize a constituency tree into four fields,
+    each associated with one or more :class:`~supar.utils.field.Field` objects.
+
+    Attributes:
+        WORD:
+            Words in the sentence.
+        POS:
+            Part-of-speech tags, or underscores if not available.
+        TREE:
+            The raw constituency tree in :class:`nltk.tree.Tree` format.
+        CHART:
+            The factorized sequence of binarized tree traversed in post-order.
+    """
+
+    root = ''
+    fields = ['WORD', 'POS', 'TREE', 'CHART']
+
+    def __init__(
+        self,
+        WORD: Optional[Union[Field, Iterable[Field]]] = None,
+        POS: Optional[Union[Field, Iterable[Field]]] = None,
+        TREE: Optional[Union[Field, Iterable[Field]]] = None,
+        CHART: Optional[Union[Field, Iterable[Field]]] = None
+    ) -> Tree:
+        super().__init__()
+
+        self.WORD = WORD
+        self.POS = POS
+        self.TREE = TREE
+        self.CHART = CHART
+
+    @property
+    def src(self):
+        return self.WORD, self.POS, self.TREE
+
+    @property
+    def tgt(self):
+        return self.CHART,
+
+    @classmethod
+    def totree(
+        cls,
+        tokens: List[Union[str, Tuple]],
+        root: str = '',
+        normalize: Dict[str, str] = {'(': '-LRB-', ')': '-RRB-'}
+    ) -> nltk.Tree:
+        r"""
+        Converts a list of tokens to a :class:`nltk.tree.Tree`, with missing fields filled in with underscores.
+
+        Args:
+            tokens (List[Union[str, Tuple]]):
+                This can be either a list of words or word/pos pairs.
+            root (str):
+                The root label of the tree. Default: ''.
+            normalize (Dict):
+                Keys within the dict in each token will be replaced by the values. Default: ``{'(': '-LRB-', ')': '-RRB-'}``.
+
+        Returns:
+            A :class:`nltk.tree.Tree` object.
+
+        Examples:
+            >>> from supar.models.const.crf.transform import Tree
+            >>> Tree.totree(['She', 'enjoys', 'playing', 'tennis', '.'], 'TOP').pprint()
+            (TOP ( (_ She)) ( (_ enjoys)) ( (_ playing)) ( (_ tennis)) ( (_ .)))
+            >>> Tree.totree(['(', 'If', 'You', 'Let', 'It', ')'], 'TOP').pprint()
+            (TOP
+              ( (_ -LRB-))
+              ( (_ If))
+              ( (_ You))
+              ( (_ Let))
+              ( (_ It))
+              ( (_ -RRB-)))
+        """
+
+        normalize = str.maketrans(normalize)
+        if isinstance(tokens[0], str):
+            tokens = [(token, '_') for token in tokens]
+        return nltk.Tree(root, [nltk.Tree('', [nltk.Tree(pos, [word.translate(normalize)])]) for word, pos in tokens])
+
+    @classmethod
+    def binarize(
+        cls,
+        tree: nltk.Tree,
+        left: bool = True,
+        mark: str = '*',
+        join: str = '::',
+        implicit: bool = False
+    ) -> nltk.Tree:
+        r"""
+        Conducts binarization over the tree.
+
+        First, the tree is transformed to satisfy `Chomsky Normal Form (CNF)`_.
+        Here we call :meth:`~nltk.tree.Tree.chomsky_normal_form` to conduct left-binarization.
+        Second, all unary productions in the tree are collapsed.
+
+        Args:
+            tree (nltk.tree.Tree):
+                The tree to be binarized.
+            left (bool):
+                If ``True``, left-binarization is conducted. Default: ``True``.
+            mark (str):
+                A string used to mark newly inserted nodes, working if performing explicit binarization. Default: ``'*'``.
+            join (str):
+                A string used to connect collapsed node labels. Default: ``'::'``.
+            implicit (bool):
+                If ``True``, performs implicit binarization. Default: ``False``.
+
+        Returns:
+            The binarized tree.
+
+        Examples:
+            >>> from supar.models.const.crf.transform import Tree
+            >>> tree = nltk.Tree.fromstring('''
+                                            (TOP
+                                              (S
+                                                (NP (_ She))
+                                                (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis)))))
+                                                (_ .)))
+                                            ''')
+            >>> tree.pretty_print()
+                         TOP
+                          |
+                          S
+              ____________|________________
+             |            VP               |
+             |     _______|_____           |
+             |    |             S          |
+             |    |             |          |
+             |    |             VP         |
+             |    |        _____|____      |
+             NP   |       |          NP    |
+             |    |       |          |     |
+             _    _       _          _     _
+             |    |       |          |     |
+            She enjoys playing     tennis  .
+
+            >>> Tree.binarize(tree).pretty_print()
+                             TOP
+                              |
+                              S
+                         _____|__________________
+                        S*                       |
+              __________|_____                   |
+             |                VP                 |
+             |     ___________|______            |
+             |    |                S::VP         |
+             |    |            ______|_____      |
+             NP  VP*         VP*           NP    S*
+             |    |           |            |     |
+             _    _           _            _     _
+             |    |           |            |     |
+            She enjoys     playing       tennis  .
+
+            >>> Tree.binarize(tree, implicit=True).pretty_print()
+                             TOP
+                              |
+                              S
+                         _____|__________________
+                                                 |
+              __________|_____                   |
+             |                VP                 |
+             |     ___________|______            |
+             |    |                S::VP         |
+             |    |            ______|_____      |
+             NP                            NP
+             |    |           |            |     |
+             _    _           _            _     _
+             |    |           |            |     |
+            She enjoys     playing       tennis  .
+
+            >>> Tree.binarize(tree, left=False).pretty_print()
+                         TOP
+                          |
+                          S
+              ____________|______
+             |                   S*
+             |             ______|___________
+             |            VP                 |
+             |     _______|______            |
+             |    |            S::VP         |
+             |    |        ______|_____      |
+             NP  VP*     VP*           NP    S*
+             |    |       |            |     |
+             _    _       _            _     _
+             |    |       |            |     |
+            She enjoys playing       tennis  .
+
+        .. _Chomsky Normal Form (CNF):
+            https://en.wikipedia.org/wiki/Chomsky_normal_form
+        """
+
+        tree = tree.copy(True)
+        nodes = [tree]
+        if len(tree) == 1:
+            if not isinstance(tree[0][0], nltk.Tree):
+                tree[0] = nltk.Tree(f'{tree.label()}{mark}', [tree[0]])
+            nodes = [tree[0]]
+        while nodes:
+            node = nodes.pop()
+            if isinstance(node, nltk.Tree):
+                if implicit:
+                    label = ''
+                else:
+                    label = node.label()
+                    if mark not in label:
+                        label = f'{label}{mark}'
+                # ensure that only non-terminals can be attached to a n-ary subtree
+                if len(node) > 1:
+                    for child in node:
+                        if not isinstance(child[0], nltk.Tree):
+                            child[:] = [nltk.Tree(child.label(), child[:])]
+                            child.set_label(label)
+                # chomsky normal form factorization
+                if len(node) > 2:
+                    if left:
+                        node[:-1] = [nltk.Tree(label, node[:-1])]
+                    else:
+                        node[1:] = [nltk.Tree(label, node[1:])]
+                nodes.extend(node)
+        # collapse unary productions, shoule be conducted after binarization
+        tree.collapse_unary(joinChar=join)
+        return tree
+
+    @classmethod
+    def factorize(
+        cls,
+        tree: nltk.Tree,
+        delete_labels: Optional[Set[str]] = None,
+        equal_labels: Optional[Dict[str, str]] = None
+    ) -> Iterable[Tuple]:
+        r"""
+        Factorizes the tree into a sequence traversed in post-order.
+
+        Args:
+            tree (nltk.tree.Tree):
+                The tree to be factorized.
+            delete_labels (Optional[Set[str]]):
+                A set of labels to be ignored. This is used for evaluation.
+                If it is a pre-terminal label, delete the word along with the brackets.
+                If it is a non-terminal label, just delete the brackets (don't delete children).
+                In `EVALB`_, the default set is:
+                {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}
+                Default: ``None``.
+            equal_labels (Optional[Dict[str, str]]):
+                The key-val pairs in the dict are considered equivalent (non-directional). This is used for evaluation.
+                The default dict defined in `EVALB`_ is: {'ADVP': 'PRT'}
+                Default: ``None``.
+
+        Returns:
+            The sequence of the factorized tree.
+
+        Examples:
+            >>> from supar.models.const.crf.transform import Tree
+            >>> tree = nltk.Tree.fromstring('''
+                                            (TOP
+                                              (S
+                                                (NP (_ She))
+                                                (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis)))))
+                                                (_ .)))
+                                            ''')
+            >>> Tree.factorize(tree)
+            [(0, 1, 'NP'), (3, 4, 'NP'), (2, 4, 'VP'), (2, 4, 'S'), (1, 4, 'VP'), (0, 5, 'S'), (0, 5, 'TOP')]
+            >>> Tree.factorize(tree, delete_labels={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''})
+            [(0, 1, 'NP'), (3, 4, 'NP'), (2, 4, 'VP'), (2, 4, 'S'), (1, 4, 'VP'), (0, 5, 'S')]
+
+        .. _EVALB:
+            https://nlp.cs.nyu.edu/evalb/
+        """
+
+        def track(tree, i):
+            label = tree if isinstance(tree, str) else tree.label()
+            if delete_labels is not None and label in delete_labels:
+                label = None
+            if equal_labels is not None:
+                label = equal_labels.get(label, label)
+            if len(tree) == 1 and not isinstance(tree[0], nltk.Tree):
+                return (i + 1 if label is not None else i), []
+            j, spans = i, []
+            for child in tree:
+                j, s = track(child, j)
+                spans += s
+            if label is not None and j > i:
+                spans = spans + [(i, j, label)]
+            return j, spans
+        return track(tree, 0)[1]
+
+    @classmethod
+    def build(
+        cls,
+        sentence: Union[nltk.Tree, Iterable],
+        spans: Iterable[Tuple],
+        delete_labels: Optional[Set[str]] = None,
+        mark: Union[str, Tuple[str]] = ('*', '|<>'),
+        root: str = '',
+        join: str = '::',
+        postorder: bool = True
+    ) -> nltk.Tree:
+        r"""
+        Builds a constituency tree from a span sequence.
+        During building, the sequence is recovered, i.e., de-binarized to the original format.
+
+        Args:
+            sentence (Union[nltk.tree.Tree, Iterable]):
+                Sentence to provide a base for building a result tree, both `nltk.tree.Tree` and tokens are allowed.
+            spans (Iterable[Tuple]):
+                A list of spans, each consisting of the indices of left/right boundaries and label of the constituent.
+            delete_labels (Optional[Set[str]]):
+                A set of labels to be ignored. Default: ``None``.
+            mark (Union[str, List[str]]):
+                A string used to mark newly inserted nodes. Non-terminals containing this will be removed.
+                Default: ``('*', '|<>')``.
+            root (str):
+                The root label of the tree, needed if input a list of tokens. Default: ''.
+            join (str):
+                A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains.
+                Default: ``'::'``.
+            postorder (bool):
+                If ``True``, enforces the sequence is sorted in post-order. Default: ``True``.
+
+        Returns:
+            A result constituency tree.
+
+        Examples:
+            >>> from supar.models.const.crf.transform import Tree
+            >>> Tree.build(['She', 'enjoys', 'playing', 'tennis', '.'],
+                           [(0, 5, 'S'), (0, 4, 'S*'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP*'),
+                            (2, 4, 'S::VP'), (2, 3, 'VP*'), (3, 4, 'NP'), (4, 5, 'S*')],
+                           root='TOP').pretty_print()
+                         TOP
+                          |
+                          S
+              ____________|________________
+             |            VP               |
+             |     _______|_____           |
+             |    |             S          |
+             |    |             |          |
+             |    |             VP         |
+             |    |        _____|____      |
+             NP   |       |          NP    |
+             |    |       |          |     |
+             _    _       _          _     _
+             |    |       |          |     |
+            She enjoys playing     tennis  .
+
+            >>> Tree.build(['She', 'enjoys', 'playing', 'tennis', '.'],
+                           [(0, 1, 'NP'), (3, 4, 'NP'), (2, 4, 'VP'), (2, 4, 'S'), (1, 4, 'VP'), (0, 5, 'S')],
+                           root='TOP').pretty_print()
+                         TOP
+                          |
+                          S
+              ____________|________________
+             |            VP               |
+             |     _______|_____           |
+             |    |             S          |
+             |    |             |          |
+             |    |             VP         |
+             |    |        _____|____      |
+             NP   |       |          NP    |
+             |    |       |          |     |
+             _    _       _          _     _
+             |    |       |          |     |
+            She enjoys playing     tennis  .
+
+        """
+
+        tree = sentence if isinstance(sentence, nltk.Tree) else Tree.totree(sentence, root)
+        leaves = [subtree for subtree in tree.subtrees() if not isinstance(subtree[0], nltk.Tree)]
+        if postorder:
+            spans = sorted(spans, key=lambda x: (x[1], x[1] - x[0]))
+
+        root = tree.label()
+        start, stack = 0, []
+        for span in spans:
+            i, j, label = span
+            if delete_labels is not None and label in delete_labels:
+                continue
+            stack.extend([(n, n + 1, leaf) for n, leaf in enumerate(leaves[start:i], start)])
+            children = []
+            while len(stack) > 0 and i <= stack[-1][0]:
+                children = [stack.pop()] + children
+            start = children[-1][1] if len(children) > 0 else i
+            children.extend([(n, n + 1, leaf) for n, leaf in enumerate(leaves[start:j], start)])
+            start = j
+            if not label or label.endswith(mark):
+                stack.extend(children)
+                continue
+            labels = label.split(join)
+            tree = nltk.Tree(labels[-1], [child[-1] for child in children])
+            for label in reversed(labels[:-1]):
+                tree = nltk.Tree(label, [tree])
+            stack.append((i, j, tree))
+        stack.extend([(n, n + 1, leaf) for n, leaf in enumerate(leaves[start:], start)])
+        return nltk.Tree(root, [i[-1] for i in stack])
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        **kwargs
+    ) -> List[TreeSentence]:
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                A filename or a list of instances.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+
+        Returns:
+            A list of :class:`TreeSentence` instances.
+        """
+
+        if lang is not None:
+            tokenizer = Tokenizer(lang)
+        if isinstance(data, str) and os.path.exists(data):
+            if data.endswith('.txt'):
+                data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1)
+            else:
+                data = open(data)
+        else:
+            if lang is not None:
+                data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+
+        index = 0
+        for s in data:
+            try:
+                tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
+                sentence = TreeSentence(self, tree, index, **kwargs)
+            except ValueError:
+                logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!")
+                continue
+            else:
+                yield sentence
+                index += 1
+        self.root = tree.label()
+
+
+class TreeSentence(Sentence):
+    r"""
+    Args:
+        transform (Tree):
+            A :class:`Tree` object.
+        tree (nltk.tree.Tree):
+            A :class:`nltk.tree.Tree` object.
+        index (Optional[int]):
+            Index of the sentence in the corpus. Default: ``None``.
+    """
+
+    def __init__(
+        self,
+        transform: Tree,
+        tree: nltk.Tree,
+        index: Optional[int] = None,
+        **kwargs
+    ) -> TreeSentence:
+        super().__init__(transform, index)
+
+        words, tags, chart = *zip(*tree.pos()), None
+        if transform.training:
+            chart = [[None] * (len(words) + 1) for _ in range(len(words) + 1)]
+            for i, j, label in Tree.factorize(Tree.binarize(tree, implicit=kwargs.get('implicit', False))[0]):
+                chart[i][j] = label
+        self.values = [words, tags, tree, chart]
+
+    def __repr__(self):
+        return self.values[-2].pformat(1000000)
+
+    def pretty_print(self):
+        self.values[-2].pretty_print()
diff --git a/tania_scripts/supar/models/const/sl/__init__.py b/tania_scripts/supar/models/const/sl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b6b7bd1056edd874f9589057530c6fcae8ed25d
--- /dev/null
+++ b/tania_scripts/supar/models/const/sl/__init__.py
@@ -0,0 +1,4 @@
+from .model import SLConstituentModel
+from .parser import SLConstituentParser
+
+__all__ = ['SLConstituentModel', 'SLConstituentParser']
\ No newline at end of file
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25ff091642789417721ecf105676e6aa0cec7ae7
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d46638efe8006ba326a4c8302ee488c3cbf8e14
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f671041ff9274200b3afc20f09983fcd26145c5b
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dc15de997aec065667819a9092434d52331791f
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d61fb3dc0394c8bb1146467f5c1b762d37cca0b
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdce8977c12fc63d881480ee1b0a5cde8a562fd9
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1874e521ca0c7d7e3004a1b69a4ae48e1f6bf140
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbffb6675228c80bd72bea723d3dbbfbcbec55b4
Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/sl/model.py b/tania_scripts/supar/models/const/sl/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..640c1f99b42927dc1c84a4fb23eb3314001eaf23
--- /dev/null
+++ b/tania_scripts/supar/models/const/sl/model.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from supar.model import Model
+from supar.modules import MLP, DecoderLSTM
+from supar.utils import Config
+from typing import List
+
+
+class SLConstituentModel(Model):
+
+    def __init__(self,
+                 n_words,
+                 n_commons,
+                 n_ancestors,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat: List[str] =['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_arc_mlp=500,
+                 n_rel_mlp=100,
+                 mlp_dropout=.33,
+                 scale=0,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        # create decoder
+        self.common_decoder, self.ancestor_decoder = None, None
+        if self.args.decoder == 'lstm':
+            decoder = lambda out_dim: DecoderLSTM(
+                self.args.n_encoder_hidden, self.args.n_encoder_hidden, out_dim, 
+                self.args.n_decoder_layers, dropout=mlp_dropout, device=self.device
+            )
+        else:
+            decoder = lambda out_dim: MLP(
+                n_in=self.args.n_encoder_hidden, n_out=out_dim,
+                dropout=mlp_dropout, activation=True
+            )
+
+        self.common_decoder = decoder(self.args.n_commons)
+        self.ancestor_decoder = decoder(self.args.n_ancestors)
+
+        # create delay projection
+        if self.args.delay != 0:
+            self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay + 1),
+                                  n_out=self.args.n_encoder_hidden, dropout=mlp_dropout)
+
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words: torch.Tensor, feats: List[torch.Tensor]=None):
+        # words, *feats ~ [batch_size, bos + pad(seq_len) + delay, n_encoder_hidden]
+        x = self.encode(words, feats)
+        x = x[:, 1:, :]
+
+        # x ~ [batch_size, pad(seq_len), n_encoder_hidden]
+        batch_size, pad_seq_len, embed_size = x.shape
+        if self.args.delay != 0:
+            x = torch.cat([x[:, i:(pad_seq_len - self.args.delay + i), :] for i in range(self.args.delay+1)], dim=2)
+            x = self.delay_proj(x)
+
+        x, qloss = self.vq_forward(x)   # vector quantization
+
+        # s_common ~ [batch_size, pad(seq_len), n_commons]
+        # s_ancestor ~ [batch_size, pad(seq_len), n_ancestors]
+        s_common, s_ancestor = self.common_decoder(x), self.ancestor_decoder(x)
+        return s_common, s_ancestor, qloss
+
+    def loss(self,
+             s_common: torch.Tensor, s_ancestor: torch.Tensor, 
+             commons: torch.Tensor, ancestors: torch.Tensor, mask: torch.Tensor):
+        s_common, commons = s_common[mask], commons[mask]
+        s_ancestor, ancestors = s_ancestor[mask], ancestors[mask]
+        common_loss = self.criterion(s_common, commons)
+        ancestor_loss = self.criterion(s_ancestor, ancestors)
+        return common_loss + ancestor_loss
+
+    def decode(self, s_common, s_ancestor):
+        common_pred = s_common.argmax(-1)
+        ancestor_pred = s_ancestor.argmax(-1)
+        return common_pred, ancestor_pred
\ No newline at end of file
diff --git a/tania_scripts/supar/models/const/sl/parser.py b/tania_scripts/supar/models/const/sl/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..17e8a3a91e4aa862223a846e045b3d442e1e43dd
--- /dev/null
+++ b/tania_scripts/supar/models/const/sl/parser.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Dict, Iterable, Set, Union
+
+import torch, nltk
+from supar.models.const.sl.model import SLConstituentModel
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, EOS, PAD, UNK
+from supar.utils.field import ChartField, Field, RawField, SubwordField
+from supar.utils.logging import get_logger
+from supar.utils.metric import SpanMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.models.const.sl.transform import SLConstituent
+from supar.utils.transform import Batch
+from supar.models.const.crf.transform import Tree
+from supar.codelin import get_con_encoder, LinearizedTree, C_Label
+from supar.codelin.utils.constants import BOS as C_BOS, EOS as C_EOS
+
+logger = get_logger(__name__)
+
+
+class SLConstituentParser(Parser):
+
+    NAME = 'SLConstituentParser'
+    MODEL = SLConstituentModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.COMMON = self.transform.COMMON
+        self.ANCESTOR = self.transform.ANCESTOR
+        self.encoder = self.transform.encoder
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, texts, *feats, commons, ancestors, trees = batch
+        mask = batch.mask[:, (1+self.args.delay):]
+        s_common, s_ancestor, qloss = self.model(words, feats)
+        loss = self.model.loss(s_common, s_ancestor, commons, ancestors, mask) + qloss
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        words, texts, *feats, commons, ancestors, trees = batch
+        mask = batch.mask[:, (1+self.args.delay):]
+
+        # forward pass
+        s_common, s_ancestor, qloss = self.model(words, feats)
+        loss = self.model.loss(s_common, s_ancestor, commons, ancestors, mask) + qloss
+
+        # make predictions of the output of the network
+        common_preds, ancestor_preds = self.model.decode(s_common, s_ancestor)
+
+        # obtain original tokens to compute decoding
+        lens = (batch.lens - 1 - self.args.delay).tolist()
+        common_preds = [self.COMMON.vocab[i.tolist()] for i in common_preds[mask].split(lens)]
+        ancestor_preds = [self.ANCESTOR.vocab[i.tolist()] for i in ancestor_preds[mask].split(lens)]
+        # tag_preds = [self.transform.POS.vocab[i.tolist()] for i in tags[mask].split(lens)]
+        tag_preds = map(lambda tree: tuple(zip(*tree.pos()))[1], trees)
+
+        preds = list()
+        for i, (forms, upos, common_pred, ancestor_pred) in enumerate(zip(texts, tag_preds, common_preds, ancestor_preds)):
+            labels = list(map(lambda x: self.encoder.separator.join(x), zip(common_pred, ancestor_pred)))
+            linearized = list(map(lambda x: '\t'.join(x), zip(forms, upos, labels)))
+            linearized = [f'{C_BOS}\t{C_BOS}\t{C_BOS}'] + linearized + [f'{C_EOS}\t{C_EOS}\t{C_EOS}']
+            linearized = '\n'.join(linearized)
+            tree = LinearizedTree.from_string(linearized, mode='CONST', separator=self.encoder.separator,
+                                              unary_joiner=self.encoder.unary_joiner)
+            tree = self.encoder.decode(tree)
+            tree = tree.postprocess_tree('strat_max', clean_nulls=False)
+            preds.append(nltk.Tree.fromstring(str(tree)))
+
+            if len(preds[-1].leaves()) != len(trees[i].leaves()):
+                with open('error', 'w') as file:
+                    file.write(linearized)
+
+
+        return SpanMetric(loss,
+                          [Tree.factorize(tree, None, None) for tree in preds],
+                          [Tree.factorize(tree, None, None) for tree in trees])
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, texts, *feats = batch
+        tags = feats[-1][:, (1+self.args.delay):]
+        mask = batch.mask[:, (1 + self.args.delay):]
+
+        # forward pass
+        s_common, s_ancestor, _ = self.model(words, feats)
+
+        # make predictions of the output of the network
+        common_preds, ancestor_preds = self.model.decode(s_common, s_ancestor)
+
+        # obtain original tokens to compute decoding
+        lens = (batch.lens - 1 - self.args.delay).tolist()
+        common_preds = [self.COMMON.vocab[i.tolist()] for i in common_preds[mask].split(lens)]
+        ancestor_preds = [self.ANCESTOR.vocab[i.tolist()] for i in ancestor_preds[mask].split(lens)]
+        tag_preds = [self.transform.POS.vocab[i.tolist()] for i in tags[mask].split(lens)]
+
+        preds = list()
+        for i, (forms, upos, common_pred, ancestor_pred) in enumerate(zip(texts, tag_preds, common_preds, ancestor_preds)):
+            labels = list(map(lambda x: self.encoder.separator.join(x), zip(common_pred, ancestor_pred)))
+            linearized = list(map(lambda x: '\t'.join(x), zip(forms, upos, labels)))
+            linearized = [f'{C_BOS}\t{C_BOS}\t{C_BOS}'] + linearized + [f'{C_EOS}\t{C_EOS}\t{C_EOS}']
+            linearized = '\n'.join(linearized)
+            tree = LinearizedTree.from_string(linearized, mode='CONST', separator=self.encoder.separator, unary_joiner=self.encoder.unary_joiner)
+            tree = self.encoder.decode(tree)
+            tree = tree.postprocess_tree('strat_max', clean_nulls=False)
+            preds.append(nltk.Tree.fromstring(str(tree)))
+        batch.trees = preds
+        return batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR = None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            pad_token = t.pad if t.pad else PAD
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t, delay=args.delay)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True, delay=args.delay)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len, delay=args.delay)
+        TAG = Field('tags', bos=BOS,)
+        TEXT = RawField('texts')
+        COMMON = Field('commons')
+        ANCESTOR = Field('ancestors')
+        TREE = RawField('trees')
+        encoder = get_con_encoder(args.codes)
+        transform = SLConstituent(encoder=encoder, WORD=(WORD, TEXT, CHAR), POS=TAG,
+                                  COMMON=COMMON, ANCESTOR=ANCESTOR, TREE=TREE)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if CHAR is not None:
+                CHAR.build(train)
+        TAG.build(train)
+        COMMON.build(train)
+        ANCESTOR.build(train)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_commons': len(COMMON.vocab),
+            'n_ancestors': len(ANCESTOR.vocab),
+            'n_tags': len(TAG.vocab),
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'delay': 0 if 'delay' not in args.keys() else args.delay,
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
diff --git a/tania_scripts/supar/models/const/sl/transform.py b/tania_scripts/supar/models/const/sl/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b94b9660d42e20953d4a031e96aa363f80301357
--- /dev/null
+++ b/tania_scripts/supar/models/const/sl/transform.py
@@ -0,0 +1,94 @@
+from typing import Union, Iterable, Optional, Set
+from supar.utils.field import Field
+from supar.utils.transform import Sentence, Transform
+from supar.codelin import C_Tree, UNARY_JOINER, LABEL_SEPARATOR
+import nltk, re
+
+
+class SLConstituent(Transform):
+    fields = ['WORD', 'POS', 'COMMON', 'ANCESTOR', 'TREE']
+
+    def __init__(
+        self,
+        encoder,
+        WORD: Union[Field, Iterable[Field]],
+        POS: Union[Field, Iterable[Field]],
+        COMMON: Union[Field, Iterable[Field]],
+        ANCESTOR: Union[Field, Iterable[Field]],
+        TREE: Union[Field, Iterable[Field]]
+    ):
+        super().__init__()
+
+        self.WORD = WORD
+        self.POS = POS
+        self.COMMON = COMMON
+        self.ANCESTOR = ANCESTOR
+        self.TREE = TREE
+        self.encoder = encoder
+
+
+    @property
+    def src(self):
+        return self.WORD, self.POS
+
+    @property
+    def tgt(self):
+        return self.COMMON, self.ANCESTOR, self.TREE
+    def load(
+        self,
+        data: Union[str, Iterable],
+        **kwargs
+    ) -> Iterable[Sentence]:
+        lines = open(data)
+        index = 0
+        for line in lines:
+            line = line.strip()
+            if len(line) > 0:
+                sentence = SLConstituentSentence(self, line, self.encoder, index)
+                yield sentence
+                index += 1
+
+
+def get_nodes(tree: nltk.Tree):
+    nodes = {tree.label()}
+    for subtree in tree:
+        if isinstance(subtree[0], nltk.Tree):
+            nodes = {*nodes, *get_nodes(subtree)}
+    return nodes
+
+def remove_doubles(tree: nltk.Tree):
+    for i, subtree in enumerate(tree):
+        if isinstance(subtree, str) and len(tree) > 1:
+            tree.pop(i)
+        elif isinstance(subtree, nltk.Tree):
+            remove_doubles(subtree)
+
+class SLConstituentSentence(Sentence):
+
+    def __init__(self, transform: SLConstituent, line: str, encoder, index: Optional[int] = None):
+        super().__init__(transform, index)
+
+        # get nodes of the tree
+        gold = nltk.Tree.fromstring(line)
+        nodes = ''.join(get_nodes(gold))
+        assert (encoder.separator not in nodes) and (encoder.unary_joiner not in nodes)
+
+        # create linearized tree
+        tree = C_Tree.from_string(line)
+        linearized_tree = encoder.encode(tree)
+        self.annotations = []
+        commons = list(map(lambda x: repr(x).split(encoder.separator)[0], linearized_tree.labels))
+        ancestors = list(map(lambda x: encoder.separator.join(repr(x).split(encoder.separator)[1:]), linearized_tree.labels))
+
+        _, postags = zip(*gold.pos())
+        self.values = [
+            linearized_tree.words,
+            postags,
+            commons, ancestors
+        ]
+        self.values.append(
+            nltk.Tree.fromstring(line)
+        )
+    def __repr__(self):
+        remove_doubles(self.values[-1])
+        return re.sub(' +', ' ', str(self.values[-1]).replace('\n', '').replace('\t', ''))
diff --git a/tania_scripts/supar/models/const/tt/__init__.py b/tania_scripts/supar/models/const/tt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..43892195ca6b97953149e492744adacaf5f01015
--- /dev/null
+++ b/tania_scripts/supar/models/const/tt/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import TetraTaggingConstituencyModel
+from .parser import TetraTaggingConstituencyParser
+
+__all__ = ['TetraTaggingConstituencyModel', 'TetraTaggingConstituencyParser']
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a577e63e231c15d7c6ffbde1d8660f3f458b985
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f242de08b1629939a1184b1ebf706bb4b2b2cbda
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4deac65c29f9618638ef3d168e5f3799e148ae96
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08a12831d12bab671d404d55859b4bb81ba4c6e0
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..740bf37c4d0abed35cbd2fe734b4899404da1c63
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d0f9c3cecab729455b1cf901cfed200de9a4883
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2226fc4c38750791963263a693f1896078972e05
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee00f85bb15eed4b144ed777f9dbfeef8c753e8c
Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/tt/model.py b/tania_scripts/supar/models/const/tt/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bce1b6489bef08ba66494937a7a9c64873582ec
--- /dev/null
+++ b/tania_scripts/supar/models/const/tt/model.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from supar.model import Model
+from supar.utils import Config
+from supar.utils.common import INF
+
+
+class TetraTaggingConstituencyModel(Model):
+    r"""
+    The implementation of TetraTagging Constituency Parser :cite:`kitaev-klein-2020-tetra`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layers. Default: .33.
+        n_gnn_layers (int):
+            The number of GNN layers. Default: 3.
+        gnn_dropout (float):
+            The dropout ratio of GNN layers. Default: .33.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, True),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_gnn_layers=3,
+                 gnn_dropout=.33,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.proj = nn.Linear(self.args.n_encoder_hidden, self.args.n_leaves + self.args.n_nodes)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(
+        self,
+        words: torch.LongTensor,
+        feats: List[torch.LongTensor] = None
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                Scores for all leaves (``[batch_size, seq_len, n_leaves]``) and nodes (``[batch_size, seq_len, n_nodes]``).
+        """
+
+        s = self.proj(self.encode(words, feats)[:, 1:-1])
+        s_leaf, s_node = s[..., :self.args.n_leaves], s[..., self.args.n_leaves:]
+        return s_leaf, s_node
+
+    def loss(
+        self,
+        s_leaf: torch.Tensor,
+        s_node: torch.Tensor,
+        leaves: torch.LongTensor,
+        nodes: torch.LongTensor,
+        mask: torch.BoolTensor
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            s_leaf (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``.
+                Leaf scores.
+            s_node (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``.
+                Non-terminal scores.
+            leaves (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Actions for leaves.
+            nodes (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Actions for non-terminals.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        leaf_mask, node_mask = mask, mask[:, 1:]
+        leaf_loss = self.criterion(s_leaf[leaf_mask], leaves[leaf_mask])
+        node_loss = self.criterion(s_node[:, :-1][node_mask], nodes[node_mask]) if nodes.shape[1] > 0 else 0
+        return leaf_loss + node_loss
+
+    def decode(
+        self,
+        s_leaf: torch.Tensor,
+        s_node: torch.Tensor,
+        mask: torch.BoolTensor,
+        left_mask: torch.BoolTensor,
+        depth: int = 8
+    ) -> List[List[Tuple]]:
+        r"""
+        Args:
+            s_leaf (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``.
+                Leaf scores.
+            s_node (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``.
+                Non-terminal scores.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+            left_mask (~torch.BoolTensor): ``[n_leaves + n_nodes]``.
+                The mask for distingushing left/rightward actions.
+            depth (int):
+                Stack depth. Default: 8.
+
+        Returns:
+            List[List[Tuple]]:
+                Sequences of factorized labeled trees.
+        """
+        from torch_scatter import scatter_max
+
+        lens = mask.sum(-1)
+        batch_size, seq_len, n_leaves = s_leaf.shape
+        end_mask = (lens - 1).unsqueeze(-1).eq(lens.new_tensor(range(seq_len)))
+        leaf_left_mask, node_left_mask = left_mask[:n_leaves], left_mask[n_leaves:]
+        s_leaf = s_leaf.masked_fill_(end_mask.unsqueeze(-1) & leaf_left_mask, -INF)
+        # [n_leaves], [n_nodes]
+        changes = (torch.where(leaf_left_mask, 1, 0), torch.where(node_left_mask, 0, -1))
+        # [batch_size, depth]
+        depths = lens.new_full((depth,), -2).index_fill_(-1, lens.new_tensor(0), -1).repeat(batch_size, 1)
+        # [2, batch_size, depth, seq_len]
+        labels, paths = lens.new_zeros(2, batch_size, depth, seq_len), lens.new_zeros(2, batch_size, depth, seq_len)
+        # [batch_size, depth]
+        s = s_leaf.new_zeros(batch_size, depth)
+
+        def advance(s, s_t, depths, changes):
+            batch_size, n_labels = s_t.shape
+            # [batch_size, depth * n_labels]
+            depths = (depths.unsqueeze(-1) + changes).view(batch_size, -1)
+            # [batch_size, depth, n_labels]
+            s_t = s.unsqueeze(-1) + s_t.unsqueeze(1)
+            # [batch_size, depth * n_labels]
+            # fill scores of invalid depths with -INF
+            s_t = s_t.view(batch_size, -1).masked_fill_((depths < 0).logical_or_(depths >= depth), -INF)
+            # [batch_size, depth]
+            # for each depth, we use the `scatter_max` trick to obtain the 1-best label
+            s, ls = scatter_max(s_t, depths.clamp(0, depth - 1), -1, s_t.new_full((batch_size, depth), -INF))
+            # [batch_size, depth]
+            depths = depths.gather(-1, ls.clamp(0, depths.shape[1] - 1)).masked_fill_(s.eq(-INF), -1)
+            ll = ls % n_labels
+            lp = depths - changes[ll]
+            return s, ll, lp, depths
+
+        for t in range(seq_len):
+            m = lens.gt(t)
+            s[m], labels[0, m, :, t], paths[0, m, :, t], depths[m] = advance(s[m], s_leaf[m, t], depths[m], changes[0])
+            if t == seq_len - 1:
+                break
+            m = lens.gt(t + 1)
+            s[m], labels[1, m, :, t], paths[1, m, :, t], depths[m] = advance(s[m], s_node[m, t], depths[m], changes[1])
+
+        lens = lens.tolist()
+        labels, paths = labels.movedim((0, 2), (2, 3))[mask].split(lens), paths.movedim((0, 2), (2, 3))[mask].split(lens)
+        leaves, nodes = [], []
+        for i, length in enumerate(lens):
+            leaf_labels, node_labels = labels[i].transpose(0, 1).tolist()
+            leaf_paths, node_paths = paths[i].transpose(0, 1).tolist()
+            leaf_pred, node_pred, prev = [leaf_labels[-1][0]], [], leaf_paths[-1][0]
+            for j in reversed(range(length - 1)):
+                node_pred.append(node_labels[j][prev])
+                prev = node_paths[j][prev]
+                leaf_pred.append(leaf_labels[j][prev])
+                prev = leaf_paths[j][prev]
+            leaves.append(list(reversed(leaf_pred)))
+            nodes.append(list(reversed(node_pred)))
+        return leaves, nodes
diff --git a/tania_scripts/supar/models/const/tt/parser.py b/tania_scripts/supar/models/const/tt/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..7289c222a454629876ecc155721fef36ef79a0fb
--- /dev/null
+++ b/tania_scripts/supar/models/const/tt/parser.py
@@ -0,0 +1,205 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Dict, Iterable, Set, Union
+
+import torch
+
+from supar.models.const.tt.model import TetraTaggingConstituencyModel
+from supar.models.const.tt.transform import TetraTaggingTree
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, EOS, PAD, UNK
+from supar.utils.field import Field, RawField, SubwordField
+from supar.utils.logging import get_logger
+from supar.utils.metric import SpanMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class TetraTaggingConstituencyParser(Parser):
+    r"""
+    The implementation of TetraTagging Constituency Parser :cite:`kitaev-klein-2020-tetra`.
+    """
+
+    NAME = 'tetra-tagging-constituency'
+    MODEL = TetraTaggingConstituencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TREE = self.transform.TREE
+        self.LEAF = self.transform.LEAF
+        self.NODE = self.transform.NODE
+
+        self.left_mask = torch.tensor([*(i.startswith('l') for i in self.LEAF.vocab.itos),
+                                       *(i.startswith('L') for i in self.NODE.vocab.itos)]).to(self.device)
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        depth: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        depth: int = 1,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        depth: int = 1,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, *feats, _, leaves, nodes = batch
+        mask = batch.mask[:, 2:]
+        s_leaf, s_node = self.model(words, feats)
+        loss = self.model.loss(s_leaf, s_node, leaves, nodes, mask)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        words, *feats, trees, leaves, nodes = batch
+        mask = batch.mask[:, 2:]
+        s_leaf, s_node = self.model(words, feats)
+        loss = self.model.loss(s_leaf, s_node, leaves, nodes, mask)
+        preds = self.model.decode(s_leaf, s_node, mask, self.left_mask, self.args.depth)
+        preds = [TetraTaggingTree.action2tree(tree, (self.LEAF.vocab[i], self.NODE.vocab[j] if len(j) > 0 else []))
+                 for tree, i, j in zip(trees, *preds)]
+        return SpanMetric(loss,
+                          [TetraTaggingTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds],
+                          [TetraTaggingTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees])
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats, trees = batch
+        mask = batch.mask[:, 2:]
+        s_leaf, s_node = self.model(words, feats)
+        preds = self.model.decode(s_leaf, s_node, mask, self.left_mask, self.args.depth)
+        batch.trees = [TetraTaggingTree.action2tree(tree, (self.LEAF.vocab[i], self.NODE.vocab[j] if len(j) > 0 else []))
+                       for tree, i, j in zip(trees, *preds)]
+        if self.args.prob:
+            raise NotImplementedError("Returning action probs are currently not supported yet.")
+        return batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR, ELMO, BERT = None, None, None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True)
+            if 'tag' in args.feat:
+                TAG = Field('tags', bos=BOS, eos=EOS)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len)
+            if 'elmo' in args.feat:
+                from allennlp.modules.elmo import batch_to_ids
+                ELMO = RawField('elmo')
+                ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device)
+            if 'bert' in args.feat:
+                t = TransformerTokenizer(args.bert)
+                BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t)
+                BERT.vocab = t.vocab
+        TREE = RawField('trees')
+        LEAF, NODE = Field('leaf'), Field('node')
+        transform = TetraTaggingTree(WORD=(WORD, CHAR, ELMO, BERT), POS=TAG, TREE=TREE, LEAF=LEAF, NODE=NODE)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if TAG is not None:
+                TAG.build(train)
+            if CHAR is not None:
+                CHAR.build(train)
+        LEAF, NODE = LEAF.build(train), NODE.build(train)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_leaves': len(LEAF.vocab),
+            'n_nodes': len(NODE.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'bert_pad_index': BERT.pad_index if BERT is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index,
+            'eos_index': WORD.eos_index
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
diff --git a/tania_scripts/supar/models/const/tt/transform.py b/tania_scripts/supar/models/const/tt/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..a337b94eaaab48f7a61568a4a396604ffe839691
--- /dev/null
+++ b/tania_scripts/supar/models/const/tt/transform.py
@@ -0,0 +1,301 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union, Sequence
+
+import nltk
+
+from supar.models.const.crf.transform import Tree
+from supar.utils.logging import get_logger
+from supar.utils.tokenizer import Tokenizer
+from supar.utils.transform import Sentence
+
+if TYPE_CHECKING:
+    from supar.utils import Field
+
+logger = get_logger(__name__)
+
+
+class TetraTaggingTree(Tree):
+    r"""
+    :class:`TetraTaggingTree` is derived from the :class:`Tree` class and is defined for supporting the transition system of
+    tetra tagger :cite:`kitaev-klein-2020-tetra`.
+
+    Attributes:
+        WORD:
+            Words in the sentence.
+        POS:
+            Part-of-speech tags, or underscores if not available.
+        TREE:
+            The raw constituency tree in :class:`nltk.tree.Tree` format.
+        LEAF:
+            Action labels in tetra tagger transition system.
+        NODE:
+            Non-terminal labels.
+    """
+
+    fields = ['WORD', 'POS', 'TREE', 'LEAF', 'NODE']
+
+    def __init__(
+        self,
+        WORD: Optional[Union[Field, Iterable[Field]]] = None,
+        POS: Optional[Union[Field, Iterable[Field]]] = None,
+        TREE: Optional[Union[Field, Iterable[Field]]] = None,
+        LEAF: Optional[Union[Field, Iterable[Field]]] = None,
+        NODE: Optional[Union[Field, Iterable[Field]]] = None
+    ) -> Tree:
+        super().__init__()
+
+        self.WORD = WORD
+        self.POS = POS
+        self.TREE = TREE
+        self.LEAF = LEAF
+        self.NODE = NODE
+
+    @property
+    def tgt(self):
+        return self.LEAF, self.NODE
+
+    @classmethod
+    def tree2action(cls, tree: nltk.Tree) -> Tuple[Sequence, Sequence]:
+        r"""
+        Converts a (binarized) constituency tree into tetra-tagging actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                A constituency tree in :class:`nltk.tree.Tree` format.
+
+        Returns:
+            Tetra-tagging actions for leaves and non-terminals.
+
+        Examples:
+            >>> from supar.models.const.tt.transform import TetraTaggingTree
+            >>> tree = nltk.Tree.fromstring('''
+                                            (TOP
+                                              (S
+                                                (NP (_ She))
+                                                (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis)))))
+                                                (_ .)))
+                                            ''')
+            >>> tree.pretty_print()
+                         TOP
+                          |
+                          S
+              ____________|________________
+             |            VP               |
+             |     _______|_____           |
+             |    |             S          |
+             |    |             |          |
+             |    |             VP         |
+             |    |        _____|____      |
+             NP   |       |          NP    |
+             |    |       |          |     |
+             _    _       _          _     _
+             |    |       |          |     |
+            She enjoys playing     tennis  .
+
+            >>> tree = TetraTaggingTree.binarize(tree, left=False, implicit=True)
+            >>> tree.pretty_print()
+                         TOP
+                          |
+                          S
+              ____________|______
+             |
+             |             ______|___________
+             |            VP                 |
+             |     _______|______            |
+             |    |            S::VP         |
+             |    |        ______|_____      |
+             NP                        NP
+             |    |       |            |     |
+             _    _       _            _     _
+             |    |       |            |     |
+            She enjoys playing       tennis  .
+
+            >>> TetraTaggingTree.tree2action(tree)
+            (['l/NP', 'l/', 'l/', 'r/NP', 'r/'], ['L/S', 'L/VP', 'R/S::VP', 'R/'])
+        """
+
+        def traverse(tree: nltk.Tree, left: bool = True) -> List:
+            if len(tree) == 1 and not isinstance(tree[0], nltk.Tree):
+                return ['l' if left else 'r'], []
+            if len(tree) == 1 and not isinstance(tree[0][0], nltk.Tree):
+                return [f"{'l' if left else 'r'}/{tree.label()}"], []
+            return tuple(sum(i, []) for i in zip(*[traverse(tree[0]),
+                                                   ([], [f'{("L" if left else "R")}/{tree.label()}']),
+                                                   traverse(tree[1], False)]))
+        return traverse(tree[0])
+
+    @classmethod
+    def action2tree(
+        cls,
+        tree: nltk.Tree,
+        actions: Tuple[Sequence, Sequence],
+        mark: Union[str, Tuple[str]] = ('*', '|<>'),
+        join: str = '::',
+    ) -> nltk.Tree:
+        r"""
+        Recovers a constituency tree from tetra-tagging actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                An empty tree that provides a base for building a result tree.
+            actions (Tuple[Sequence, Sequence]):
+                Tetra-tagging actions.
+            mark (Union[str, List[str]]):
+                A string used to mark newly inserted nodes. Non-terminals containing this will be removed.
+                Default: ``('*', '|<>')``.
+            join (str):
+                A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains.
+                Default: ``'::'``.
+
+        Returns:
+            A result constituency tree.
+
+        Examples:
+            >>> from supar.models.const.tt.transform import TetraTaggingTree
+            >>> tree = TetraTaggingTree.totree(['She', 'enjoys', 'playing', 'tennis', '.'], 'TOP')
+            >>> actions = (['l/NP', 'l/', 'l/', 'r/NP', 'r/'], ['L/S', 'L/VP', 'R/S::VP', 'R/'])
+            >>> TetraTaggingTree.action2tree(tree, actions).pretty_print()
+                         TOP
+                          |
+                          S
+              ____________|________________
+             |            VP               |
+             |     _______|_____           |
+             |    |             S          |
+             |    |             |          |
+             |    |             VP         |
+             |    |        _____|____      |
+             NP   |       |          NP    |
+             |    |       |          |     |
+             _    _       _          _     _
+             |    |       |          |     |
+            She enjoys playing     tennis  .
+
+        """
+
+        stack = []
+        leaves = [nltk.Tree(pos, [token]) for token, pos in tree.pos()]
+        for i, (al, an) in enumerate(zip(*actions)):
+            leaf = nltk.Tree(al.split('/', 1)[1], [leaves[i]])
+            if al.startswith('l'):
+                stack.append([leaf, None])
+            else:
+                slot = stack[-1][1]
+                slot.append(leaf)
+            if an.startswith('L'):
+                node = nltk.Tree(an.split('/', 1)[1], [stack[-1][0]])
+                stack[-1][0] = node
+            else:
+                node = nltk.Tree(an.split('/', 1)[1], [stack.pop()[0]])
+                slot = stack[-1][1]
+                slot.append(node)
+            stack[-1][1] = node
+        # the last leaf must be leftward
+        leaf = nltk.Tree(actions[0][-1].split('/', 1)[1], [leaves[-1]])
+        if len(stack) > 0:
+            stack[-1][1].append(leaf)
+        else:
+            stack.append([leaf, None])
+
+        def debinarize(tree):
+            if len(tree) == 1 and not isinstance(tree[0], nltk.Tree):
+                return [tree]
+            label, children = tree.label(), []
+            for child in tree:
+                children.extend(debinarize(child))
+            if not label or label.endswith(mark):
+                return children
+            labels = label.split(join) if join in label else [label]
+            tree = nltk.Tree(labels[-1], children)
+            for label in reversed(labels[:-1]):
+                tree = nltk.Tree(label, [tree])
+            return [tree]
+        return debinarize(nltk.Tree(tree.label(), [stack[0][0]]))[0]
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        **kwargs
+    ) -> List[TetraTaggingTreeSentence]:
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                A filename or a list of instances.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+
+        Returns:
+            A list of :class:`TetraTaggingTreeSentence` instances.
+        """
+
+        if lang is not None:
+            tokenizer = Tokenizer(lang)
+        if isinstance(data, str) and os.path.exists(data):
+            if data.endswith('.txt'):
+                data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1)
+            else:
+                data = open(data)
+        else:
+            if lang is not None:
+                data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+
+        index = 0
+        for s in data:
+            try:
+                tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
+                sentence = TetraTaggingTreeSentence(self, tree, index)
+            except ValueError:
+                logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!")
+                continue
+            else:
+                yield sentence
+                index += 1
+        self.root = tree.label()
+
+
+class TetraTaggingTreeSentence(Sentence):
+    r"""
+    Args:
+        transform (TetraTaggingTree):
+            A :class:`TetraTaggingTree` object.
+        tree (nltk.tree.Tree):
+            A :class:`nltk.tree.Tree` object.
+        index (Optional[int]):
+            Index of the sentence in the corpus. Default: ``None``.
+    """
+
+    def __init__(
+        self,
+        transform: TetraTaggingTree,
+        tree: nltk.Tree,
+        index: Optional[int] = None
+    ) -> TetraTaggingTreeSentence:
+        super().__init__(transform, index)
+
+        words, tags = zip(*tree.pos())
+        leaves, nodes = None, None
+        if transform.training:
+            oracle_tree = tree.copy(True)
+            # the root node must have a unary chain
+            if len(oracle_tree) > 1:
+                oracle_tree[:] = [nltk.Tree('*', oracle_tree)]
+            oracle_tree = TetraTaggingTree.binarize(oracle_tree, left=False, implicit=True)
+            if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree):
+                oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]])
+            leaves, nodes = transform.tree2action(oracle_tree)
+        self.values = [words, tags, tree, leaves, nodes]
+
+    def __repr__(self):
+        return self.values[-3].pformat(1000000)
+
+    def pretty_print(self):
+        self.values[-3].pretty_print()
diff --git a/tania_scripts/supar/models/const/vi/__init__.py b/tania_scripts/supar/models/const/vi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db91608946d018e5adb1dca9e49acb144fc7eb31
--- /dev/null
+++ b/tania_scripts/supar/models/const/vi/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import VIConstituencyModel
+from .parser import VIConstituencyParser
+
+__all__ = ['VIConstituencyModel', 'VIConstituencyParser']
diff --git a/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..061ba0053d502449b355ff1e49fecf53cd5b85e9
Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5780eac044e7f23e6effee39efea6913adc7c972
Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74c5b353f48c7a6aaa1d1f8944496ff7070b89ec
Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef42262709e0f807a94447a265c803f5688b9c97
Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6aec5892d74be8c41b6e936b63abef924b1a5841
Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2279ac265980ece4415c326074e52e8ec8ef509
Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/const/vi/model.py b/tania_scripts/supar/models/const/vi/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c44daac95be0d324e239a2fd5ad6d791a0be64ea
--- /dev/null
+++ b/tania_scripts/supar/models/const/vi/model.py
@@ -0,0 +1,237 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from supar.models.const.crf.model import CRFConstituencyModel
+from supar.modules import MLP, Biaffine, Triaffine
+from supar.structs import ConstituencyCRF, ConstituencyLBP, ConstituencyMFVI
+from supar.utils import Config
+
+
+class VIConstituencyModel(CRFConstituencyModel):
+    r"""
+    The implementation of Constituency Parser using variational inference.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_span_mlp (int):
+            Span MLP size. Default: 500.
+        n_pair_mlp (int):
+            Binary factor MLP size. Default: 100.
+        n_label_mlp  (int):
+            Label MLP size. Default: 100.
+        mlp_dropout (float):
+            The dropout ratio of MLP layers. Default: .33.
+        inference (str):
+            Approximate inference methods. Default: ``mfvi``.
+        max_iter (int):
+            Max iteration times for inference. Default: 3.
+        interpolation (int):
+            Constant to even out the label/edge loss. Default: .1.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, True),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_span_mlp=500,
+                 n_pair_mlp=100,
+                 n_label_mlp=100,
+                 mlp_dropout=.33,
+                 inference='mfvi',
+                 max_iter=3,
+                 interpolation=0.1,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.span_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout)
+        self.span_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout)
+        self.pair_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout)
+        self.pair_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout)
+        self.pair_mlp_b = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout)
+        self.label_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout)
+        self.label_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout)
+
+        self.span_attn = Biaffine(n_in=n_span_mlp, bias_x=True, bias_y=False)
+        self.pair_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=False)
+        self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True)
+        self.inference = (ConstituencyMFVI if inference == 'mfvi' else ConstituencyLBP)(max_iter)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words, feats):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor, ~torch.Tensor:
+                Scores of all possible constituents (``[batch_size, seq_len, seq_len]``),
+                second-order triples (``[batch_size, seq_len, seq_len, n_labels]``) and
+                all possible labels on each constituent (``[batch_size, seq_len, seq_len, n_labels]``).
+        """
+
+        x = self.encode(words, feats)
+
+        x_f, x_b = x.chunk(2, -1)
+        x = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1)
+
+        span_l = self.span_mlp_l(x)
+        span_r = self.span_mlp_r(x)
+        pair_l = self.pair_mlp_l(x)
+        pair_r = self.pair_mlp_r(x)
+        pair_b = self.pair_mlp_b(x)
+        label_l = self.label_mlp_l(x)
+        label_r = self.label_mlp_r(x)
+
+        # [batch_size, seq_len, seq_len]
+        s_span = self.span_attn(span_l, span_r)
+        s_pair = self.pair_attn(pair_l, pair_r, pair_b).permute(0, 3, 1, 2)
+        # [batch_size, seq_len, seq_len, n_labels]
+        s_label = self.label_attn(label_l, label_r).permute(0, 2, 3, 1)
+
+        return s_span, s_pair, s_label
+
+    def loss(self, s_span, s_pair, s_label, charts, mask):
+        r"""
+        Args:
+            s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all constituents.
+            s_pair (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+                Scores of second-order triples.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all constituent labels.
+            charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                The tensor of gold-standard labels. Positions without labels are filled with -1.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The training loss and marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        span_mask = charts.ge(0) & mask
+        span_loss, span_probs = self.inference((s_span, s_pair), mask, span_mask)
+        label_loss = self.criterion(s_label[span_mask], charts[span_mask])
+        loss = self.args.interpolation * label_loss + (1 - self.args.interpolation) * span_loss
+
+        return loss, span_probs
+
+    def decode(self, s_span, s_label, mask):
+        r"""
+        Args:
+            s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all constituents.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all constituent labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            List[List[Tuple]]:
+                Sequences of factorized labeled trees.
+        """
+
+        span_preds = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).argmax
+        label_preds = s_label.argmax(-1).tolist()
+        return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)]
diff --git a/tania_scripts/supar/models/const/vi/parser.py b/tania_scripts/supar/models/const/vi/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..5721a9dcc78f5c7e16484e4c18bc25fed51d81fd
--- /dev/null
+++ b/tania_scripts/supar/models/const/vi/parser.py
@@ -0,0 +1,108 @@
+# -*- coding: utf-8 -*-
+
+from typing import Dict, Iterable, Set, Union
+
+import torch
+
+from supar.models.const.crf.parser import CRFConstituencyParser
+from supar.models.const.crf.transform import Tree
+from supar.models.const.vi.model import VIConstituencyModel
+from supar.utils import Config
+from supar.utils.logging import get_logger
+from supar.utils.metric import SpanMetric
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class VIConstituencyParser(CRFConstituencyParser):
+    r"""
+    The implementation of Constituency Parser using variational inference.
+    """
+
+    NAME = 'vi-constituency'
+    MODEL = VIConstituencyModel
+
+    def train(
+        self,
+        train,
+        dev,
+        test,
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32, workers: int = 0, amp: bool = False, cache: bool = False,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
+        equal: Dict = {'ADVP': 'PRT'},
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, *feats, _, charts = batch
+        mask = batch.mask[:, 1:]
+        mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
+        s_span, s_pair, s_label = self.model(words, feats)
+        loss, _ = self.model.loss(s_span, s_pair, s_label, charts, mask)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> SpanMetric:
+        words, *feats, trees, charts = batch
+        mask = batch.mask[:, 1:]
+        mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
+        s_span, s_pair, s_label = self.model(words, feats)
+        loss, s_span = self.model.loss(s_span, s_pair, s_label, charts, mask)
+        chart_preds = self.model.decode(s_span, s_label, mask)
+        preds = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
+                 for tree, chart in zip(trees, chart_preds)]
+        return SpanMetric(loss,
+                          [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds],
+                          [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees])
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats, trees = batch
+        mask, lens = batch.mask[:, 1:], batch.lens - 2
+        mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
+        s_span, s_pair, s_label = self.model(words, feats)
+        s_span = self.model.inference((s_span, s_pair), mask)
+        chart_preds = self.model.decode(s_span, s_label, mask)
+        batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
+                       for tree, chart in zip(trees, chart_preds)]
+        if self.args.prob:
+            batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)]
+        return batch
diff --git a/tania_scripts/supar/models/dep/__init__.py b/tania_scripts/supar/models/dep/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aad535787edd8e64b6a673c79f969d10f3133a39
--- /dev/null
+++ b/tania_scripts/supar/models/dep/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+
+from .biaffine import BiaffineDependencyModel, BiaffineDependencyParser
+from .crf import CRFDependencyModel, CRFDependencyParser
+from .crf2o import CRF2oDependencyModel, CRF2oDependencyParser
+from .vi import VIDependencyModel, VIDependencyParser
+from .sl import SLDependencyModel, SLDependencyParser
+from .eager import ArcEagerDependencyModel, ArcEagerDependencyParser
+
+__all__ = ['BiaffineDependencyModel', 'BiaffineDependencyParser',
+           'CRFDependencyModel', 'CRFDependencyParser',
+           'CRF2oDependencyModel', 'CRF2oDependencyParser',
+           'VIDependencyModel', 'VIDependencyParser',
+           'SLDependencyModel', 'SLDependencyParser',
+           'ArcEagerDependencyModel', 'ArcEagerDependencyParser']
diff --git a/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7861152ee2c2a8276035f9246170cd9dc9cc615c
Binary files /dev/null and b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e046f387fd327180d2c8a0e599fbf49639f9884c
Binary files /dev/null and b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__init__.py b/tania_scripts/supar/models/dep/biaffine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d757c65a2c15d8e495528d9029ad34440eee6ebc
--- /dev/null
+++ b/tania_scripts/supar/models/dep/biaffine/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import BiaffineDependencyModel
+from .parser import BiaffineDependencyParser
+
+__all__ = ['BiaffineDependencyModel', 'BiaffineDependencyParser']
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a03c15f409c6cb82e0c38e8de1c728163143075d
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cc85e00431ce55159ade848598530026f4c870f
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c463ffc47a61e67f929c1d390efdc163db1ecc9
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec57b786c69963f7b6418e915051312869bea45f
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a161c0ca80e6f499166fc7e4f195252909a9dc8
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41daea0ef2c2181fcba2b5fbf9a540eded84891b
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a172f6c96a8bba92e2b94ef264d11b820f66f204
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b880c597dadc460d71c38295eb68e34d01977a2
Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/biaffine/model.py b/tania_scripts/supar/models/dep/biaffine/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..14420cb8563dbbdd606afad43b6f699c293d2d68
--- /dev/null
+++ b/tania_scripts/supar/models/dep/biaffine/model.py
@@ -0,0 +1,234 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from supar.model import Model
+from supar.models.dep.biaffine.transform import CoNLL
+from supar.modules import MLP, Biaffine
+from supar.structs import DependencyCRF, MatrixTree
+from supar.utils import Config
+from supar.utils.common import MIN
+
+
+class BiaffineDependencyModel(Model):
+    r"""
+    The implementation of Biaffine Dependency Parser :cite:`dozat-etal-2017-biaffine`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_rels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_arc_mlp (int):
+            Arc MLP size. Default: 500.
+        n_rel_mlp  (int):
+            Label MLP size. Default: 100.
+        mlp_dropout (float):
+            The dropout ratio of MLP layers. Default: .33.
+        scale (float):
+            Scaling factor for affine scores. Default: 0.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_rels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['tag', 'char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_arc_mlp=500,
+                 n_rel_mlp=100,
+                 mlp_dropout=.33,
+                 scale=0,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.arc_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout)
+        self.arc_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout)
+        self.rel_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout)
+        self.rel_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout)
+
+        self.arc_attn = Biaffine(n_in=n_arc_mlp, scale=scale, bias_x=True, bias_y=False)
+        self.rel_attn = Biaffine(n_in=n_rel_mlp, n_out=n_rels, bias_x=True, bias_y=True)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words, feats=None):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible arcs.
+                The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
+                scores of all possible labels on each arc.
+        """
+
+        x = self.encode(words, feats)
+        mask = words.ne(self.args.pad_index) if len(words.shape) < 3 else words.ne(self.args.pad_index).any(-1)
+
+        arc_d = self.arc_mlp_d(x)
+        arc_h = self.arc_mlp_h(x)
+        rel_d = self.rel_mlp_d(x)
+        rel_h = self.rel_mlp_h(x)
+
+        # [batch_size, seq_len, seq_len]
+        s_arc = self.arc_attn(arc_d, arc_h).masked_fill_(~mask.unsqueeze(1), MIN)
+        # [batch_size, seq_len, seq_len, n_rels]
+        s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1)
+
+        return s_arc, s_rel
+
+    def loss(self, s_arc, s_rel, arcs, rels, mask, partial=False):
+        r"""
+        Args:
+            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible arcs.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each arc.
+            arcs (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard arcs.
+            rels (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+            partial (bool):
+                ``True`` denotes the trees are partially annotated. Default: ``False``.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        if partial:
+            mask = mask & arcs.ge(0)
+        s_arc, arcs = s_arc[mask], arcs[mask]
+        s_rel, rels = s_rel[mask], rels[mask]
+        s_rel = s_rel[torch.arange(len(arcs)), arcs]
+        arc_loss = self.criterion(s_arc, arcs)
+        rel_loss = self.criterion(s_rel, rels)
+
+        return arc_loss + rel_loss
+
+    def decode(self, s_arc, s_rel, mask, tree=False, proj=False):
+        r"""
+        Args:
+            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible arcs.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each arc.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+            tree (bool):
+                If ``True``, ensures to output well-formed trees. Default: ``False``.
+            proj (bool):
+                If ``True``, ensures to output projective trees. Default: ``False``.
+
+        Returns:
+            ~torch.LongTensor, ~torch.LongTensor:
+                Predicted arcs and labels of shape ``[batch_size, seq_len]``.
+        """
+
+        lens = mask.sum(1)
+        arc_preds = s_arc.argmax(-1)
+        bad = [not CoNLL.istree(seq[1:i+1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist())]
+        if tree and any(bad):
+            arc_preds[bad] = (DependencyCRF if proj else MatrixTree)(s_arc[bad], mask[bad].sum(-1)).argmax
+        rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
+
+        return arc_preds, rel_preds
diff --git a/tania_scripts/supar/models/dep/biaffine/parser.py b/tania_scripts/supar/models/dep/biaffine/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7b3093fc323131b372ffd744d5de9cf1e117cf9
--- /dev/null
+++ b/tania_scripts/supar/models/dep/biaffine/parser.py
@@ -0,0 +1,213 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Iterable, Union
+
+import torch
+
+from supar.models.dep.biaffine.model import BiaffineDependencyModel
+from supar.models.dep.biaffine.transform import CoNLL
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, PAD, UNK
+from supar.utils.field import Field, RawField, SubwordField
+from supar.utils.fn import ispunct
+from supar.utils.logging import get_logger
+from supar.utils.metric import AttachmentMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class BiaffineDependencyParser(Parser):
+    r"""
+    The implementation of Biaffine Dependency Parser :cite:`dozat-etal-2017-biaffine`.
+    """
+
+    NAME = 'biaffine-dependency'
+    MODEL = BiaffineDependencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TAG = self.transform.CPOS
+        self.ARC, self.REL = self.transform.HEAD, self.transform.DEPREL
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        tree: bool = False,
+        proj: bool = False,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        tree: bool = True,
+        proj: bool = False,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        tree: bool = True,
+        proj: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, _, *feats, arcs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_rel = self.model(words, feats)
+        loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> AttachmentMetric:
+        words, _, *feats, arcs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_rel = self.model(words, feats)
+        loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
+        arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
+        if self.args.partial:
+            mask &= arcs.ge(0)
+        # ignore all punctuation if not specified
+        if not self.args.punct:
+            mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
+        return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask)
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, _, *feats = batch
+        mask, lens = batch.mask, (batch.lens - 1).tolist()
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_rel = self.model(words, feats)
+        arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
+        batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
+        batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
+        if self.args.prob:
+            batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.softmax(-1).unbind())]
+        return batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary.
+                Required if taking words as encoder input.
+                Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR, ELMO, BERT = None, None, None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True)
+            if 'tag' in args.feat:
+                TAG = Field('tags', bos=BOS)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len)
+            if 'elmo' in args.feat:
+                from allennlp.modules.elmo import batch_to_ids
+                ELMO = RawField('elmo')
+                ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device)
+            if 'bert' in args.feat:
+                t = TransformerTokenizer(args.bert)
+                BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t)
+                BERT.vocab = t.vocab
+        TEXT = RawField('texts')
+        ARC = Field('arcs', bos=BOS, use_vocab=False, fn=CoNLL.get_arcs)
+        REL = Field('rels', bos=BOS)
+        transform = CoNLL(FORM=(WORD, TEXT, CHAR, ELMO, BERT), CPOS=TAG, HEAD=ARC, DEPREL=REL)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if TAG is not None:
+                TAG.build(train)
+            if CHAR is not None:
+                CHAR.build(train)
+        REL.build(train)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_rels': len(REL.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'bert_pad_index': BERT.pad_index if BERT is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
diff --git a/tania_scripts/supar/models/dep/biaffine/transform.py b/tania_scripts/supar/models/dep/biaffine/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..073572b57a2610a829253e698d4f4f1bbee05b85
--- /dev/null
+++ b/tania_scripts/supar/models/dep/biaffine/transform.py
@@ -0,0 +1,379 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from io import StringIO
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
+
+from supar.utils.logging import get_logger
+from supar.utils.tokenizer import Tokenizer
+from supar.utils.transform import Sentence, Transform
+
+if TYPE_CHECKING:
+    from supar.utils import Field
+
+logger = get_logger(__name__)
+
+
+class CoNLL(Transform):
+    r"""
+    A :class:`CoNLL` object holds ten fields required for CoNLL-X data format :cite:`buchholz-marsi-2006-conll`.
+    Each field can be bound to one or more :class:`~supar.utils.field.Field` objects.
+    For example, ``FORM`` can contain both :class:`~supar.utils.field.Field` and :class:`~supar.utils.field.SubwordField`
+    to produce tensors for words and subwords.
+
+    Attributes:
+        ID:
+            Token counter, starting at 1.
+        FORM:
+            Words in the sentence.
+        LEMMA:
+            Lemmas or stems (depending on the particular treebank) of words, or underscores if not available.
+        CPOS:
+            Coarse-grained part-of-speech tags, where the tagset depends on the treebank.
+        POS:
+            Fine-grained part-of-speech tags, where the tagset depends on the treebank.
+        FEATS:
+            Unordered set of syntactic and/or morphological features (depending on the particular treebank),
+            or underscores if not available.
+        HEAD:
+            Heads of the tokens, which are either values of ID or zeros.
+        DEPREL:
+            Dependency relations to the HEAD.
+        PHEAD:
+            Projective heads of tokens, which are either values of ID or zeros, or underscores if not available.
+        PDEPREL:
+            Dependency relations to the PHEAD, or underscores if not available.
+    """
+
+    fields = ['ID', 'FORM', 'LEMMA', 'CPOS', 'POS', 'FEATS', 'HEAD', 'DEPREL', 'PHEAD', 'PDEPREL']
+
+    def __init__(
+        self,
+        ID: Optional[Union[Field, Iterable[Field]]] = None,
+        FORM: Optional[Union[Field, Iterable[Field]]] = None,
+        LEMMA: Optional[Union[Field, Iterable[Field]]] = None,
+        CPOS: Optional[Union[Field, Iterable[Field]]] = None,
+        POS: Optional[Union[Field, Iterable[Field]]] = None,
+        FEATS: Optional[Union[Field, Iterable[Field]]] = None,
+        HEAD: Optional[Union[Field, Iterable[Field]]] = None,
+        DEPREL: Optional[Union[Field, Iterable[Field]]] = None,
+        PHEAD: Optional[Union[Field, Iterable[Field]]] = None,
+        PDEPREL: Optional[Union[Field, Iterable[Field]]] = None
+    ) -> CoNLL:
+        super().__init__()
+
+        self.ID = ID
+        self.FORM = FORM
+        self.LEMMA = LEMMA
+        self.CPOS = CPOS
+        self.POS = POS
+        self.FEATS = FEATS
+        self.HEAD = HEAD
+        self.DEPREL = DEPREL
+        self.PHEAD = PHEAD
+        self.PDEPREL = PDEPREL
+
+    @property
+    def src(self):
+        return self.FORM, self.LEMMA, self.CPOS, self.POS, self.FEATS
+
+    @property
+    def tgt(self):
+        return self.HEAD, self.DEPREL, self.PHEAD, self.PDEPREL
+
+    @classmethod
+    def get_arcs(cls, sequence, placeholder='_'):
+        return [-1 if i == placeholder else int(i) for i in sequence]
+
+    @classmethod
+    def get_sibs(cls, sequence, placeholder='_'):
+        sibs = [[0] * (len(sequence) + 1) for _ in range(len(sequence) + 1)]
+        heads = [0] + [-1 if i == placeholder else int(i) for i in sequence]
+
+        for i, hi in enumerate(heads[1:], 1):
+            for j, hj in enumerate(heads[i + 1:], i + 1):
+                di, dj = hi - i, hj - j
+                if hi >= 0 and hj >= 0 and hi == hj and di * dj > 0:
+                    if abs(di) > abs(dj):
+                        sibs[i][hi] = j
+                    else:
+                        sibs[j][hj] = i
+                    break
+        return sibs[1:]
+
+    @classmethod
+    def get_edges(cls, sequence):
+        edges = [[0] * (len(sequence) + 1) for _ in range(len(sequence) + 1)]
+        for i, s in enumerate(sequence, 1):
+            if s != '_':
+                for pair in s.split('|'):
+                    edges[i][int(pair.split(':')[0])] = 1
+        return edges
+
+    @classmethod
+    def get_labels(cls, sequence):
+        labels = [[None] * (len(sequence) + 1) for _ in range(len(sequence) + 1)]
+        for i, s in enumerate(sequence, 1):
+            if s != '_':
+                for pair in s.split('|'):
+                    edge, label = pair.split(':', 1)
+                    labels[i][int(edge)] = label
+        return labels
+
+    @classmethod
+    def build_relations(cls, chart):
+        sequence = ['_'] * len(chart)
+        for i, row in enumerate(chart):
+            pairs = [(j, label) for j, label in enumerate(row) if label is not None]
+            if len(pairs) > 0:
+                sequence[i] = '|'.join(f"{head}:{label}" for head, label in pairs)
+        return sequence
+
+    @classmethod
+    def toconll(cls, tokens: List[Union[str, Tuple]]) -> str:
+        r"""
+        Converts a list of tokens to a string in CoNLL-X format with missing fields filled with underscores.
+
+        Args:
+            tokens (List[Union[str, Tuple]]):
+                This can be either a list of words, word/pos pairs or word/lemma/pos triples.
+
+        Returns:
+            A string in CoNLL-X format.
+
+        Examples:
+            >>> print(CoNLL.toconll(['She', 'enjoys', 'playing', 'tennis', '.']))
+            1       She     _       _       _       _       _       _       _       _
+            2       enjoys  _       _       _       _       _       _       _       _
+            3       playing _       _       _       _       _       _       _       _
+            4       tennis  _       _       _       _       _       _       _       _
+            5       .       _       _       _       _       _       _       _       _
+
+            >>> print(CoNLL.toconll([('She',     'she',    'PRP'),
+                                     ('enjoys',  'enjoy',  'VBZ'),
+                                     ('playing', 'play',   'VBG'),
+                                     ('tennis',  'tennis', 'NN'),
+                                     ('.',       '_',      '.')]))
+            1       She     she     PRP     _       _       _       _       _       _
+            2       enjoys  enjoy   VBZ     _       _       _       _       _       _
+            3       playing play    VBG     _       _       _       _       _       _
+            4       tennis  tennis  NN      _       _       _       _       _       _
+            5       .       _       .       _       _       _       _       _       _
+
+        """
+
+        if isinstance(tokens[0], str):
+            s = '\n'.join([f"{i}\t{word}\t" + '\t'.join(['_'] * 8)
+                           for i, word in enumerate(tokens, 1)])
+        elif len(tokens[0]) == 2:
+            s = '\n'.join([f"{i}\t{word}\t_\t{tag}\t" + '\t'.join(['_'] * 6)
+                           for i, (word, tag) in enumerate(tokens, 1)])
+        elif len(tokens[0]) == 3:
+            s = '\n'.join([f"{i}\t{word}\t{lemma}\t{tag}\t" + '\t'.join(['_'] * 6)
+                           for i, (word, lemma, tag) in enumerate(tokens, 1)])
+        else:
+            raise RuntimeError(f"Invalid sequence {tokens}. Only list of str or list of word/pos/lemma tuples are support.")
+        return s + '\n'
+
+    @classmethod
+    def isprojective(cls, sequence: List[int]) -> bool:
+        r"""
+        Checks if a dependency tree is projective.
+        This also works for partial annotation.
+
+        Besides the obvious crossing arcs, the examples below illustrate two non-projective cases
+        which are hard to detect in the scenario of partial annotation.
+
+        Args:
+            sequence (List[int]):
+                A list of head indices.
+
+        Returns:
+            ``True`` if the tree is projective, ``False`` otherwise.
+
+        Examples:
+            >>> CoNLL.isprojective([2, -1, 1])  # -1 denotes un-annotated cases
+            False
+            >>> CoNLL.isprojective([3, -1, 2])
+            False
+        """
+
+        pairs = [(h, d) for d, h in enumerate(sequence, 1) if h >= 0]
+        for i, (hi, di) in enumerate(pairs):
+            for hj, dj in pairs[i + 1:]:
+                (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj])
+                if li <= hj <= ri and hi == dj:
+                    return False
+                if lj <= hi <= rj and hj == di:
+                    return False
+                if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0:
+                    return False
+        return True
+
+    @classmethod
+    def istree(cls, sequence: List[int], proj: bool = False, multiroot: bool = False) -> bool:
+        r"""
+        Checks if the arcs form an valid dependency tree.
+
+        Args:
+            sequence (List[int]):
+                A list of head indices.
+            proj (bool):
+                If ``True``, requires the tree to be projective. Default: ``False``.
+            multiroot (bool):
+                If ``False``, requires the tree to contain only a single root. Default: ``True``.
+
+        Returns:
+            ``True`` if the arcs form an valid tree, ``False`` otherwise.
+
+        Examples:
+            >>> CoNLL.istree([3, 0, 0, 3], multiroot=True)
+            True
+            >>> CoNLL.istree([3, 0, 0, 3], proj=True)
+            False
+        """
+
+        from supar.structs.fn import tarjan
+        if proj and not cls.isprojective(sequence):
+            return False
+        n_roots = sum(head == 0 for head in sequence)
+        if n_roots == 0:
+            return False
+        if not multiroot and n_roots > 1:
+            return False
+        if any(i == head for i, head in enumerate(sequence, 1)):
+            return False
+        return next(tarjan(sequence), None) is None
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        proj: bool = False,
+        **kwargs
+    ) -> Iterable[CoNLLSentence]:
+        r"""
+        Loads the data in CoNLL-X format.
+        Also supports for loading data from CoNLL-U file with comments and non-integer IDs.
+
+        Args:
+            data (Union[str, Iterable]):
+                A filename or a list of instances.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+            proj (bool):
+                If ``True``, discards all non-projective sentences. Default: ``False``.
+
+        Returns:
+            A list of :class:`CoNLLSentence` instances.
+        """
+
+        isconll = False
+        if lang is not None:
+            tokenizer = Tokenizer(lang)
+        if isinstance(data, str) and os.path.exists(data):
+            f = open(data)
+            if data.endswith('.txt'):
+                lines = (i
+                         for s in f
+                         if len(s) > 1
+                         for i in StringIO(self.toconll(s.split() if lang is None else tokenizer(s)) + '\n'))
+            else:
+                lines, isconll = f, True
+        else:
+            if lang is not None:
+                data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+            lines = (i for s in data for i in StringIO(self.toconll(s) + '\n'))
+
+        index, sentence = 0, []
+        for line in lines:
+            line = line.strip()
+            if len(line) == 0:
+                sentence = CoNLLSentence(self, sentence, index)
+                if isconll and self.training and proj and not self.isprojective(list(map(int, sentence.arcs))):
+                    logger.warning(f"Sentence {index} is not projective. Discarding it!")
+                else:
+                    yield sentence
+                    index += 1
+                sentence = []
+            else:
+                sentence.append(line)
+
+
+class CoNLLSentence(Sentence):
+    r"""
+    Sencence in CoNLL-X format.
+
+    Args:
+        transform (CoNLL):
+            A :class:`~supar.utils.transform.CoNLL` object.
+        lines (List[str]):
+            A list of strings composing a sentence in CoNLL-X format.
+            Comments and non-integer IDs are permitted.
+        index (Optional[int]):
+            Index of the sentence in the corpus. Default: ``None``.
+
+    Examples:
+        >>> lines = ['# text = But I found the location wonderful and the neighbors very kind.',
+                     '1\tBut\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '2\tI\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '3\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '4\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '5\tlocation\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '6\twonderful\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '7\tand\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '7.1\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '8\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '9\tneighbors\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '10\tvery\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '11\tkind\t_\t_\t_\t_\t_\t_\t_\t_',
+                     '12\t.\t_\t_\t_\t_\t_\t_\t_\t_']
+        >>> sentence = CoNLLSentence(transform, lines)  # fields in transform are built from ptb.
+        >>> sentence.arcs = [3, 3, 0, 5, 6, 3, 6, 9, 11, 11, 6, 3]
+        >>> sentence.rels = ['cc', 'nsubj', 'root', 'det', 'nsubj', 'xcomp',
+                             'cc', 'det', 'dep', 'advmod', 'conj', 'punct']
+        >>> sentence
+        # text = But I found the location wonderful and the neighbors very kind.
+        1       But     _       _       _       _       3       cc      _       _
+        2       I       _       _       _       _       3       nsubj   _       _
+        3       found   _       _       _       _       0       root    _       _
+        4       the     _       _       _       _       5       det     _       _
+        5       location        _       _       _       _       6       nsubj   _       _
+        6       wonderful       _       _       _       _       3       xcomp   _       _
+        7       and     _       _       _       _       6       cc      _       _
+        7.1     found   _       _       _       _       _       _       _       _
+        8       the     _       _       _       _       9       det     _       _
+        9       neighbors       _       _       _       _       11      dep     _       _
+        10      very    _       _       _       _       11      advmod  _       _
+        11      kind    _       _       _       _       6       conj    _       _
+        12      .       _       _       _       _       3       punct   _       _
+    """
+
+    def __init__(self, transform: CoNLL, lines: List[str], index: Optional[int] = None) -> CoNLLSentence:
+        super().__init__(transform, index)
+
+        self.values = []
+        # record annotations for post-recovery
+        self.annotations = dict()
+
+        for i, line in enumerate(lines):
+            value = line.split('\t')
+            if value[0].startswith('#') or not value[0].isdigit():
+                self.annotations[-i - 1] = line
+            else:
+                self.annotations[len(self.values)] = line
+                self.values.append(value)
+        self.values = list(zip(*self.values))
+
+    def __repr__(self):
+        # cover the raw lines
+        merged = {**self.annotations,
+                  **{i: '\t'.join(map(str, line))
+                     for i, line in enumerate(zip(*self.values))}}
+        return '\n'.join(merged.values()) + '\n'
diff --git a/tania_scripts/supar/models/dep/crf/__init__.py b/tania_scripts/supar/models/dep/crf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..27cae45e1d3c23070acfa28ffb75aedb99bf560e
--- /dev/null
+++ b/tania_scripts/supar/models/dep/crf/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import CRFDependencyModel
+from .parser import CRFDependencyParser
+
+__all__ = ['CRFDependencyModel', 'CRFDependencyParser']
diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5810074dccba6a92cf481e86ef69ece8be74592c
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3609fddab845fa29cbbf98125bf4f14117dacd7c
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..350a941b0ad158a2c7ec7f74b109e4200050bd1e
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82c5f8a1f3181f461f6d3f7c6d863af8093a12ca
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0a41a163099e2b9eb112143902fdc323def85a0
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b93724d0f7d60bbc3af66307c55579f66b3ed2c5
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf/model.py b/tania_scripts/supar/models/dep/crf/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..472bd359efe2aa3b0670e0e8ef964d86c3b22822
--- /dev/null
+++ b/tania_scripts/supar/models/dep/crf/model.py
@@ -0,0 +1,134 @@
+# -*- coding: utf-8 -*-
+
+import torch
+from supar.models.dep.biaffine.model import BiaffineDependencyModel
+from supar.structs import DependencyCRF, MatrixTree
+
+
+class CRFDependencyModel(BiaffineDependencyModel):
+    r"""
+    The implementation of first-order CRF Dependency Parser
+    :cite:`zhang-etal-2020-efficient,ma-hovy-2017-neural,koo-etal-2007-structured`).
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_rels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_arc_mlp (int):
+            Arc MLP size. Default: 500.
+        n_rel_mlp  (int):
+            Label MLP size. Default: 100.
+        mlp_dropout (float):
+            The dropout ratio of MLP layers. Default: .33.
+        scale (float):
+            Scaling factor for affine scores. Default: 0.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+        proj (bool):
+            If ``True``, takes :class:`DependencyCRF` as inference layer, :class:`MatrixTree` otherwise.
+            Default: ``True``.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def loss(self, s_arc, s_rel, arcs, rels, mask, mbr=True, partial=False):
+        r"""
+        Args:
+            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible arcs.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each arc.
+            arcs (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard arcs.
+            rels (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+            mbr (bool):
+                If ``True``, returns marginals for MBR decoding. Default: ``True``.
+            partial (bool):
+                ``True`` denotes the trees are partially annotated. Default: ``False``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The training loss and
+                original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
+        """
+
+        CRF = DependencyCRF if self.args.proj else MatrixTree
+        arc_dist = CRF(s_arc, mask.sum(-1))
+        arc_loss = -arc_dist.log_prob(arcs, partial=partial).sum() / mask.sum()
+        arc_probs = arc_dist.marginals if mbr else s_arc
+        # -1 denotes un-annotated arcs
+        if partial:
+            mask = mask & arcs.ge(0)
+        s_rel, rels = s_rel[mask], rels[mask]
+        s_rel = s_rel[torch.arange(len(rels)), arcs[mask]]
+        rel_loss = self.criterion(s_rel, rels)
+        loss = arc_loss + rel_loss
+        return loss, arc_probs
diff --git a/tania_scripts/supar/models/dep/crf/parser.py b/tania_scripts/supar/models/dep/crf/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed16896b85dac689a14c92d5b89d6732fc85640
--- /dev/null
+++ b/tania_scripts/supar/models/dep/crf/parser.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+
+from typing import Iterable, Union
+
+import torch
+
+from supar.models.dep.biaffine.parser import BiaffineDependencyParser
+from supar.models.dep.crf.model import CRFDependencyModel
+from supar.structs import DependencyCRF, MatrixTree
+from supar.utils import Config
+from supar.utils.fn import ispunct
+from supar.utils.logging import get_logger
+from supar.utils.metric import AttachmentMetric
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class CRFDependencyParser(BiaffineDependencyParser):
+    r"""
+    The implementation of first-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`.
+    """
+
+    NAME = 'crf-dependency'
+    MODEL = CRFDependencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        mbr: bool = True,
+        tree: bool = False,
+        proj: bool = False,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        mbr: bool = True,
+        tree: bool = True,
+        proj: bool = True,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        mbr: bool = True,
+        tree: bool = True,
+        proj: bool = True,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, _, *feats, arcs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_rel = self.model(words, feats)
+        loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> AttachmentMetric:
+        words, _, *feats, arcs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_rel = self.model(words, feats)
+        loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial)
+        arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
+        if self.args.partial:
+            mask &= arcs.ge(0)
+        # ignore all punctuation if not specified
+        if not self.args.punct:
+            mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
+        return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask)
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        CRF = DependencyCRF if self.args.proj else MatrixTree
+        words, _, *feats = batch
+        mask, lens = batch.mask, batch.lens - 1
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_rel = self.model(words, feats)
+        s_arc = CRF(s_arc, lens).marginals if self.args.mbr else s_arc
+        arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
+        lens = lens.tolist()
+        batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
+        batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
+        if self.args.prob:
+            arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
+            batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())]
+        return batch
diff --git a/tania_scripts/supar/models/dep/crf2o/__init__.py b/tania_scripts/supar/models/dep/crf2o/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2acf9ce4dcc45c96c270384714470b59e94e927
--- /dev/null
+++ b/tania_scripts/supar/models/dep/crf2o/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import CRF2oDependencyModel
+from .parser import CRF2oDependencyParser
+
+__all__ = ['CRF2oDependencyModel', 'CRF2oDependencyParser']
diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb6754d2364431258038d51a7d60f97240ef9d09
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cf033823fb198f1abbc2c5fe24ab7427a6e3819
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6243392a5a1c5d884a93eb8cd18fbcf1edad6ec4
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7586c6f2981ac0dc9e8dbe7252d6f9ed741d5805
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7cf299268ee7bd3bd211e8bd3d917638bd104180
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec22eb9df296b742697cdd799d431ff90162da41
Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/crf2o/model.py b/tania_scripts/supar/models/dep/crf2o/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b53bdd4f7db62faf30350e50877f7d61b4a6cc2
--- /dev/null
+++ b/tania_scripts/supar/models/dep/crf2o/model.py
@@ -0,0 +1,263 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from supar.models.dep.biaffine.model import BiaffineDependencyModel
+from supar.models.dep.biaffine.transform import CoNLL
+from supar.modules import MLP, Biaffine, Triaffine
+from supar.structs import Dependency2oCRF, MatrixTree
+from supar.utils import Config
+from supar.utils.common import MIN
+
+
+class CRF2oDependencyModel(BiaffineDependencyModel):
+    r"""
+    The implementation of second-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_rels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_arc_mlp (int):
+            Arc MLP size. Default: 500.
+        n_sib_mlp (int):
+            Sibling MLP size. Default: 100.
+        n_rel_mlp  (int):
+            Label MLP size. Default: 100.
+        mlp_dropout (float):
+            The dropout ratio of MLP layers. Default: .33.
+        scale (float):
+            Scaling factor for affine scores. Default: 0.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_rels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_arc_mlp=500,
+                 n_sib_mlp=100,
+                 n_rel_mlp=100,
+                 mlp_dropout=.33,
+                 scale=0,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.arc_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout)
+        self.arc_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout)
+        self.sib_mlp_s = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout)
+        self.sib_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout)
+        self.sib_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout)
+        self.rel_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout)
+        self.rel_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout)
+
+        self.arc_attn = Biaffine(n_in=n_arc_mlp, scale=scale, bias_x=True, bias_y=False)
+        self.sib_attn = Triaffine(n_in=n_sib_mlp, scale=scale, bias_x=True, bias_y=True)
+        self.rel_attn = Biaffine(n_in=n_rel_mlp, n_out=n_rels, bias_x=True, bias_y=True)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words, feats=None):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor, ~torch.Tensor:
+                Scores of all possible arcs (``[batch_size, seq_len, seq_len]``),
+                dependent-head-sibling triples (``[batch_size, seq_len, seq_len, seq_len]``) and
+                all possible labels on each arc (``[batch_size, seq_len, seq_len, n_labels]``).
+        """
+
+        x = self.encode(words, feats)
+        mask = words.ne(self.args.pad_index) if len(words.shape) < 3 else words.ne(self.args.pad_index).any(-1)
+
+        arc_d = self.arc_mlp_d(x)
+        arc_h = self.arc_mlp_h(x)
+        sib_s = self.sib_mlp_s(x)
+        sib_d = self.sib_mlp_d(x)
+        sib_h = self.sib_mlp_h(x)
+        rel_d = self.rel_mlp_d(x)
+        rel_h = self.rel_mlp_h(x)
+
+        # [batch_size, seq_len, seq_len]
+        s_arc = self.arc_attn(arc_d, arc_h).masked_fill_(~mask.unsqueeze(1), MIN)
+        # [batch_size, seq_len, seq_len, seq_len]
+        s_sib = self.sib_attn(sib_s, sib_d, sib_h).permute(0, 3, 1, 2)
+        # [batch_size, seq_len, seq_len, n_rels]
+        s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1)
+
+        return s_arc, s_sib, s_rel
+
+    def loss(self, s_arc, s_sib, s_rel, arcs, sibs, rels, mask, mbr=True, partial=False):
+        r"""
+        Args:
+            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible arcs.
+            s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+                Scores of all possible dependent-head-sibling triples.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each arc.
+            arcs (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard arcs.
+            sibs (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                The tensor of gold-standard siblings.
+            rels (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+            mbr (bool):
+                If ``True``, returns marginals for MBR decoding. Default: ``True``.
+            partial (bool):
+                ``True`` denotes the trees are partially annotated. Default: ``False``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The training loss and
+                original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
+        """
+
+        arc_dist = Dependency2oCRF((s_arc, s_sib), mask.sum(-1))
+        arc_loss = -arc_dist.log_prob((arcs, sibs), partial=partial).sum() / mask.sum()
+        if mbr:
+            s_arc, s_sib = arc_dist.marginals
+        # -1 denotes un-annotated arcs
+        if partial:
+            mask = mask & arcs.ge(0)
+        s_rel, rels = s_rel[mask], rels[mask]
+        s_rel = s_rel[torch.arange(len(rels)), arcs[mask]]
+        rel_loss = self.criterion(s_rel, rels)
+        loss = arc_loss + rel_loss
+        return loss, s_arc, s_sib
+
+    def decode(self, s_arc, s_sib, s_rel, mask, tree=False, mbr=True, proj=False):
+        r"""
+        Args:
+            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible arcs.
+            s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+                Scores of all possible dependent-head-sibling triples.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each arc.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+            tree (bool):
+                If ``True``, ensures to output well-formed trees. Default: ``False``.
+            mbr (bool):
+                If ``True``, performs MBR decoding. Default: ``True``.
+            proj (bool):
+                If ``True``, ensures to output projective trees. Default: ``False``.
+
+        Returns:
+            ~torch.LongTensor, ~torch.LongTensor:
+                Predicted arcs and labels of shape ``[batch_size, seq_len]``.
+        """
+
+        lens = mask.sum(1)
+        arc_preds = s_arc.argmax(-1)
+        bad = [not CoNLL.istree(seq[1:i+1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist())]
+        if tree and any(bad):
+            if proj:
+                arc_preds[bad] = Dependency2oCRF((s_arc[bad], s_sib[bad]), mask[bad].sum(-1)).argmax
+            else:
+                arc_preds[bad] = MatrixTree(s_arc[bad], mask[bad].sum(-1)).argmax
+        rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
+
+        return arc_preds, rel_preds
diff --git a/tania_scripts/supar/models/dep/crf2o/parser.py b/tania_scripts/supar/models/dep/crf2o/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..c822b2a154aa326c743fae85b157a476c9fdffcb
--- /dev/null
+++ b/tania_scripts/supar/models/dep/crf2o/parser.py
@@ -0,0 +1,216 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Iterable, Union
+
+import torch
+
+from supar.models.dep.biaffine.parser import BiaffineDependencyParser
+from supar.models.dep.biaffine.transform import CoNLL
+from supar.models.dep.crf2o.model import CRF2oDependencyModel
+from supar.structs import Dependency2oCRF
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, PAD, UNK
+from supar.utils.field import ChartField, Field, RawField, SubwordField
+from supar.utils.fn import ispunct
+from supar.utils.logging import get_logger
+from supar.utils.metric import AttachmentMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class CRF2oDependencyParser(BiaffineDependencyParser):
+    r"""
+    The implementation of second-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`.
+    """
+
+    NAME = 'crf2o-dependency'
+    MODEL = CRF2oDependencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        mbr: bool = True,
+        tree: bool = False,
+        proj: bool = False,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        mbr: bool = True,
+        tree: bool = True,
+        proj: bool = True,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        mbr: bool = True,
+        tree: bool = True,
+        proj: bool = True,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, _, *feats, arcs, sibs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_sib, s_rel = self.model(words, feats)
+        loss, *_ = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> AttachmentMetric:
+        words, _, *feats, arcs, sibs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_sib, s_rel = self.model(words, feats)
+        loss, s_arc, s_sib = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
+        arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
+        if self.args.partial:
+            mask &= arcs.ge(0)
+        # ignore all punctuation if not specified
+        if not self.args.punct:
+            mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
+        return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask)
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, _, *feats = batch
+        mask, lens = batch.mask, batch.lens - 1
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_sib, s_rel = self.model(words, feats)
+        s_arc, s_sib = Dependency2oCRF((s_arc, s_sib), lens).marginals if self.args.mbr else (s_arc, s_sib)
+        arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
+        lens = lens.tolist()
+        batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
+        batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
+        if self.args.prob:
+            arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
+            batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())]
+        return batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default: 2.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        TAG, CHAR, ELMO, BERT = None, None, None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True)
+            if 'tag' in args.feat:
+                TAG = Field('tags', bos=BOS)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len)
+            if 'elmo' in args.feat:
+                from allennlp.modules.elmo import batch_to_ids
+                ELMO = RawField('elmo')
+                ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device)
+            if 'bert' in args.feat:
+                t = TransformerTokenizer(args.bert)
+                BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t)
+                BERT.vocab = t.vocab
+        TEXT = RawField('texts')
+        ARC = Field('arcs', bos=BOS, use_vocab=False, fn=CoNLL.get_arcs)
+        SIB = ChartField('sibs', bos=BOS, use_vocab=False, fn=CoNLL.get_sibs)
+        REL = Field('rels', bos=BOS)
+        transform = CoNLL(FORM=(WORD, TEXT, CHAR, ELMO, BERT), CPOS=TAG, HEAD=(ARC, SIB), DEPREL=REL)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if TAG is not None:
+                TAG.build(train)
+            if CHAR is not None:
+                CHAR.build(train)
+        REL.build(train)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_rels': len(REL.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'bert_pad_index': BERT.pad_index if BERT is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
diff --git a/tania_scripts/supar/models/dep/eager/__init__.py b/tania_scripts/supar/models/dep/eager/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e985987237e25784b579e74d88c25f8a1b3de7
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/__init__.py
@@ -0,0 +1,2 @@
+from .parser import ArcEagerDependencyParser
+from .model import ArcEagerDependencyModel
\ No newline at end of file
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69b6b347907cd94bda81a4738bf8e7e85860c421
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4eb2047d9191fa7725d4f6c991fbd42ac25afea
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c35f8752a167848718305af293b7ba376172b991
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd043bef1f8af3f5ca69f649ed48a44101524e37
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f44d08db4b5106cd902edd53576ffd54fc22572
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c53842f2152b6d56e2c6e8626ae48e37586c53a4
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35d39d5378eae49dfe2f72c59831aec35703b2ca
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbdee27774e2be073b02a1de3f1e5540e6c700cf
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/model.py b/tania_scripts/supar/models/dep/eager/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..300e82969f98af688a3f1700c7aab969ff2b3d75
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/model.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn as nn
+from supar.model import Model
+from supar.modules import MLP, DecoderLSTM
+from supar.utils import Config
+from typing import Tuple, List
+
+class ArcEagerDependencyModel(Model):
+
+    def __init__(self,
+                 n_words,
+                 n_transitions,
+                 n_trels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_arc_mlp=500,
+                 n_rel_mlp=100,
+                 mlp_dropout=.33,
+                 scale=0,
+                 pad_index=0,
+                 unk_index=1,
+                 n_decoder_layers=4,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        # create decoder for buffer front, stack top and rels
+        self.transition_decoder = self.rel_decoder = None, None
+
+        stack_size, buffer_size = [self.args.n_encoder_hidden//2] * 2 if (self.args.n_encoder_hidden % 2) == 0 \
+            else [self.args.n_encoder_hidden//2, self.args.n_encoder_hidden//2+1]
+
+        # create projection to reduce dimensionality of the encoder
+        self.stack_proj = MLP(
+                n_in=self.args.n_encoder_hidden, n_out=stack_size,
+                dropout=mlp_dropout)
+        self.buffer_proj = MLP(
+            n_in=self.args.n_encoder_hidden, n_out=buffer_size,
+            dropout=mlp_dropout
+        )
+
+        if self.args.decoder == 'lstm':
+            decoder = lambda out_dim: DecoderLSTM(
+                self.args.n_encoder_hidden, self.args.n_encoder_hidden, out_dim,
+                self.args.n_decoder_layers, dropout=mlp_dropout, device=self.device
+            )
+        else:
+            decoder = lambda out_dim: MLP(
+                n_in=self.args.n_encoder_hidden, n_out=out_dim, dropout=mlp_dropout
+            )
+
+        self.transition_decoder = decoder(n_transitions)
+        self.trel_decoder = decoder(n_trels)
+
+        # create delay projection
+        if self.args.delay != 0:
+            self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay + 1),
+                                  n_out=self.args.n_encoder_hidden, dropout=mlp_dropout)
+
+        # create PoS tagger
+        if self.args.encoder in ['lstm', 'bert']:
+            self.pos_tagger = DecoderLSTM(
+                self.args.n_encoder_hidden, self.args.n_encoder_hidden, self.args.n_tags, 
+                num_layers=1, dropout=mlp_dropout, device=self.device
+            )
+        else:
+            #args.encoder is bert
+            self.pos_tagger = nn.Identity()
+
+        self.criterion = nn.CrossEntropyLoss()
+
+    def encoder_forward(self, words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]:
+        """
+        Applies encoding forward pass. Maps a tensor of word indices (`words`) to their corresponding neural
+        representation.
+        Args:
+            words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay]
+            feats: List[torch.Tensor]
+            lens: List[int]
+
+        Returns: x, qloss
+            x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim]
+            qloss: torch.FloatTensor ~ 1
+
+        """
+
+        x = super().encode(words, feats)
+        s_tag = self.pos_tagger(x[:, 1:-(1+self.args.delay), :])
+
+        # adjust lengths to allow delay predictions
+        # x ~ [batch_size, bos + pad(seq_len) + eos, embed_dim]
+        if self.args.delay != 0:
+            x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay + 1)], dim=2)
+            x = self.delay_proj(x)
+
+        # pass through vector quantization
+        x, qloss = self.vq_forward(x)
+        return x, s_tag, qloss
+
+    def decoder_forward(self, x: torch.Tensor, stack_top: torch.Tensor, buffer_front: torch.Tensor) -> Tuple[torch.Tensor]:
+        """
+        Args:
+            x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim]
+            stack_top: torch.IntTensor ~ [batch_size, pad(tr_len)]
+            buffer_front: torch.IntTensor ~ [batch_size, pad(tr_len)]
+
+        Returns: s_transition, s_trel
+            s_transition: torch.FloatTensor ~ [batch_size, pad(tr_len), n_transitions]
+            s_trel: torch.FloatTensor ~ [batch_size, pad(tr_len), n_trels]
+        """
+        batch_size = x.shape[0]
+
+        # obtain encoded embeddings for stack_top and buffer_front
+        stack_top = torch.stack([x[i, stack_top[i], :] for i in range(batch_size)])
+        buffer_front = torch.stack([x[i, buffer_front[i], :] for i in range(batch_size)])
+
+        # pass through projections
+        stack_top = self.stack_proj(stack_top)
+        buffer_front = self.buffer_proj(buffer_front)
+
+        # stack_top ~ [batch_size, pad(tr_len), embed_dim//2]
+        # buffer_front ~ [batch_size, pad(tr_len), embed_dim//2]
+        # x ~ [batch_size, pad(tr_len), embed_dim]
+        x = torch.concat([stack_top, buffer_front], dim=-1)
+        
+        # s_transition ~ [batch_size, pad(tr_len), n_transitions]
+        # s_trel = [batch_size, pad(tr_len), n_trels]
+        s_transition = self.transition_decoder(x)
+        s_trel = self.trel_decoder(x)
+
+        return s_transition, s_trel
+
+    def forward(self, words: torch.Tensor, stack_top: torch.Tensor, buffer_front: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]:
+        """
+        Args:
+            words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay].
+            stack_top: torch.IntTensor ~ [batch_size, pad(tr_len)]
+            buffer_front: torch.IntTensor ~ [batch_size, pad(tr_len)]
+            feats: List[torch.Tensor]
+
+        Returns: s_transition, s_trel, qloss
+            s_transition: torch.FloatTensor ~ [batch_size, pad(tr_len), n_transitions]
+            s_trel: torch.FloatTensor ~ [batch_size, pad(tr_len), n_trels]
+            qloss: torch.FloatTensor ~ 1
+        """
+        x, s_tag, qloss = self.encoder_forward(words, feats)
+
+        s_transition, s_trel = self.decoder_forward(x, stack_top, buffer_front)
+        return s_transition, s_trel, s_tag, qloss
+
+    def decode(self, s_transition: torch.Tensor, s_trel: torch.Tensor, exclude: list = None):
+        transition_preds = s_transition.argsort(-1, descending=True)
+        if exclude:
+            s_trel[:, :, exclude] = -1
+        trel_preds = s_trel.argmax(-1)
+        return transition_preds, trel_preds
+    
+    def decode_stag(self, s_tag: torch.Tensor):
+        stag_preds = s_tag.argmax(-1)
+        #stag_preds = s_tag.argsort(-1, descending=True)
+        return stag_preds
+
+    def loss(self, s_transition: torch.Tensor, s_trel: torch.Tensor, s_tag,
+             transitions: torch.Tensor, trels: torch.Tensor, tags,
+             smask: torch.Tensor, trmask: torch.Tensor, TRANSITION):
+        s_transition, transitions = s_transition[trmask], transitions[trmask]
+        s_trel, trels = s_trel[trmask], trels[trmask]
+
+        # remove those values in trels that correspond to shift and reduce actions
+        transition_pred = TRANSITION.vocab[s_transition.argmax(-1).flatten().tolist()]
+        trel_mask = torch.tensor(list(map(lambda x: x not in ['reduce', 'shift'], transition_pred)))
+        s_trel, trels = s_trel[trel_mask], trels[trel_mask]
+
+        tag_loss = self.criterion(s_tag[smask], tags[smask]) if self.args.encoder == 'lstm' else torch.tensor(0).to(self.device)
+        transition_loss = self.criterion(s_transition, transitions)
+        trel_loss = self.criterion(s_trel, trels)
+
+        return transition_loss + trel_loss + tag_loss
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__init__.py b/tania_scripts/supar/models/dep/eager/oracle/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cf072a90eda87b6f954a0a161965c253cafd1b1
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13ce2553293d74a70524325df69e92ffe8f7771d
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93488668e82fa23ba9e3ff1768cb420f90132ce3
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09ee22f8790adc5bbc67d6acfbc51c23685850e4
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41d88bf55396dcd614cb06abfdcd6cdd80a64562
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c22b12dc825c9a45c4c9b0f8e44b31477268e653
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7fce37dccace89b6d7436c953040e05640fbe754
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9534220075a54bc9b129b311b945ea708d02dc8
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d76682cca61dd66a9aa936ace916052a3ae2e0f7
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a8d46ef4a1c5317321c19f8323f13e6319c4fe0
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f963a4cfe22b0dc6246066d745d20c9344ac44ed
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a362f4b907de1716061719405def54b641c7abd
Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/eager/oracle/arceager.py b/tania_scripts/supar/models/dep/eager/oracle/arceager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b65dad15199cb5d6338afdb3b84b6a18ace20918
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/oracle/arceager.py
@@ -0,0 +1,219 @@
+from supar.models.dep.eager.oracle.stack import Stack
+from supar.models.dep.eager.oracle.buffer import Buffer
+from supar.models.dep.eager.oracle.node import Node
+from supar.models.dep.eager.oracle.dependency import Transition
+from typing import List, Dict
+
+class ArcEagerEncoder:
+    def __init__(self, bos: str, eos: str):
+        self.stack = None
+        self.buffer = None
+        self.transitions = []
+        self.dependencies = []
+        self.nodes_assigned = []
+        self.bos, self.eos = bos, eos
+        self.shift_token = '<shift>'
+        self.reduce_token = '<reduce>'
+        self.n = 0
+
+    def left_arc(self):
+        # head = buffer_front, dependent = stack_top (stack_top <- buffer_front)
+        stack_top = self.stack.pop()
+        buffer_front = self.buffer.get()
+        self.transitions.append(
+            Transition(type='left-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=stack_top.DEPREL)
+        )
+        self.dependencies.append(
+            Node(ID=stack_top.ID, FORM=stack_top.FORM, UPOS=stack_top.UPOS, HEAD=buffer_front.ID, DEPREL=stack_top.DEPREL)
+        )
+        self.nodes_assigned.append(stack_top.ID)
+        return self.transitions
+
+    def right_arc(self):
+        # head = stack_top, dependent = buffer_front (stack_top -> buffer_front)
+        stack_top = self.stack.get()
+        buffer_front = self.buffer.remove()
+        self.stack.push(buffer_front)
+        self.transitions.append(
+            Transition(type='right-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=buffer_front.DEPREL)
+        )
+        self.dependencies.append(
+            Node(ID=buffer_front.ID, FORM=buffer_front.FORM, UPOS=buffer_front.UPOS, HEAD=stack_top.ID, DEPREL=buffer_front.DEPREL)
+        )
+        self.nodes_assigned.append(buffer_front.ID)
+        return self.transitions
+
+    def shift(self):
+        front_item = self.buffer.remove()
+        stack_top = self.stack.get()
+        self.stack.push(front_item)
+        self.transitions.append(
+            Transition(type='shift', stack_top=stack_top, buffer_front=front_item, deprel=front_item.DEPREL))
+        return self.transitions
+
+    def reduce(self):
+        stack_top = self.stack.get()
+        buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(self.n, self.eos)
+        self.stack.pop()
+        self.transitions.append(
+            Transition(type='reduce', stack_top=stack_top, buffer_front=buffer_front, deprel=stack_top.DEPREL)
+        )
+        return self.transitions
+
+    def next_action(self):
+        stack_top = self.stack.get()
+        try:
+            buffer_front = self.buffer.get()
+        except IndexError:
+            if stack_top.ID in self.nodes_assigned:
+                return self.reduce()
+            return None
+
+        if buffer_front.ID == stack_top.HEAD:
+            return self.left_arc()
+        elif (buffer_front.ID not in self.nodes_assigned) and (stack_top.ID == buffer_front.HEAD):
+            return self.right_arc()
+        elif (stack_top.ID in self.nodes_assigned) and (buffer_front.HEAD in [node.ID for node in self.stack.items]):
+            return self.reduce()
+        elif (stack_top.ID in self.nodes_assigned) and (buffer_front.ID in [node.HEAD for node in self.stack.items]):
+            return self.reduce()
+        else:
+            return self.shift()
+
+    def encode(self, sentence: List[Node]):
+        # create stack and buffer
+        self.stack = Stack([Node.create_root(self.bos)])
+        self.buffer = Buffer(sentence.copy())
+        self.n = len(sentence)
+
+        # reset
+        self.transitions, self.dependencies = [], []
+
+        next_action = self.next_action()
+        while next_action:
+            next_action = self.next_action()
+
+        # remove values
+        self.dependencies = sorted(self.dependencies, key=lambda dep: dep.ID)
+        return self.transitions
+
+
+
+class ArcEagerDecoder:
+    def __init__(self, sentence: List[Node], bos: str, eos: str, unk: str):
+        self.sentence = sentence.copy()
+        self.decoded_nodes = [Node(ID=node.ID, FORM=node.FORM, UPOS=node.UPOS, HEAD=0, DEPREL=unk) for node in sentence]
+        self.transitions = list()
+        self.nodes_assigned = list()
+        self.stack = Stack([Node.create_root(bos)])
+        self.buffer = Buffer(sentence.copy())
+        self.bos, self.eos, self.unk = bos, eos, unk
+        self.shift_token, self.reduce_token = '<shift>',  '<reduce>'
+
+
+    def left_arc(self, deprel: str):
+        # head = buffer_front, dependent = stack_top (stack_top <- buffer_front)
+        stack_top = self.stack.pop()
+        buffer_front = self.buffer.get()
+        self.nodes_assigned.append(stack_top.ID)
+        self.transitions.append(
+            Transition(type='left-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=deprel)
+        )
+        self.decoded_nodes[stack_top.ID - 1].HEAD = buffer_front.ID
+        self.decoded_nodes[stack_top.ID - 1].DEPREL = deprel
+        # get next states
+        stack_top = self.stack.get()
+        buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos)
+        return stack_top, buffer_front
+
+    def right_arc(self, deprel: str):
+        # head = stack_top, dependent = buffer_front (stack_top -> buffer_front)
+        stack_top = self.stack.get()
+        buffer_front = self.buffer.remove()
+        self.stack.push(buffer_front)
+        self.transitions.append(
+            Transition(type='right-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=deprel)
+        )
+        self.nodes_assigned.append(buffer_front.ID)
+        self.decoded_nodes[buffer_front.ID - 1].HEAD = stack_top.ID
+        self.decoded_nodes[buffer_front.ID - 1].DEPREL = deprel
+
+        # get next states
+        stack_top = self.stack.get()
+        buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos)
+        return stack_top, buffer_front
+
+
+    def shift(self, deprel):
+        front_item = self.buffer.remove()
+        stack_top = self.stack.get()
+        self.stack.push(front_item)
+        self.transitions.append(
+            Transition(type='shift', stack_top=stack_top, buffer_front=front_item, deprel=deprel))
+        # get next states
+        stack_top = self.stack.get()
+        buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos)
+        return stack_top, buffer_front
+
+    def reduce(self, deprel):
+        stack_top = self.stack.get()
+        try:
+            buffer_front = self.buffer.get()
+        except IndexError:
+            buffer_front = Node.create_eos(len(self.sentence), self.eos)
+        self.stack.pop()
+        self.transitions.append(
+            Transition(type='reduce', stack_top=stack_top, buffer_front=buffer_front, deprel=deprel)
+        )
+        # get next states
+        stack_top = self.stack.get()
+        buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos)
+        return stack_top, buffer_front
+
+    def apply_transition(self, transitions: List[str], deprel: str):
+        stack_top = self.stack.get()
+        try:
+            buffer_front = self.buffer.get()
+        except IndexError:
+            if stack_top.ID in self.nodes_assigned:
+                self.reduce(deprel)
+            return None
+
+        for transition in transitions:
+            if (transition == 'left-arc') and (stack_top.ID not in self.nodes_assigned) and (not stack_top.is_root):
+                return self.left_arc(deprel)
+            if (transition == 'right-arc') and (buffer_front.ID not in self.nodes_assigned):
+                return self.right_arc(deprel)
+            if (transition == 'reduce') and (stack_top.ID in self.nodes_assigned):
+                return self.reduce(deprel)
+            return self.shift(deprel)
+
+    def apply(self, transition: str, deprel: str):
+        if transition == 'left-arc':
+            return self.left_arc(deprel)
+        if transition == 'right-arc':
+            return self.right_arc(deprel)
+        if transition == 'reduce':
+            return self.reduce(deprel)
+        if transition == 'shift':
+            return self.shift(deprel)
+
+    def decode_sentence(self, transitions: List[List[str]], deprels: List[str]):
+        for transition_ops, deprel in zip(transitions, deprels):
+            info = self.apply_transition(transition_ops, deprel)
+            if info is None:
+                break
+        self.postprocess()
+        return self.decoded_nodes
+
+    def postprocess(self):
+        # check if there are more than one node with root head
+        roots = sum([node.HEAD == 0 for node in self.decoded_nodes])
+        if roots > 1:
+            # get leftmost root
+            for node in self.decoded_nodes:
+                if node.HEAD == 0:
+                    root = node.ID
+            for node in self.decoded_nodes[root:]:
+                if node.HEAD == 0:
+                    node.HEAD = root
diff --git a/tania_scripts/supar/models/dep/eager/oracle/buffer.py b/tania_scripts/supar/models/dep/eager/oracle/buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a1f5d436d970327abef248843db0fb4890cbeeb
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/oracle/buffer.py
@@ -0,0 +1,24 @@
+from typing import Optional, List
+from supar.models.dep.eager.oracle.node import Node
+
+class Buffer:
+    def __init__(self, items: Optional[List[Node]]):
+        if items:
+            self.items = items
+        else:
+            self.items = []
+
+    def get(self) -> Node:
+        return self.items[0]
+
+    def remove(self) -> Node:
+        return self.items.pop(0)
+
+    def append(self, item: Node):
+        self.items.append(item)
+
+    def __len__(self):
+        return len(self.items)
+
+    def __repr__(self):
+        return str(self.items)
\ No newline at end of file
diff --git a/tania_scripts/supar/models/dep/eager/oracle/dependency.py b/tania_scripts/supar/models/dep/eager/oracle/dependency.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4825ba75f62256dd3038c43d96578d4fd6c99c3
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/oracle/dependency.py
@@ -0,0 +1,30 @@
+from typing import Optional, List
+from supar.models.dep.eager.oracle.node import Node
+
+class Dependency:
+    def __init__(self, dependent_id: int, head_id: str, deprel: str):
+        self.dependent_id = dependent_id
+        self.head_id = head_id
+        self.deprel = deprel
+
+    def __repr__(self):
+        return self.toconll()
+
+    def toconll(self):
+        return '\t'.join([
+            str(self.dependent_id), str(self.head_id), self.deprel
+        ])
+
+class Transition:
+    def __init__(self, type: str, stack_top: Node, buffer_front: Node, deprel: str):
+        self.type = type
+        self.stack_top = stack_top
+        self.buffer_front = buffer_front
+        self.deprel = deprel
+
+    def __repr__(self):
+        return '\t'.join((str(self.stack_top.FORM), str(self.buffer_front.FORM), self.type, self.deprel))
+
+
+
+
diff --git a/tania_scripts/supar/models/dep/eager/oracle/node.py b/tania_scripts/supar/models/dep/eager/oracle/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aa07a99a616853235e6b1d2dc18f6709319e2a6
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/oracle/node.py
@@ -0,0 +1,74 @@
+from typing import List
+
+class Node:
+
+    def __init__(self, ID: int, FORM: str, UPOS: str, HEAD: int, DEPREL: str, is_root: bool = False):
+        self.ID = ID
+        self.FORM = FORM
+        self.UPOS = UPOS
+        self.HEAD = HEAD
+        self.DEPREL = DEPREL
+        self.is_root = is_root
+
+    def __str__(self):
+        return f'Node(ID={self.ID}, FORM={self.FORM}, UPOS={self.UPOS}, HEAD={self.HEAD})'
+
+    def __repr__(self):
+        return f'Node(ID={self.ID}, FORM={self.FORM}, UPOS={self.UPOS}, HEAD={self.HEAD})'
+
+    @classmethod
+    def from_conllu(cls, conll: str):
+        ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = conll.split('\t')
+        if HEAD == '_':
+             return Node(int(ID), FORM, UPOS, HEAD, DEPREL)
+        else:
+             return Node(int(ID), FORM, UPOS, int(HEAD), DEPREL)
+
+    @classmethod
+    def create_root(cls, token: str):
+        return Node(0, token, token, 0, token, is_root=True)
+
+    @classmethod
+    def create_eos(cls, position: int, token: str):
+        return Node(position, token, token, 0, token, is_root=False)
+
+    def coverage(self) -> range:
+        limits = sorted([self.ID, self.HEAD])
+        return range(*limits)
+
+    def isprojective(heads: List[int]):
+        pairs = [(h, d) for d, h in enumerate(heads, 1) if h >= 0]
+        for i, (hi, di) in enumerate(pairs):
+            for hj, dj in pairs[i + 1:]:
+                (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj])
+                if li <= hj <= ri and hi == dj:
+                    print('1')
+                    return False
+                if lj <= hi <= rj and hj == di:
+                    print('2')
+                    return False
+                if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0:
+                    print('3')
+                    print(di, hi, dj, hj)
+                    return False
+        return True
+
+def isprojective(heads: List[int]):
+    pairs = [(h, d) for d, h in enumerate(heads, 1) if h >= 0]
+    for i, (hi, di) in enumerate(pairs):
+        for hj, dj in pairs[i + 1:]:
+            (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj])
+            if li <= hj <= ri and hi == dj:
+                print('1')
+                return False
+            if lj <= hi <= rj and hj == di:
+                print('2')
+                return False
+            if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0:
+                print('3')
+                print(di, hi, dj, hj)
+                return False
+    return True
+
+
+
diff --git a/tania_scripts/supar/models/dep/eager/oracle/stack.py b/tania_scripts/supar/models/dep/eager/oracle/stack.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee5545a54e71d347c514d385848a08889cade629
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/oracle/stack.py
@@ -0,0 +1,24 @@
+from typing import Union, List, Optional
+from supar.models.dep.eager.oracle.node import Node
+
+class Stack:
+    def __init__(self, items: Optional[List[Node]] = None):
+        if items:
+            self.items = items
+        else:
+            self.items = []
+
+    def pop(self) -> Node:
+        return self.items.pop(-1)
+
+    def push(self, item: Node):
+        self.items.append(item)
+
+    def get(self) -> Node:
+        return self.items[-1]
+
+    def __repr__(self):
+        return repr(self.items)
+
+
+
diff --git a/tania_scripts/supar/models/dep/eager/parser.py b/tania_scripts/supar/models/dep/eager/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4857ded7c9bc3b61042ef60dbd33e9f4800972f
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/parser.py
@@ -0,0 +1,380 @@
+import os
+from supar.models.dep.eager.oracle.node import Node
+
+import torch
+from supar.models.dep.eager.model import ArcEagerDependencyModel
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, PAD, UNK, EOS
+from supar.utils.field import Field, RawField, SubwordField
+from supar.utils.fn import ispunct
+from supar.utils.logging import get_logger
+from supar.utils.metric import AttachmentMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+from supar.models.dep.eager.transform import ArcEagerTransform
+from supar.models.dep.eager.oracle.arceager import ArcEagerDecoder
+from itertools import groupby
+
+
+logger = get_logger(__name__)
+from typing import Tuple, List, Union
+
+def consecutive_duplicate_spans(L):
+    new_list = []
+    for group in groupby(L):
+        f = list(group[1])
+        if len(f) > 1:
+            new_el = f[0].replace('-arc', '') + "*"
+            new_list.append(new_el)
+        elif len(f) == 1:
+            new_el = f[0].replace('-arc', '')
+            new_list.append(new_el)
+    return new_list
+
+
+class ArcEagerDependencyParser(Parser):
+    MODEL = ArcEagerDependencyModel
+    NAME = 'arceager-dependency'
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.FORM = self.transform.FORM
+        self.STACK_TOP = self.transform.STACK_TOP
+        self.BUFFER_FRONT = self.transform.BUFFER_FRONT
+        self.TRANSITION, self.TREL = self.transform.TRANSITION, self.transform.TREL
+        self.TAG = self.transform.UPOS
+        self.HEAD = self.transform.HEAD
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, texts, *feats, tags, _, _, stack_top, buffer_front, transitions, trels = batch
+        #print('dep eager parser feats', len(*feats))
+        tmask = self.get_padding_mask(stack_top)
+
+        # pad stack_top and buffer_front vectors: note that padding index must be the length of the sequence
+        pad_indices = (batch.lens - 1 - self.args.delay).tolist()
+        stack_top, buffer_front = self.pad_tensor(stack_top, pad_indices), self.pad_tensor(buffer_front, pad_indices)
+        # forward pass
+        #   stack_top: torch.Tensor ~ [batch_size, pad(tr_len), n_transitions]
+        #   buffer_front: torch.Tensor ~ [batch_size, pad(tr_len), n_trels]
+        s_transition, s_trel, s_tag, qloss = self.model(words, stack_top, buffer_front, feats)
+
+        # compute loss
+        smask = batch.mask[:, (2+self.args.delay):]
+        loss = self.model.loss(s_transition, s_trel, s_tag, transitions, trels, tags, smask, tmask, self.TRANSITION) + qloss
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> AttachmentMetric:
+        words, texts, *feats, tags, heads, deprels, stack_top, buffer_front, transitions, trels = batch
+        transition_mask = self.get_padding_mask(stack_top)
+
+
+        # obtain transition loss
+        stack_top, buffer_front = \
+            self.pad_tensor(stack_top, (batch.lens - 1 - self.args.delay).tolist()), \
+                self.pad_tensor(buffer_front, (batch.lens - 1 - self.args.delay).tolist())
+        s_transition, s_trel, s_tag, qloss = self.model(words, stack_top, buffer_front, feats.copy())
+        smask = batch.mask[:, (2+self.args.delay):]
+        loss = self.model.loss(s_transition, s_trel, s_tag, transitions, trels, tags, smask, transition_mask, self.TRANSITION) + qloss
+
+        # obtain indices of deprels from TREL field
+        batch_size = words.shape[0]
+        deprels = [self.TREL.vocab[deprels[b]] for b in range(batch_size)]
+
+        # create decoders
+        lens = list(map(len, texts))
+        sentences = list()
+        for b in range(batch_size):
+            sentences.append(
+                [
+                    Node(ID=i + 1, FORM=form, UPOS='', HEAD=head, DEPREL=deprel) for i, (form, head, deprel) in \
+                    enumerate(zip(texts[b], heads[b, :lens[b]].tolist(), deprels[b]))
+                ]
+            )
+        decoders = list(map(
+            lambda sentence: ArcEagerDecoder(sentence=sentence, bos='', eos='', unk=self.transform.TREL.unk_index),
+            sentences
+        ))
+         
+        # compute oracle simulation for all elements in batch
+        head_preds, deprel_preds, *_ = self.oracle_decoding(decoders, words, feats)
+        head_preds, deprel_preds = self.pad_tensor(head_preds), self.pad_tensor(deprel_preds)
+        deprels = self.pad_tensor(deprels)
+
+        seq_mask = batch.mask[:, (2 + self.args.delay):]
+        return AttachmentMetric(loss, (head_preds, deprel_preds), (heads, deprels), seq_mask)
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, texts, *feats = batch
+        #print('SENTENCE: ', ' '.join(list(texts[0])))
+        #print()
+        lens = (batch.lens - 2 - self.args.delay).tolist()
+
+
+        batch_size = words.shape[0]
+        # create decoders
+        sentences = list()
+        for b in range(batch_size):
+            sentences.append(
+                [
+                    Node(ID=i + 1, FORM='', UPOS='', HEAD=None, DEPREL=None) for i in range(lens[b])
+                ]
+            )
+        decoders = list(map(
+            lambda sentence: ArcEagerDecoder(sentence=sentence, bos='', eos='', unk=self.transform.TREL.unk_index),
+            sentences
+        ))
+        # compute oracle simulation for all elements in batch
+        head_preds, deprel_preds, stack_list, buffer_list,actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, act_dict = self.oracle_decoding(decoders, words, feats)
+        batch.heads = head_preds
+        batch.rels = deprel_preds
+        return batch, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, list(texts[0]), act_dict
+
+    def get_padding_mask(self, tensor_list: List[torch.Tensor]) -> torch.Tensor:
+        """
+        From a list of tensors of different lengths, creates a padding mask where False values indicates
+        padding tokens. True otherwise.
+        Args:
+            tensor_list: List of tensors.
+        Returns: torch.Tensor ~ [len(tensor_list), max(lenghts)]
+
+        """
+        lens = list(map(len, tensor_list))
+        max_len = max(lens)
+        return torch.tensor([[True] * length + [False] * (max_len - length) for length in lens]).to(self.model.device)
+
+    def pad_tensor(
+            self,
+            tensor_list: Union[List[torch.Tensor], List[List[int]]],
+            pad_index: Union[int, List[int]] = 0
+    ):
+        """
+        Applies padding to a list of tensors or list of lists.
+        Args:
+            tensor_list: List of tensors or list of lists.
+            pad_index: Index used for padding or list of indices used for padding for each item of tensor_list.
+        Returns: torch.Tensor ~ [len(tensor_list), max(lengths)]
+        """
+        max_length = max(map(len, tensor_list))
+        if isinstance(pad_index, int):
+            if isinstance(tensor_list[0], list):
+                return torch.tensor(
+                    [tensor + [pad_index] * (max_length - len(tensor)) for tensor in tensor_list]).to(self.model.device)
+            else:
+                return torch.tensor(
+                    [tensor.tolist() + [pad_index] * (max_length - len(tensor)) for tensor in tensor_list]).to(self.model.device)
+        else:
+            pad_indexes = pad_index
+            if isinstance(tensor_list[0], list):
+                return torch.tensor(
+                    [tensor + [pad_index] * (max_length - len(tensor)) for tensor, pad_index in
+                     zip(tensor_list, pad_indexes)]).to(self.model.device)
+            else:
+                return torch.tensor(
+                    [tensor.tolist() + [pad_index] * (max_length - len(tensor)) for tensor, pad_index in
+                     zip(tensor_list, pad_indexes)]).to(self.model.device)
+
+    def get_text_mask(self, batch):
+        text_lens = (batch.lens - 2 - self.args.delay).tolist()
+        mask = batch.mask
+        mask[:, 0] = 0  # remove bos token
+        for i, text_len in enumerate(text_lens):
+            mask[i, (1 + text_len):] = 0
+        return mask
+
+    def oracle_decoding(self, decoders: List[ArcEagerDecoder], words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[
+        List[List[int]]]:
+        """
+        Implements Arc-Eager decoding. Using words indices, creates the initial state of the Arc-Eager oracle
+        and predicts each (transition, trel) with the TransitionDependencyModel.
+        Args:
+            decoders: List[ArcEagerDecoder] ~ batch_size
+            words: torch.Tensor ~ [batch_size, seq_len]
+            feats: List[torch.Tensor ~ [batch_size, seq_len, feat_embed]] ~ n_feat
+
+
+        Returns: head_preds, deprel_preds
+            head_preds: List[List[int] ~ sen_len] ~ batch_size: Head values for each sentence in batch.
+            deprel_preds: List[List[int] ~ sen_len] ~ batch_size: Indices of dependency relations for each sentence in batch.
+        """
+        # create a mask vector to filter those decoders that achieved the final state
+        compute = [True for _ in range(len(decoders))]
+        batch_size = len(decoders)
+        #print('ORACLE DECO TAG VOCAB', self.TAG.vocab.items())
+
+        exclude = self.TREL.vocab[['<reduce>', '<shift>']]
+
+        # obtain word representations of the encoder
+        x, s_tag, _ = self.model.encoder_forward(words, feats)
+        transition_stags = self.model.decode_stag(s_tag)
+
+
+        stack_top = [torch.tensor([decoders[b].stack.get().ID]).reshape(1) for b in range(batch_size)]
+        buffer_front = [torch.tensor([decoders[b].buffer.get().ID]).reshape(1) for b in range(batch_size)]
+        counter = 0
+        
+        stack_list = []
+        buffer_list = []
+        actions_list = []
+        
+        #print('stack_top', stack_top[-1].item())
+        #print('buffer_front', buffer_front[-1].item())
+        stack_list.append(stack_top[-1].item())
+        buffer_list.append(buffer_front[-1].item())
+        
+        while any(compute):
+            s_transition, s_trel = self.model.decoder_forward(x, torch.stack(stack_top), torch.stack(buffer_front))
+            transition_preds, trel_preds = self.model.decode(s_transition, s_trel, exclude)
+            transition_preds, trel_preds = transition_preds[:, counter, :], trel_preds[:, counter]
+            transition_preds = [self.TRANSITION.vocab[i.tolist()] for i in
+                                transition_preds.reshape(batch_size, self.args.n_transitions)]
+            for b, decoder in enumerate(decoders):
+                #print('209', transition_preds[b][0])
+                actions_list.append(transition_preds[b][0])
+                if not compute[b]:
+                    stop, bfront = stack_top[b][-1].item(), buffer_front[b][-1].item()
+
+                else:
+                    result = decoder.apply_transition(transition_preds[b], trel_preds[b].item())
+                    if result is None:
+                        stop, bfront = stack_top[b][-1].item(), buffer_front[b][-1].item()
+                        #print("2019 stop, bfront", stop, bfront)
+
+                        
+                        compute[b] = False
+                    else:
+                        #print()
+                        stop, bfront = result[0].ID, result[1].ID
+                        stack_list.append(stop)
+                        buffer_list.append(bfront)
+                        #print("225 stop, bfront", stop, bfront)
+
+ 
+                stack_top[b] = torch.concat([stack_top[b], torch.tensor([stop])])
+                buffer_front[b] = torch.concat([buffer_front[b], torch.tensor([bfront])])
+            counter += 1
+
+        head_preds = [[node.HEAD for node in decoder.decoded_nodes] for decoder in decoders]
+        deprel_preds = [[node.DEPREL for node in decoder.decoded_nodes] for decoder in decoders]
+        deprel_preds_decoded = [[self.TREL.vocab[i] for i in dep_pre] for dep_pre in deprel_preds]
+        pos_preds = [[i.item() for i in trans_stag] for trans_stag in transition_stags]
+        pos_preds_decoded =  [[self.TAG.vocab[i.item()] for i in trans_stag] for trans_stag in transition_stags]   
+        form_preds = [[node.FORM for node in decoder.decoded_nodes] for decoder in decoders]
+        #print(len(stack_list), stack_list)
+        #print(len(buffer_list), buffer_list)
+        #print(len(actions_list), actions_list)
+        #assert len(stack_list) == len(buffer_list) == len(actions_list)
+        
+        """
+        print()
+        
+        print('stack list: ', stack_list)
+        print('buffer list: ', buffer_list)
+        print('actions list: ', actions_list)
+        print()
+        print('head_preds: ', head_preds)
+        print('deprel_preds: ', deprel_preds, deprel_preds_decoded)
+        print('pos_preds: ', pos_preds, pos_preds_decoded)
+        print()
+        """
+                
+        act_dict  = {}
+        for i, j in zip(buffer_list, actions_list):
+            act_dict.setdefault(i, []).append(j)
+        
+        act_dict_list = []
+        for _,vel in act_dict.items():        
+            new_list = consecutive_duplicate_spans(vel)
+            act_dict_list.append(" ".join(new_list))
+         
+        #print('macro actions: ', act_dict_list)
+      
+        #assert len(act_dict_list) == len(head_preds[0])
+        
+        
+        return head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, act_dict
+
+    @classmethod
+    def build(cls, path, min_freq=1, fix_len=20, **kwargs):
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+
+        # ------------------------------- source fields -------------------------------
+        WORD, TAG, CHAR = None, None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            pad_token = t.pad if t.pad else PAD
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay)
+            if 'tag' in args.feat:
+                TAG = Field('tags')
+        TAG = Field('tags')
+        TEXT = RawField('texts')
+        STACK_TOP = RawField('stack_top', fn=lambda x: torch.tensor(x))
+        BUFFER_FRONT = RawField('buffer_front', fn=lambda x: torch.tensor(x))
+
+        # ------------------------------- target fields -------------------------------
+        TRANSITION = Field('transition')
+        TREL = Field('trels')
+        HEAD = Field('heads', use_vocab=False)
+        DEPREL = RawField('rels')
+
+        transform = ArcEagerTransform(
+            FORM=(WORD, TEXT, CHAR), UPOS=TAG, HEAD=HEAD, DEPREL=DEPREL,
+            STACK_TOP=STACK_TOP, BUFFER_FRONT=BUFFER_FRONT, TRANSITION=TRANSITION, TREL=TREL,
+        )
+        train = Dataset(transform, args.train, **args)
+
+        if args.encoder != 'bert':
+            print('HOLY SH dep eager parser')
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None),
+                       lambda x: x / torch.std(x))
+            if TAG:
+                TAG.build(train)
+            if CHAR:
+
+                CHAR.build(train)
+        TAG.build(train)
+        TREL.build(train)
+        TRANSITION.build(train)
+        
+        print('TAG VOCAB', TAG.vocab.items())
+        #print('CHAR VOCAB', CHAR.vocab.items())
+
+
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_transitions': len(TRANSITION.vocab),
+            'n_trels': len(TREL.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index
+        })
+        
+        
+        logger.info(f"{transform}")
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
diff --git a/tania_scripts/supar/models/dep/eager/transform.py b/tania_scripts/supar/models/dep/eager/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..be5c883b8abba8f07d14ca75ddbfc8adaf9fa2f8
--- /dev/null
+++ b/tania_scripts/supar/models/dep/eager/transform.py
@@ -0,0 +1,176 @@
+from supar.utils.transform import Transform, Sentence
+from supar.utils.field import Field
+from typing import Iterable, Union, Optional, List, Tuple
+from supar.models.dep.eager.oracle.arceager import ArcEagerEncoder
+from supar.models.dep.eager.oracle.node import Node
+import os
+from io import StringIO
+
+
+class ArcEagerTransform(Transform):
+
+    fields = ['ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS', 'FEATS', 'HEAD', 'DEPREL', 'DEPS', 'MISC', 'STACK_TOP', 'BUFFER_FRONT', 'TRANSITION', 'TREL']
+
+    def __init__(
+        self,
+        ID: Optional[Union[Field, Iterable[Field]]] = None,
+        FORM: Optional[Union[Field, Iterable[Field]]] = None,
+        LEMMA: Optional[Union[Field, Iterable[Field]]] = None,
+        UPOS: Optional[Union[Field, Iterable[Field]]] = None,
+        XPOS: Optional[Union[Field, Iterable[Field]]] = None,
+        FEATS: Optional[Union[Field, Iterable[Field]]] = None,
+        HEAD: Optional[Union[Field, Iterable[Field]]] = None,
+        DEPREL: Optional[Union[Field, Iterable[Field]]] = None,
+        DEPS: Optional[Union[Field, Iterable[Field]]] = None,
+        MISC: Optional[Union[Field, Iterable[Field]]] = None,
+        STACK_TOP: Optional[Union[Field, Iterable[Field]]] = None,
+        BUFFER_FRONT: Optional[Union[Field, Iterable[Field]]] = None,
+        TRANSITION: Optional[Union[Field, Iterable[Field]]] = None,
+        TREL: Optional[Union[Field, Iterable[Field]]] = None
+    ):
+        super().__init__()
+
+        self.ID = ID
+        self.FORM = FORM
+        self.LEMMA = LEMMA
+        self.UPOS = UPOS
+        self.XPOS = XPOS
+        self.FEATS = FEATS
+        self.HEAD = HEAD
+        self.DEPREL = DEPREL
+        self.DEPS = DEPS
+        self.MISC = MISC
+        self.STACK_TOP = STACK_TOP
+        self.BUFFER_FRONT = BUFFER_FRONT
+        self.TRANSITION = TRANSITION
+        self.TREL = TREL
+
+    @property
+    def src(self):
+        return self.FORM, self.LEMMA, self.UPOS, self.XPOS, self.FEATS
+    
+    @classmethod
+    def toconll(cls, tokens: List[Union[str, Tuple]]) -> str:
+        r"""
+        Converts a list of tokens to a string in CoNLL-X format with missing fields filled with underscores.
+
+        Args:
+            tokens (List[Union[str, Tuple]]):
+                This can be either a list of words, word/pos pairs or word/lemma/pos triples.
+
+        Returns:
+            A string in CoNLL-X format.
+
+        Examples:
+            >>> print(CoNLL.toconll(['She', 'enjoys', 'playing', 'tennis', '.']))
+            1       She     _       _       _       _       _       _       _       _
+            2       enjoys  _       _       _       _       _       _       _       _
+            3       playing _       _       _       _       _       _       _       _
+            4       tennis  _       _       _       _       _       _       _       _
+            5       .       _       _       _       _       _       _       _       _
+
+            >>> print(CoNLL.toconll([('She',     'she',    'PRP'),
+                                     ('enjoys',  'enjoy',  'VBZ'),
+                                     ('playing', 'play',   'VBG'),
+                                     ('tennis',  'tennis', 'NN'),
+                                     ('.',       '_',      '.')]))
+            1       She     she     PRP     _       _       _       _       _       _
+                2       enjoys  enjoy   VBZ     _       _       _       _       _       _
+            3       playing play    VBG     _       _       _       _       _       _
+            4       tennis  tennis  NN      _       _       _       _       _       _
+            5       .       _       .       _       _       _       _       _       _
+
+        """
+        if isinstance(tokens[0], str):
+            s = '\n'.join([f"{i}\t{word}\t" + '\t'.join(['_'] * 8)
+                           for i, word in enumerate(tokens, 1)])
+        elif len(tokens[0]) == 2:
+            s = '\n'.join([f"{i}\t{word}\t_\t{tag}\t" + '\t'.join(['_'] * 6)
+                           for i, (word, tag) in enumerate(tokens, 1)])
+        elif len(tokens[0]) == 3:
+            s = '\n'.join([f"{i}\t{word}\t{lemma}\t{tag}\t" + '\t'.join(['_'] * 6)
+                           for i, (word, lemma, tag) in enumerate(tokens, 1)])
+        elif len('85!!', tokens[0]) == 10:
+            s = '\n'.join([f"{i}\t{word}\t{lemma}\t{tag}\t{xpos}\t{upos}\t{head}\t{rel}\t{comm}\t{morph}"
+                           for (i, word, lemma, tag, xpos, upos, head, rel, comm, morph) in tokens])
+        else:
+            raise RuntimeError(f"Invalid sequence {tokens}. Only list of str or list of word/pos/lemma tuples are support.")
+        return s + '\n'
+
+    @property
+    def tgt(self):
+        return self.HEAD, self.DEPREL, self.DEPS, self.MISC, self.STACK_TOP, self.BUFFER_FRONT, self.TRANSITION, self.TREL
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        **kwargs
+    ): 
+        if isinstance(data, str) and os.path.exists(data):
+            if os.path.isfile(data):
+                lines = open(data)
+            if os.path.isdir(data):
+                lines = []
+                for filepath in data:
+                    if filepath.endswith('.conllu'):
+                        l = open(filepath)
+                        lines.append(l)
+        else:
+            if lang is not None:
+                #data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)]
+                data = [data.split() if isinstance(data, str) else data]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+            lines = (i for s in data for i in StringIO(self.toconll(s) + '\n'))
+
+        index, sentence = 0, []
+        for line in lines:
+            line = line.strip()
+            if len(line) == 0:
+                sentence = ArcEagerSentence(self, sentence, index)
+                yield sentence
+                index += 1
+                sentence = []
+            else:
+                sentence.append(line)
+
+
+
+class ArcEagerSentence(Sentence):
+    def __init__(self, transform: ArcEagerEncoder, lines: List[str], index: Optional[int] = None):
+        super().__init__(transform, index)
+        self.values = list()
+        self.annotations = dict()
+
+        for i, line in enumerate(lines):
+            value = line.split('\t')
+            if value[0].startswith('#') or not value[0].isdigit():
+                self.annotations[-i-1] = line
+            else:
+                self.annotations[len(self.values)] = line
+                self.values.append(value)
+
+        nodes = [Node.from_conllu('\t'.join(value)) for value in self.values]
+
+
+        algorithm = ArcEagerEncoder(bos=transform.FORM[0].bos, eos=transform.FORM[0].eos)
+        transitions = algorithm.encode(nodes.copy())
+        stack_top, buffer_front, transition, trel = zip(
+            *[(transition.stack_top.ID, transition.buffer_front.ID, transition.type, transition.deprel)
+              for transition in transitions])
+
+        self.values = list(zip(*self.values))
+        if self.values[6][0] != '_':
+            self.values[6] = tuple(map(int, self.values[6]))
+        self.values += [stack_top, buffer_front, transition, trel]
+
+
+    def __repr__(self):
+        # cover the raw lines
+        #print('µµµµµµµµ', self.values[:-4])
+        merged = {**self.annotations,
+                  **{i: '\t'.join(map(str, line))
+                     for i, line in enumerate(zip(*self.values[:-4]))}}
+        
+        return '\n'.join(merged.values()) + '\n'
\ No newline at end of file
diff --git a/tania_scripts/supar/models/dep/sl/__init__.py b/tania_scripts/supar/models/dep/sl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0e75ee7b541dea11bd85069b1c4cb47e8dce78c
--- /dev/null
+++ b/tania_scripts/supar/models/dep/sl/__init__.py
@@ -0,0 +1,2 @@
+from .model import SLDependencyModel
+from .parser import SLDependencyParser
\ No newline at end of file
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e84041e18d76d00c1a56d47aa5947ed6eabe7f48
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7da5b0abda7bcaada2f81498ced6765d9b1bb222
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0eca97c4fc223779fc95ca035a2bb17aa33cb321
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f0f4cad3dcc22e7872d1c2ee274f49f8582da16
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95e1cddfd1b7d2af2f9bbf1573def416297a278b
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b16ba48675cbab2d7e21471d46322179b0928fb
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7df4678723720a507453f379f951e9dbd859a3a2
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..027f048a42a233403d83761297506d43e3d196b9
Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/sl/model.py b/tania_scripts/supar/models/dep/sl/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..79e5ec7e5a5ac1c99fbe2218d3f4a46aee6069ac
--- /dev/null
+++ b/tania_scripts/supar/models/dep/sl/model.py
@@ -0,0 +1,154 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from supar.model import Model
+from supar.modules import MLP, DecoderLSTM
+from supar.utils import Config
+from typing import Tuple, List, Union
+
+class SLDependencyModel(Model):
+    def __init__(self,
+                 n_words: int,
+                 n_labels: Union[Tuple[int], int],
+                 n_rels: int,
+                 n_tags: int = None,
+                 n_chars: int = None,
+                 encoder: str ='lstm',
+                 feat: List[str] = [],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_arc_mlp=500,
+                 n_rel_mlp=100,
+                 mlp_dropout=.33,
+                 scale=0,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        # create decoder
+        self.label_decoder, self.rel_decoder = None, None
+        if self.args.decoder == 'lstm':
+            decoder = lambda out_dim: DecoderLSTM(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
+                input_size=self.args.n_encoder_hidden, hidden_size=self.args.n_encoder_hidden,
+                num_layers=self.args.n_decoder_layers, dropout=mlp_dropout,
+                output_size=out_dim)
+        else:
+            decoder = lambda out_dim: MLP(
+                n_in=self.args.n_encoder_hidden, n_out=out_dim, dropout=mlp_dropout
+            )
+
+        if self.args.codes == '2p':
+            self.label_decoder1 = decoder(self.args.n_labels[0])
+            self.label_decoder2 = decoder(self.args.n_labels[1])
+            self.label_decoder = lambda x: (self.label_decoder1(x), self.label_decoder2(x))
+        else:
+            self.label_decoder = decoder(self.args.n_labels)
+
+        self.rel_decoder = decoder(self.args.n_rels)
+
+        # create delay projection
+        if self.args.delay != 0:
+            self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay + 1),
+                                  n_out=self.args.n_encoder_hidden, dropout=mlp_dropout)
+
+        # create PoS tagger
+        if self.args.encoder == 'lstm':
+            self.pos_tagger = DecoderLSTM(
+                input_size=self.args.n_encoder_hidden, hidden_size=self.args.n_encoder_hidden,
+                output_size=self.args.n_tags, num_layers=1, dropout=mlp_dropout, device=self.device)
+        else:
+            self.pos_tagger = nn.Identity()
+
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words: torch.Tensor, feats: List[torch.Tensor] = None) -> Tuple[torch.Tensor]:
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            s_label (~Union[torch.Tensor, Tuple[torch.Tensor]]): ``[batch_size, seq_len, n_labels]``
+                Tensor or 2-dimensional tensor tuple (if 2-planar bracketing coding is being used) which holds the
+                scores of all possible labels.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, n_rels]``
+                Holds scores of all possible dependency relations.
+            s_tag (~torch.Tensor): ` [batch_size, seq_len, n_tags]``
+                Holds scores of all possible tags for each word.
+            qloss (~torch.Tensor):
+                Vector quantization loss.
+        """
+
+        # x ~ [batch_size, bos + pad(seq_len) + delay, n_encoder_hidden]
+        x = self.encode(words, feats)
+        x = x[:, 1:, :]                 # remove BoS token
+
+        # s_tag ~ [batch_size, pad(seq_len), n_tags]
+        s_tag = self.pos_tagger(x if self.args.delay == 0 else x[:, :-self.args.delay, :])
+
+        # map or concatenate delayed representations
+        if self.args.delay != 0:
+            x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay+1)], dim=2)
+            x = self.delay_proj(x)
+
+        # x ~ [batch_size, pad(seq_len), n_encoder_hidden]
+        batch_size, pad_seq_len, _ = x.shape
+
+        # pass through vector quantization module
+        x, qloss = self.vq_forward(x)
+
+        # make predictions of labels/relations
+        s_label = self.label_decoder(x)
+        s_rel = self.rel_decoder(x)
+
+        return s_label, s_rel, s_tag, qloss
+
+    def loss(
+        self,
+        s_label: Union[Tuple[torch.Tensor], torch.Tensor],
+        s_rel: torch.Tensor,
+        s_tag: torch.Tensor,
+        labels: Union[Tuple[torch.Tensor], torch.Tensor],
+        rels: torch.Tensor,
+        tags: torch.Tensor,
+        mask: torch.Tensor
+         ) -> torch.Tensor:
+
+        loss = self.criterion(s_label[mask], labels[mask]) if self.args.codes != '2p' else sum(self.criterion(scores[mask], golds[mask]) for scores, golds in zip(s_label, labels))
+
+        loss += self.criterion(s_rel[mask], rels[mask])
+
+        if self.args.encoder == 'lstm':
+            loss += self.criterion(s_tag[mask], tags[mask])
+        return loss
+
+    def decode(self, s_label: Union[Tuple[torch.Tensor], torch.Tensor], s_rel: torch.Tensor, s_tag: torch.Tensor,
+               mask: torch.Tensor):
+        label_preds = s_label.argmax(-1) if self.args.codes != '2p' else tuple(map(lambda x: x.argmax(-1), s_label))
+        rel_preds = s_rel.argmax(-1)
+        tag_preds = s_tag.argmax(-1) if self.args.encoder == 'lstm' else None
+        return label_preds, rel_preds, tag_preds
diff --git a/tania_scripts/supar/models/dep/sl/parser.py b/tania_scripts/supar/models/dep/sl/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..af95475824a3b5719b4ff13ab955cdd51b5493bf
--- /dev/null
+++ b/tania_scripts/supar/models/dep/sl/parser.py
@@ -0,0 +1,323 @@
+# -*- coding: utf-8 -*-
+
+import os, re
+from typing import Iterable, Union, List
+
+import torch
+from supar.models.dep.sl.model import SLDependencyModel
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, PAD, UNK
+from supar.utils.field import Field, RawField, SubwordField
+from supar.utils.logging import get_logger
+from supar.utils.metric import AttachmentMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+from supar.models.dep.sl.transform import SLDependency
+from supar.codelin import get_dep_encoder, LinearizedTree, D_Label
+from supar.codelin.utils.constants import D_ROOT_HEAD
+from supar.codelin import LABEL_SEPARATOR
+logger = get_logger(__name__)
+from typing import Tuple, List, Union
+
+
+NONE = '<none>'
+OPTIONS = [r'>\*', r'<\*', r'/\*', r'\\\*']
+
+
+def split_planes(labels: Tuple[str], plane: int) -> Tuple[str]:
+
+    split = lambda label: min(match.span()[0] if match is not None else len(label) for match in map(lambda x: re.search(x, label), OPTIONS))
+    splits = list(map(split, labels))
+    if plane == 0:
+        return tuple(label[:i] for label, i in zip(labels, splits))
+    else:
+        return tuple(label[i:] if i < len(label) else NONE for label, i in zip(labels, splits))
+
+class SLDependencyParser(Parser):
+
+    NAME = 'SLDependencyParser'
+    MODEL = SLDependencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.TAG = self.transform.UPOS
+        self.LABEL, self.DEPREL = self.transform.LABEL, self.transform.DEPREL
+        self.encoder = get_dep_encoder(self.args.codes, LABEL_SEPARATOR)
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 1000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        tree: bool = False,
+        proj: bool = False,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def few_shot(
+            self,
+            train: str,
+            dev: Union[str, Iterable],
+            test: Union[str, Iterable],
+            n_samples: int,
+            epochs: int = 2,
+            batch_size: int = 50
+    ) -> None:
+        return super().few_shot(**Config().update(locals()))
+
+    def predict(
+            self,
+            data: Union[str, Iterable],
+            pred: str = None,
+            lang: str = None,
+            prob: bool = False,
+            batch_size: int = 5000,
+            buckets: int = 8,
+            workers: int = 0,
+            amp: bool = False,
+            cache: bool = False,
+            tree: bool = True,
+            proj: bool = False,
+            verbose: bool = True,
+            **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 500,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        tree: bool = True,
+        proj: bool = False,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        if self.args.encoder == 'lstm':
+            words, texts, chars, tags, _, rels, *labels = batch
+            feats = [chars]
+        else:
+            words, texts, tags, _, rels, *labels = batch
+            feats = []
+        labels = labels[0] if len(labels) == 1 else labels
+        mask = batch.mask[:, (1+self.args.delay):]
+
+        # forward pass
+        s_label, s_rel, s_tag, qloss = self.model(words, feats)
+
+        # compute loss
+        loss = self.model.loss(s_label, s_rel, s_tag, labels, rels, tags, mask) + qloss
+
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> AttachmentMetric:
+        if self.args.encoder == 'lstm':
+            words, texts, chars, tags, heads, rels, *labels = batch
+            feats = [chars]
+        else:
+            words, texts, tags, heads, rels, *labels = batch
+            feats = []
+        labels = labels[0] if len(labels) == 1 else labels
+        mask = batch.mask[:, (1+self.args.delay):]
+        lens = (batch.lens - 1 - self.args.delay).tolist()
+
+        # forward pass
+        s_label, s_rel, s_tag, qloss = self.model(words, feats)
+
+        # compute loss and decode
+        loss = self.model.loss(s_label, s_rel, s_tag, labels, rels, tags, mask) + qloss
+        label_preds, rel_preds, tag_preds = self.model.decode(s_label, s_rel, s_tag, mask)
+
+        # obtain original label and deprel strings to compute decoding
+        if self.args.codes == '2p':
+            label_preds = [
+                (self.LABEL[0].vocab[l1.tolist()], self.LABEL[1].vocab[l2.tolist()])
+                for l1, l2 in zip(*map(lambda x: x[mask].split(lens), label_preds))
+            ]
+            label_preds = [
+                list(map(lambda x: x[0] + (x[1] if x[1] != NONE else ''), zip(*label_pred)))
+                for label_pred in label_preds
+            ]
+        else:
+            label_preds = [self.LABEL.vocab[i.tolist()] for i in label_preds[mask].split(lens)]
+        deprel_preds = [self.DEPREL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
+
+        if self.args.encoder == 'lstm':
+            tag_preds = [self.TAG.vocab[i.tolist()] for i in tag_preds[mask].split(lens)]
+        else:
+            tag_preds = [self.TAG.vocab[i.tolist()] for i in tags[mask].split(lens)]
+
+        # decode
+        head_preds = list()
+        for label_pred, deprel_pred, tag_pred, forms in zip(label_preds, deprel_preds, tag_preds, texts):
+            labels = [D_Label(label, deprel, self.encoder.separator) for label, deprel in zip(label_pred, deprel_pred)]
+            linearized_tree = LinearizedTree(list(forms), tag_pred, ['_']*len(forms), labels, 0)
+
+            decoded_tree = self.encoder.decode(linearized_tree)
+            decoded_tree.postprocess_tree(D_ROOT_HEAD)
+            head_preds.append(torch.tensor([int(node.head) for node in decoded_tree.nodes]))
+
+
+        # resize head predictions (add padding)
+        resize = lambda list_of_tensors: \
+            torch.stack([
+                torch.concat([x, torch.zeros(mask.shape[1] - len(x))])
+                for x in list_of_tensors])
+
+        head_preds = resize(head_preds).to(torch.int32).to(self.model.device)
+
+        return AttachmentMetric(loss, (head_preds, rel_preds), (heads, rels), mask)
+
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        if self.args.encoder == 'lstm':
+            words, texts, *feats, tags = batch
+        else:
+            words, texts, *feats, tags = batch
+        mask = batch.mask[:, (1 + self.args.delay):]
+        lens = (batch.lens - 1 - self.args.delay).tolist()
+
+        # forward pass
+        s_label, s_rel, s_tag, qloss = self.model(words, feats)
+
+        # compute loss and decode
+        label_preds, rel_preds, tag_preds = self.model.decode(s_label, s_rel, s_tag, mask)
+
+        # obtain original label and deprel strings to compute decoding
+        if self.args.codes == '2p':
+            label_preds = [
+                (self.LABEL[0].vocab[l1.tolist()], self.LABEL[1].vocab[l2.tolist()])
+                for l1, l2 in zip(*map(lambda x: x[mask].split(lens), label_preds))
+            ]
+            label_preds = [
+                list(map(lambda x: x[0] + (x[1] if x[1] != NONE else ''), zip(*label_pred)))
+                for label_pred in label_preds
+            ]
+        else:
+            label_preds = [self.LABEL.vocab[i.tolist()] for i in label_preds[mask].split(lens)]
+
+        deprel_preds = [self.DEPREL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
+
+        if self.args.encoder == 'lstm':
+            tag_preds = [self.TAG.vocab[i.tolist()] for i in tag_preds[mask].split(lens)]
+        else:
+            tag_preds = [self.TAG.vocab[i.tolist()] for i in tags[mask].split(lens)]
+
+        # decode
+        head_preds = list()
+        for label_pred, deprel_pred, tag_pred, forms in zip(label_preds, deprel_preds, tag_preds, texts):
+            labels = [D_Label(label, deprel, self.encoder.separator) for label, deprel in zip(label_pred, deprel_pred)]
+            linearized_tree = LinearizedTree(forms, tag_pred, ['_'] * len(forms), labels, 0)
+
+            decoded_tree = self.encoder.decode(linearized_tree)
+            decoded_tree.postprocess_tree(D_ROOT_HEAD)
+
+            head_preds.append([int(node.head) for node in decoded_tree.nodes])
+
+        batch.heads = head_preds
+        batch.rels = deprel_preds
+
+        return batch
+
+    @classmethod
+    def build(cls, path, min_freq=2, fix_len=20, **kwargs):
+        args = Config(**locals())
+
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        CHAR = None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            pad_token = t.pad if t.pad else PAD
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t, delay=args.delay)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True, delay=args.delay)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len, delay=args.delay)
+        TEXT = RawField('texts')
+        TAG = Field('tags')
+
+
+        if args.codes == '2p':
+            LABEL1 = Field('labels1', fn=lambda seq: split_planes(seq, 0))
+            LABEL2 = Field('labels2', fn=lambda seq: split_planes(seq, 1))
+            LABEL = (LABEL1, LABEL2)
+        else:
+            LABEL = Field('labels')
+
+        DEPREL = Field('rels')
+        HEAD = Field('heads', use_vocab=False)
+
+        transform = SLDependency(
+            encoder=get_dep_encoder(args.codes, LABEL_SEPARATOR),
+            FORM=(WORD, TEXT, CHAR), UPOS=TAG, HEAD=HEAD, DEPREL=DEPREL, LABEL=LABEL)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if CHAR:
+                CHAR.build(train)
+        TAG.build(train)
+        DEPREL.build(train)
+
+        if args.codes == '2p':
+            LABEL[0].build(train)
+            LABEL[1].build(train)
+        else:
+            LABEL.build(train)
+
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_labels': len(LABEL.vocab) if args.codes != '2p' else (len(LABEL[0].vocab), len(LABEL[1].vocab)),
+            'n_rels': len(DEPREL.vocab),
+            'n_tags': len(TAG.vocab),
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index,
+            'delay': args.delay
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
+
+
+
diff --git a/tania_scripts/supar/models/dep/sl/transform.py b/tania_scripts/supar/models/dep/sl/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..8603cb86793fc9f49863542634aaa74ffe7e9fe9
--- /dev/null
+++ b/tania_scripts/supar/models/dep/sl/transform.py
@@ -0,0 +1,102 @@
+from supar.utils.transform import Transform, Sentence
+from supar.utils.field import Field
+from typing import Iterable, List, Optional, Union
+from supar.codelin import D_Tree
+
+class SLDependency(Transform):
+
+    fields = ['ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS', 'FEATS', 'HEAD', 'DEPREL', 'DEPS', 'MISC', 'LABEL']
+
+    def __init__(
+        self,
+        encoder,
+        ID: Optional[Union[Field, Iterable[Field]]] = None,
+        FORM: Optional[Union[Field, Iterable[Field]]] = None,
+        LEMMA: Optional[Union[Field, Iterable[Field]]] = None,
+        UPOS: Optional[Union[Field, Iterable[Field]]] = None,
+        XPOS: Optional[Union[Field, Iterable[Field]]] = None,
+        FEATS: Optional[Union[Field, Iterable[Field]]] = None,
+        HEAD: Optional[Union[Field, Iterable[Field]]] = None,
+        DEPREL: Optional[Union[Field, Iterable[Field]]] = None,
+        DEPS: Optional[Union[Field, Iterable[Field]]] = None,
+        MISC: Optional[Union[Field, Iterable[Field]]] = None,
+        LABEL: Optional[Union[Field, Iterable[Field]]] = None
+    ):
+        super().__init__()
+
+        self.encoder = encoder
+        self.ID = ID
+        self.FORM = FORM
+        self.LEMMA = LEMMA
+        self.UPOS = UPOS
+        self.XPOS = XPOS
+        self.FEATS = FEATS
+        self.HEAD = HEAD
+        self.DEPREL = DEPREL
+        self.DEPS = DEPS
+        self.MISC = MISC
+        self.LABEL = LABEL
+
+    @property
+    def src(self):
+        return self.FORM, self.LEMMA, self.UPOS, self.XPOS, self.FEATS
+
+    @property
+    def tgt(self):
+        return self.HEAD, self.DEPREL, self.DEPS, self.MISC, self.LABEL
+
+    def load(
+        self,
+        data: str,
+        **kwargs
+    ):
+        lines = open(data)
+        index, sentence = 0, []
+        for line in lines:
+            line = line.strip()
+            if len(line) == 0:
+                sentence = SLDependencySentence(self, sentence, self.encoder, index)
+                if sentence.values:
+                    yield sentence
+                    index += 1
+                sentence = []
+            else:
+                sentence.append(line)
+
+
+
+class SLDependencySentence(Sentence):
+    def __init__(self, transform: SLDependency, lines: List[str], encoder, index: Optional[int] = None):
+        super().__init__(transform, index)
+        self.annotations = dict()
+        self.values = list()
+
+        for i, line in enumerate(lines):
+            value = line.split('\t')
+            if value[0].startswith('#') or not value[0].isdigit():
+                self.annotations[-i-1] = line
+            else:
+                self.annotations[len(self.values)] = line
+                self.values.append(value)
+
+        # convert values into nodes
+        tree = D_Tree.from_string('\n'.join(['\t'.join(value) for value in self.values]))
+
+        # linearize tree
+        linearized_tree = encoder.encode(tree)
+
+        self.values = list(zip(*self.values))
+        self.values[6] = tuple(map(int, self.values[6]))
+
+        # add labels
+        labels = tuple(str(label.xi) for label in linearized_tree.labels)
+        self.values.append(labels)
+
+
+
+    def __repr__(self):
+        # cover the raw lines
+        merged = {**self.annotations,
+                  **{i: '\t'.join(map(str, line))
+                     for i, line in enumerate(zip(*self.values[:-1]))}}
+        return '\n'.join(merged.values()) + '\n'
diff --git a/tania_scripts/supar/models/dep/vi/__init__.py b/tania_scripts/supar/models/dep/vi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..18dc35558d4f383929d588d3f70d319684afb7ec
--- /dev/null
+++ b/tania_scripts/supar/models/dep/vi/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import VIDependencyModel
+from .parser import VIDependencyParser
+
+__all__ = ['VIDependencyModel', 'VIDependencyParser']
diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78075f4745b4926add51043c086c330542182c05
Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01b920378d10fc9cdbec8a67d9522ea76e4b0702
Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a980f9eebf2cc021c1f6a146c773ebb86340a75
Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d903b7c1fad5509a97624f635495068bf48f4a90
Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..daf22311cfc35994234b5e1731ac3cb9e418b269
Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94aaa00287400fd00c2dabd76be8508ea9161706
Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/dep/vi/model.py b/tania_scripts/supar/models/dep/vi/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ad8c3a123f08919901ec3c06ed404a6d28385cb
--- /dev/null
+++ b/tania_scripts/supar/models/dep/vi/model.py
@@ -0,0 +1,253 @@
+# -*- coding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from supar.models.dep.biaffine.model import BiaffineDependencyModel
+from supar.models.dep.biaffine.transform import CoNLL
+from supar.modules import MLP, Biaffine, Triaffine
+from supar.structs import (DependencyCRF, DependencyLBP, DependencyMFVI,
+                           MatrixTree)
+from supar.utils import Config
+from supar.utils.common import MIN
+
+
+class VIDependencyModel(BiaffineDependencyModel):
+    r"""
+    The implementation of Dependency Parser using Variational Inference :cite:`wang-tu-2020-second`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_rels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [``'char'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 100.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .33.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 800.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_arc_mlp (int):
+            Arc MLP size. Default: 500.
+        n_sib_mlp (int):
+            Binary factor MLP size. Default: 100.
+        n_rel_mlp  (int):
+            Label MLP size. Default: 100.
+        mlp_dropout (float):
+            The dropout ratio of MLP layers. Default: .33.
+        scale (float):
+            Scaling factor for affine scores. Default: 0.
+        inference (str):
+            Approximate inference methods. Default: ``mfvi``.
+        max_iter (int):
+            Max iteration times for inference. Default: 3.
+        interpolation (int):
+            Constant to even out the label/edge loss. Default: .1.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_rels,
+                 n_tags=None,
+                 n_chars=None,
+                 encoder='lstm',
+                 feat=['char'],
+                 n_embed=100,
+                 n_pretrained=100,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.33,
+                 n_encoder_hidden=800,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_arc_mlp=500,
+                 n_sib_mlp=100,
+                 n_rel_mlp=100,
+                 mlp_dropout=.33,
+                 scale=0,
+                 inference='mfvi',
+                 max_iter=3,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.arc_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout)
+        self.arc_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout)
+        self.sib_mlp_s = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout)
+        self.sib_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout)
+        self.sib_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout)
+        self.rel_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout)
+        self.rel_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout)
+
+        self.arc_attn = Biaffine(n_in=n_arc_mlp, scale=scale, bias_x=True, bias_y=False)
+        self.sib_attn = Triaffine(n_in=n_sib_mlp, scale=scale, bias_x=True, bias_y=True)
+        self.rel_attn = Biaffine(n_in=n_rel_mlp, n_out=n_rels, bias_x=True, bias_y=True)
+        self.inference = (DependencyMFVI if inference == 'mfvi' else DependencyLBP)(max_iter)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words, feats=None):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor, ~torch.Tensor:
+                Scores of all possible arcs (``[batch_size, seq_len, seq_len]``),
+                dependent-head-sibling triples (``[batch_size, seq_len, seq_len, seq_len]``) and
+                all possible labels on each arc (``[batch_size, seq_len, seq_len, n_labels]``).
+        """
+
+        x = self.encode(words, feats)
+        mask = words.ne(self.args.pad_index) if len(words.shape) < 3 else words.ne(self.args.pad_index).any(-1)
+
+        arc_d = self.arc_mlp_d(x)
+        arc_h = self.arc_mlp_h(x)
+        sib_s = self.sib_mlp_s(x)
+        sib_d = self.sib_mlp_d(x)
+        sib_h = self.sib_mlp_h(x)
+        rel_d = self.rel_mlp_d(x)
+        rel_h = self.rel_mlp_h(x)
+
+        # [batch_size, seq_len, seq_len]
+        s_arc = self.arc_attn(arc_d, arc_h).masked_fill_(~mask.unsqueeze(1), MIN)
+        # [batch_size, seq_len, seq_len, seq_len]
+        s_sib = self.sib_attn(sib_s, sib_d, sib_h).permute(0, 3, 1, 2)
+        # [batch_size, seq_len, seq_len, n_rels]
+        s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1)
+
+        return s_arc, s_sib, s_rel
+
+    def loss(self, s_arc, s_sib, s_rel, arcs, rels, mask):
+        r"""
+        Args:
+            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible arcs.
+            s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+                Scores of all possible dependent-head-sibling triples.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each arc.
+            arcs (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard arcs.
+            rels (~torch.LongTensor): ``[batch_size, seq_len]``.
+                The tensor of gold-standard labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        arc_loss, marginals = self.inference((s_arc, s_sib), mask, arcs)
+        s_rel, rels = s_rel[mask], rels[mask]
+        s_rel = s_rel[torch.arange(len(rels)), arcs[mask]]
+        rel_loss = self.criterion(s_rel, rels)
+        loss = arc_loss + rel_loss
+        return loss, marginals
+
+    def decode(self, s_arc, s_rel, mask, tree=False, proj=False):
+        r"""
+        Args:
+            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible arcs.
+            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each arc.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+            tree (bool):
+                If ``True``, ensures to output well-formed trees. Default: ``False``.
+            proj (bool):
+                If ``True``, ensures to output projective trees. Default: ``False``.
+
+        Returns:
+            ~torch.LongTensor, ~torch.LongTensor:
+                Predicted arcs and labels of shape ``[batch_size, seq_len]``.
+        """
+
+        lens = mask.sum(1)
+        arc_preds = s_arc.argmax(-1)
+        bad = [not CoNLL.istree(seq[1:i+1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist())]
+        if tree and any(bad):
+            arc_preds[bad] = (DependencyCRF if proj else MatrixTree)(s_arc[bad], mask[bad].sum(-1)).argmax
+        rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
+
+        return arc_preds, rel_preds
diff --git a/tania_scripts/supar/models/dep/vi/parser.py b/tania_scripts/supar/models/dep/vi/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3808a3daa43247aa099f753e4a44d1cdfd480720
--- /dev/null
+++ b/tania_scripts/supar/models/dep/vi/parser.py
@@ -0,0 +1,124 @@
+# -*- coding: utf-8 -*-
+
+from typing import Iterable, Union
+
+import torch
+
+from supar.models.dep.biaffine.parser import BiaffineDependencyParser
+from supar.models.dep.vi.model import VIDependencyModel
+from supar.utils import Config
+from supar.utils.fn import ispunct
+from supar.utils.logging import get_logger
+from supar.utils.metric import AttachmentMetric
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class VIDependencyParser(BiaffineDependencyParser):
+    r"""
+    The implementation of Dependency Parser using Variational Inference :cite:`wang-tu-2020-second`.
+    """
+
+    NAME = 'vi-dependency'
+    MODEL = VIDependencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        tree: bool = False,
+        proj: bool = False,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        punct: bool = False,
+        tree: bool = True,
+        proj: bool = True,
+        partial: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        tree: bool = True,
+        proj: bool = True,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, _, *feats, arcs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_sib, s_rel = self.model(words, feats)
+        loss, *_ = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> AttachmentMetric:
+        words, _, *feats, arcs, rels = batch
+        mask = batch.mask
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_sib, s_rel = self.model(words, feats)
+        loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask)
+        arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
+        if self.args.partial:
+            mask &= arcs.ge(0)
+        # ignore all punctuation if not specified
+        if not self.args.punct:
+            mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
+        return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask)
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, _, *feats = batch
+        mask, lens = batch.mask, (batch.lens - 1).tolist()
+        # ignore the first token of each sentence
+        mask[:, 0] = 0
+        s_arc, s_sib, s_rel = self.model(words, feats)
+        s_arc = self.model.inference((s_arc, s_sib), mask)
+        arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
+        batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
+        batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
+        if self.args.prob:
+            batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.unbind())]
+        return batch
diff --git a/tania_scripts/supar/models/sdp/__init__.py b/tania_scripts/supar/models/sdp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..633e238449fbf947b867006579f439d8b8e78156
--- /dev/null
+++ b/tania_scripts/supar/models/sdp/__init__.py
@@ -0,0 +1,7 @@
+# -*- coding: utf-8 -*-
+
+from .biaffine import BiaffineSemanticDependencyModel, BiaffineSemanticDependencyParser
+from .vi import VISemanticDependencyModel, VISemanticDependencyParser
+
+__all__ = ['BiaffineSemanticDependencyModel', 'BiaffineSemanticDependencyParser',
+           'VISemanticDependencyModel', 'VISemanticDependencyParser']
diff --git a/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02080dfc284fa4d868f12d9e49fca7f0775a90f5
Binary files /dev/null and b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f809f66bbc799ef64855403f747872f7bcc4dc85
Binary files /dev/null and b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/sdp/biaffine/__init__.py b/tania_scripts/supar/models/sdp/biaffine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab2feeeb4e81b099086b73e78ae81d5fccbbe82c
--- /dev/null
+++ b/tania_scripts/supar/models/sdp/biaffine/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import BiaffineSemanticDependencyModel
+from .parser import BiaffineSemanticDependencyParser
+
+__all__ = ['BiaffineSemanticDependencyModel', 'BiaffineSemanticDependencyParser']
diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4845ca571234eecab6e183073d4864e1ec667a85
Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9896718732a155aafb0ff99c6979f1448f90fcda
Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3927ea658cbef06b9e352886b378c4743b3c16b
Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c1c571e2bcd578989340f70f59cc638e9cb644e
Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9143362fa65ddfd6a27203f97dd01c2a6a45f3e
Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3e9e25a0966252386b63bb170c782847ca49afb
Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/sdp/biaffine/model.py b/tania_scripts/supar/models/sdp/biaffine/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a7afa4dde4f2d7dee3d764ab88d1e9c899a02d7
--- /dev/null
+++ b/tania_scripts/supar/models/sdp/biaffine/model.py
@@ -0,0 +1,222 @@
+# -*- coding: utf-8 -*-
+
+import torch.nn as nn
+from supar.model import Model
+from supar.modules import MLP, Biaffine
+from supar.utils import Config
+
+
+class BiaffineSemanticDependencyModel(Model):
+    r"""
+    The implementation of Biaffine Semantic Dependency Parser :cite:`dozat-manning-2018-simpler`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        n_lemmas (int):
+            The number of lemmas, required if lemma embeddings are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'lemma'``: Lemma embeddings.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [ ``'tag'``, ``'char'``, ``'lemma'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word representations. Default: 125.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .2.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 1200.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_edge_mlp (int):
+            Edge MLP size. Default: 600.
+        n_label_mlp  (int):
+            Label MLP size. Default: 600.
+        edge_mlp_dropout (float):
+            The dropout ratio of edge MLP layers. Default: .25.
+        label_mlp_dropout (float):
+            The dropout ratio of label MLP layers. Default: .33.
+        interpolation (int):
+            Constant to even out the label/edge loss. Default: .1.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=None,
+                 n_chars=None,
+                 n_lemmas=None,
+                 encoder='lstm',
+                 feat=['tag', 'char', 'lemma'],
+                 n_embed=100,
+                 n_pretrained=125,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=400,
+                 char_pad_index=0,
+                 char_dropout=0.33,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.2,
+                 n_encoder_hidden=1200,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_edge_mlp=600,
+                 n_label_mlp=600,
+                 edge_mlp_dropout=.25,
+                 label_mlp_dropout=.33,
+                 interpolation=0.1,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.edge_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False)
+        self.edge_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False)
+        self.label_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
+        self.label_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
+
+        self.edge_attn = Biaffine(n_in=n_edge_mlp, n_out=2, bias_x=True, bias_y=True)
+        self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def load_pretrained(self, embed=None):
+        if embed is not None:
+            self.pretrained = nn.Embedding.from_pretrained(embed)
+            if embed.shape[1] != self.args.n_pretrained:
+                self.embed_proj = nn.Linear(embed.shape[1], self.args.n_pretrained)
+        return self
+
+    def forward(self, words, feats=None):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first tensor of shape ``[batch_size, seq_len, seq_len, 2]`` holds scores of all possible edges.
+                The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
+                scores of all possible labels on each edge.
+        """
+
+        x = self.encode(words, feats)
+
+        edge_d = self.edge_mlp_d(x)
+        edge_h = self.edge_mlp_h(x)
+        label_d = self.label_mlp_d(x)
+        label_h = self.label_mlp_h(x)
+
+        # [batch_size, seq_len, seq_len, 2]
+        s_edge = self.edge_attn(edge_d, edge_h).permute(0, 2, 3, 1)
+        # [batch_size, seq_len, seq_len, n_labels]
+        s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1)
+
+        return s_edge, s_label
+
+    def loss(self, s_edge, s_label, labels, mask):
+        r"""
+        Args:
+            s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
+                Scores of all possible edges.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each edge.
+            labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                The tensor of gold-standard labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        edge_mask = labels.ge(0) & mask
+        edge_loss = self.criterion(s_edge[mask], edge_mask[mask].long())
+        label_loss = self.criterion(s_label[edge_mask], labels[edge_mask])
+        return self.args.interpolation * label_loss + (1 - self.args.interpolation) * edge_loss
+
+    def decode(self, s_edge, s_label):
+        r"""
+        Args:
+            s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
+                Scores of all possible edges.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each edge.
+
+        Returns:
+            ~torch.LongTensor:
+                Predicted labels of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        return s_label.argmax(-1).masked_fill_(s_edge.argmax(-1).lt(1), -1)
diff --git a/tania_scripts/supar/models/sdp/biaffine/parser.py b/tania_scripts/supar/models/sdp/biaffine/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..c28f6c229a5d33f433b2e19378b7966eb442fb74
--- /dev/null
+++ b/tania_scripts/supar/models/sdp/biaffine/parser.py
@@ -0,0 +1,202 @@
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Iterable, Union
+
+import torch
+
+from supar.models.dep.biaffine.transform import CoNLL
+from supar.models.sdp.biaffine import BiaffineSemanticDependencyModel
+from supar.parser import Parser
+from supar.utils import Config, Dataset, Embedding
+from supar.utils.common import BOS, PAD, UNK
+from supar.utils.field import ChartField, Field, RawField, SubwordField
+from supar.utils.logging import get_logger
+from supar.utils.metric import ChartMetric
+from supar.utils.tokenizer import TransformerTokenizer
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class BiaffineSemanticDependencyParser(Parser):
+    r"""
+    The implementation of Biaffine Semantic Dependency Parser :cite:`dozat-manning-2018-simpler`.
+    """
+
+    NAME = 'biaffine-semantic-dependency'
+    MODEL = BiaffineSemanticDependencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.LEMMA = self.transform.LEMMA
+        self.TAG = self.transform.POS
+        self.LABEL = self.transform.PHEAD
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, *feats, labels = batch
+        mask = batch.mask
+        mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask[:, 0] = 0
+        s_edge, s_label = self.model(words, feats)
+        loss = self.model.loss(s_edge, s_label, labels, mask)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> ChartMetric:
+        words, *feats, labels = batch
+        mask = batch.mask
+        mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask[:, 0] = 0
+        s_edge, s_label = self.model(words, feats)
+        loss = self.model.loss(s_edge, s_label, labels, mask)
+        label_preds = self.model.decode(s_edge, s_label)
+        return ChartMetric(loss, label_preds.masked_fill(~mask, -1), labels.masked_fill(~mask, -1))
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats = batch
+        mask, lens = batch.mask, (batch.lens - 1).tolist()
+        mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask[:, 0] = 0
+        with torch.autocast(self.device, enabled=self.args.amp):
+            s_edge, s_label = self.model(words, feats)
+        label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1)
+        batch.labels = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row]
+                                               for row in chart[1:i, :i].tolist()])
+                        for i, chart in zip(lens, label_preds)]
+        if self.args.prob:
+            batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.softmax(-1).unbind())]
+        return batch
+
+    @classmethod
+    def build(cls, path, min_freq=7, fix_len=20, **kwargs):
+        r"""
+        Build a brand-new Parser, including initialization of all data fields and model parameters.
+
+        Args:
+            path (str):
+                The path of the model to be saved.
+            min_freq (str):
+                The minimum frequency needed to include a token in the vocabulary. Default:7.
+            fix_len (int):
+                The max length of all subword pieces. The excess part of each piece will be truncated.
+                Required if using CharLSTM/BERT.
+                Default: 20.
+            kwargs (Dict):
+                A dict holding the unconsumed arguments.
+        """
+
+        args = Config(**locals())
+        os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+        if os.path.exists(path) and not args.build:
+            parser = cls.load(**args)
+            parser.model = cls.MODEL(**parser.args)
+            parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device)
+            return parser
+
+        logger.info("Building the fields")
+        WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True)
+        TAG, CHAR, LEMMA, ELMO, BERT = None, None, None, None, None
+        if args.encoder == 'bert':
+            t = TransformerTokenizer(args.bert)
+            WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t)
+            WORD.vocab = t.vocab
+        else:
+            WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True)
+            if 'tag' in args.feat:
+                TAG = Field('tags', bos=BOS)
+            if 'char' in args.feat:
+                CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len)
+            if 'lemma' in args.feat:
+                LEMMA = Field('lemmas', pad=PAD, unk=UNK, bos=BOS, lower=True)
+            if 'elmo' in args.feat:
+                from allennlp.modules.elmo import batch_to_ids
+                ELMO = RawField('elmo')
+                ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device)
+            if 'bert' in args.feat:
+                t = TransformerTokenizer(args.bert)
+                BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t)
+                BERT.vocab = t.vocab
+        LABEL = ChartField('labels', fn=CoNLL.get_labels)
+        transform = CoNLL(FORM=(WORD, CHAR, ELMO, BERT), LEMMA=LEMMA, POS=TAG, PHEAD=LABEL)
+
+        train = Dataset(transform, args.train, **args)
+        if args.encoder != 'bert':
+            WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
+            if TAG is not None:
+                TAG.build(train)
+            if CHAR is not None:
+                CHAR.build(train)
+            if LEMMA is not None:
+                LEMMA.build(train)
+        LABEL.build(train)
+        args.update({
+            'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init,
+            'n_labels': len(LABEL.vocab),
+            'n_tags': len(TAG.vocab) if TAG is not None else None,
+            'n_chars': len(CHAR.vocab) if CHAR is not None else None,
+            'char_pad_index': CHAR.pad_index if CHAR is not None else None,
+            'n_lemmas': len(LEMMA.vocab) if LEMMA is not None else None,
+            'bert_pad_index': BERT.pad_index if BERT is not None else None,
+            'pad_index': WORD.pad_index,
+            'unk_index': WORD.unk_index,
+            'bos_index': WORD.bos_index
+        })
+        logger.info(f"{transform}")
+
+        logger.info("Building the model")
+        model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None)
+        logger.info(f"{model}\n")
+
+        parser = cls(args, model, transform)
+        parser.model.to(parser.device)
+        return parser
diff --git a/tania_scripts/supar/models/sdp/vi/__init__.py b/tania_scripts/supar/models/sdp/vi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aae65de0fcf695c9c98e76c3f6575e9b69c5577
--- /dev/null
+++ b/tania_scripts/supar/models/sdp/vi/__init__.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from .model import VISemanticDependencyModel
+from .parser import VISemanticDependencyParser
+
+__all__ = ['VISemanticDependencyModel', 'VISemanticDependencyParser']
diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b7e00e74f70f0d965c505e8994b7e119a0dec34
Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e17f748d964d58b705b5a14d52e4f3758f8318e7
Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cfb7bd073d93cc874641db0353d3cc92dcbf402
Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac266759295a48ebd875a15b1b28e225c633c0b6
Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ceb658d3c6b5f0ceae68bdf73fc482f9763c6d5f
Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-310.pyc differ
diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b28e31b2c498c355f12078f27de07f420d6c8a0
Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-311.pyc differ
diff --git a/tania_scripts/supar/models/sdp/vi/model.py b/tania_scripts/supar/models/sdp/vi/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..12c20e1af605ee1c918ddcafe3bb64a9baa52146
--- /dev/null
+++ b/tania_scripts/supar/models/sdp/vi/model.py
@@ -0,0 +1,471 @@
+# -*- coding: utf-8 -*-
+
+import torch.nn as nn
+from supar.model import Model
+from supar.modules import MLP, Biaffine, Triaffine
+from supar.structs import SemanticDependencyLBP, SemanticDependencyMFVI
+from supar.utils import Config
+
+
+class BiaffineSemanticDependencyModel(Model):
+    r"""
+    The implementation of Biaffine Semantic Dependency Parser :cite:`dozat-manning-2018-simpler`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        n_lemmas (int):
+            The number of lemmas, required if lemma embeddings are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'lemma'``: Lemma embeddings.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [ ``'tag'``, ``'char'``, ``'lemma'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word representations. Default: 125.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .2.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 1200.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_edge_mlp (int):
+            Edge MLP size. Default: 600.
+        n_label_mlp  (int):
+            Label MLP size. Default: 600.
+        edge_mlp_dropout (float):
+            The dropout ratio of edge MLP layers. Default: .25.
+        label_mlp_dropout (float):
+            The dropout ratio of label MLP layers. Default: .33.
+        interpolation (int):
+            Constant to even out the label/edge loss. Default: .1.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=None,
+                 n_chars=None,
+                 n_lemmas=None,
+                 encoder='lstm',
+                 feat=['tag', 'char', 'lemma'],
+                 n_embed=100,
+                 n_pretrained=125,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=400,
+                 char_pad_index=0,
+                 char_dropout=0.33,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.2,
+                 n_encoder_hidden=1200,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_edge_mlp=600,
+                 n_label_mlp=600,
+                 edge_mlp_dropout=.25,
+                 label_mlp_dropout=.33,
+                 interpolation=0.1,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.edge_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False)
+        self.edge_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False)
+        self.label_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
+        self.label_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
+
+        self.edge_attn = Biaffine(n_in=n_edge_mlp, n_out=2, bias_x=True, bias_y=True)
+        self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def load_pretrained(self, embed=None):
+        if embed is not None:
+            self.pretrained = nn.Embedding.from_pretrained(embed)
+            if embed.shape[1] != self.args.n_pretrained:
+                self.embed_proj = nn.Linear(embed.shape[1], self.args.n_pretrained)
+        return self
+
+    def forward(self, words, feats=None):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first tensor of shape ``[batch_size, seq_len, seq_len, 2]`` holds scores of all possible edges.
+                The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
+                scores of all possible labels on each edge.
+        """
+
+        x = self.encode(words, feats)
+
+        edge_d = self.edge_mlp_d(x)
+        edge_h = self.edge_mlp_h(x)
+        label_d = self.label_mlp_d(x)
+        label_h = self.label_mlp_h(x)
+
+        # [batch_size, seq_len, seq_len, 2]
+        s_edge = self.edge_attn(edge_d, edge_h).permute(0, 2, 3, 1)
+        # [batch_size, seq_len, seq_len, n_labels]
+        s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1)
+
+        return s_edge, s_label
+
+    def loss(self, s_edge, s_label, labels, mask):
+        r"""
+        Args:
+            s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
+                Scores of all possible edges.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each edge.
+            labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                The tensor of gold-standard labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            ~torch.Tensor:
+                The training loss.
+        """
+
+        edge_mask = labels.ge(0) & mask
+        edge_loss = self.criterion(s_edge[mask], edge_mask[mask].long())
+        label_loss = self.criterion(s_label[edge_mask], labels[edge_mask])
+        return self.args.interpolation * label_loss + (1 - self.args.interpolation) * edge_loss
+
+    def decode(self, s_edge, s_label):
+        r"""
+        Args:
+            s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
+                Scores of all possible edges.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each edge.
+
+        Returns:
+            ~torch.LongTensor:
+                Predicted labels of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        return s_label.argmax(-1).masked_fill_(s_edge.argmax(-1).lt(1), -1)
+
+
+class VISemanticDependencyModel(BiaffineSemanticDependencyModel):
+    r"""
+    The implementation of Semantic Dependency Parser using Variational Inference :cite:`wang-etal-2019-second`.
+
+    Args:
+        n_words (int):
+            The size of the word vocabulary.
+        n_labels (int):
+            The number of labels in the treebank.
+        n_tags (int):
+            The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
+        n_chars (int):
+            The number of characters, required if character-level representations are used. Default: ``None``.
+        n_lemmas (int):
+            The number of lemmas, required if lemma embeddings are used. Default: ``None``.
+        encoder (str):
+            Encoder to use.
+            ``'lstm'``: BiLSTM encoder.
+            ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
+            Default: ``'lstm'``.
+        feat (List[str]):
+            Additional features to use, required if ``encoder='lstm'``.
+            ``'tag'``: POS tag embeddings.
+            ``'char'``: Character-level representations extracted by CharLSTM.
+            ``'lemma'``: Lemma embeddings.
+            ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
+            Default: [ ``'tag'``, ``'char'``, ``'lemma'``].
+        n_embed (int):
+            The size of word embeddings. Default: 100.
+        n_pretrained (int):
+            The size of pretrained word embeddings. Default: 125.
+        n_feat_embed (int):
+            The size of feature representations. Default: 100.
+        n_char_embed (int):
+            The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
+        n_char_hidden (int):
+            The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
+        char_pad_index (int):
+            The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
+        elmo (str):
+            Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
+        elmo_bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
+            Default: ``(True, False)``.
+        bert (str):
+            Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
+            This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
+            Default: ``None``.
+        n_bert_layers (int):
+            Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
+            The final outputs would be weighted sum of the hidden states of these layers.
+            Default: 4.
+        mix_dropout (float):
+            The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
+        bert_pooling (str):
+            Pooling way to get token embeddings.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        bert_pad_index (int):
+            The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
+            Default: 0.
+        finetune (bool):
+            If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
+        n_plm_embed (int):
+            The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        embed_dropout (float):
+            The dropout ratio of input embeddings. Default: .2.
+        n_encoder_hidden (int):
+            The size of encoder hidden states. Default: 1200.
+        n_encoder_layers (int):
+            The number of encoder layers. Default: 3.
+        encoder_dropout (float):
+            The dropout ratio of encoder layer. Default: .33.
+        n_edge_mlp (int):
+            Unary factor MLP size. Default: 600.
+        n_pair_mlp (int):
+            Binary factor MLP size. Default: 150.
+        n_label_mlp  (int):
+            Label MLP size. Default: 600.
+        edge_mlp_dropout (float):
+            The dropout ratio of unary edge factor MLP layers. Default: .25.
+        pair_mlp_dropout (float):
+            The dropout ratio of binary factor MLP layers. Default: .25.
+        label_mlp_dropout (float):
+            The dropout ratio of label MLP layers. Default: .33.
+        inference (str):
+            Approximate inference methods. Default: ``mfvi``.
+        max_iter (int):
+            Max iteration times for inference. Default: 3.
+        interpolation (int):
+            Constant to even out the label/edge loss. Default: .1.
+        pad_index (int):
+            The index of the padding token in the word vocabulary. Default: 0.
+        unk_index (int):
+            The index of the unknown token in the word vocabulary. Default: 1.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(self,
+                 n_words,
+                 n_labels,
+                 n_tags=None,
+                 n_chars=None,
+                 n_lemmas=None,
+                 encoder='lstm',
+                 feat=['tag', 'char', 'lemma'],
+                 n_embed=100,
+                 n_pretrained=125,
+                 n_feat_embed=100,
+                 n_char_embed=50,
+                 n_char_hidden=100,
+                 char_pad_index=0,
+                 char_dropout=0,
+                 elmo='original_5b',
+                 elmo_bos_eos=(True, False),
+                 bert=None,
+                 n_bert_layers=4,
+                 mix_dropout=.0,
+                 bert_pooling='mean',
+                 bert_pad_index=0,
+                 finetune=False,
+                 n_plm_embed=0,
+                 embed_dropout=.2,
+                 n_encoder_hidden=1200,
+                 n_encoder_layers=3,
+                 encoder_dropout=.33,
+                 n_edge_mlp=600,
+                 n_pair_mlp=150,
+                 n_label_mlp=600,
+                 edge_mlp_dropout=.25,
+                 pair_mlp_dropout=.25,
+                 label_mlp_dropout=.33,
+                 inference='mfvi',
+                 max_iter=3,
+                 interpolation=0.1,
+                 pad_index=0,
+                 unk_index=1,
+                 **kwargs):
+        super().__init__(**Config().update(locals()))
+
+        self.edge_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False)
+        self.edge_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False)
+        self.pair_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=pair_mlp_dropout, activation=False)
+        self.pair_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=pair_mlp_dropout, activation=False)
+        self.pair_mlp_g = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=pair_mlp_dropout, activation=False)
+        self.label_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
+        self.label_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
+
+        self.edge_attn = Biaffine(n_in=n_edge_mlp, bias_x=True, bias_y=True)
+        self.sib_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=True)
+        self.cop_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=True)
+        self.grd_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=True)
+        self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True)
+        self.inference = (SemanticDependencyMFVI if inference == 'mfvi' else SemanticDependencyLBP)(max_iter)
+        self.criterion = nn.CrossEntropyLoss()
+
+    def forward(self, words, feats=None):
+        r"""
+        Args:
+            words (~torch.LongTensor): ``[batch_size, seq_len]``.
+                Word indices.
+            feats (List[~torch.LongTensor]):
+                A list of feat indices.
+                The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
+                or ``[batch_size, seq_len]`` otherwise.
+                Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor:
+                The first and last are scores of all possible edges of shape ``[batch_size, seq_len, seq_len]``
+                and possible labels on each edge of shape ``[batch_size, seq_len, seq_len, n_labels]``.
+                Others are scores of second-order sibling, coparent and grandparent factors
+                (``[batch_size, seq_len, seq_len, seq_len]``).
+
+        """
+
+        x = self.encode(words, feats)
+
+        edge_d = self.edge_mlp_d(x)
+        edge_h = self.edge_mlp_h(x)
+        pair_d = self.pair_mlp_d(x)
+        pair_h = self.pair_mlp_h(x)
+        pair_g = self.pair_mlp_g(x)
+        label_d = self.label_mlp_d(x)
+        label_h = self.label_mlp_h(x)
+
+        # [batch_size, seq_len, seq_len]
+        s_edge = self.edge_attn(edge_d, edge_h)
+        # [batch_size, seq_len, seq_len, seq_len], (d->h->s)
+        s_sib = self.sib_attn(pair_d, pair_d, pair_h)
+        s_sib = (s_sib.triu() + s_sib.triu(1).transpose(-1, -2)).permute(0, 3, 1, 2)
+        # [batch_size, seq_len, seq_len, seq_len], (d->h->c)
+        s_cop = self.cop_attn(pair_h, pair_d, pair_h).permute(0, 3, 1, 2)
+        s_cop = s_cop.triu() + s_cop.triu(1).transpose(-1, -2)
+        # [batch_size, seq_len, seq_len, seq_len], (d->h->g)
+        s_grd = self.grd_attn(pair_g, pair_d, pair_h).permute(0, 3, 1, 2)
+        # [batch_size, seq_len, seq_len, n_labels]
+        s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1)
+
+        return s_edge, s_sib, s_cop, s_grd, s_label
+
+    def loss(self, s_edge, s_sib, s_cop, s_grd, s_label, labels, mask):
+        r"""
+        Args:
+            s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible edges.
+            s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+                Scores of all possible dependent-head-sibling triples.
+            s_cop (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+                Scores of all possible dependent-head-coparent triples.
+            s_grd (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+                Scores of all possible dependent-head-grandparent triples.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each edge.
+            labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                The tensor of gold-standard labels.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The training loss and marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        edge_mask = labels.ge(0) & mask
+        edge_loss, marginals = self.inference((s_edge, s_sib, s_cop, s_grd), mask, edge_mask.long())
+        label_loss = self.criterion(s_label[edge_mask], labels[edge_mask])
+        loss = self.args.interpolation * label_loss + (1 - self.args.interpolation) * edge_loss
+        return loss, marginals
+
+    def decode(self, s_edge, s_label):
+        r"""
+        Args:
+            s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+                Scores of all possible edges.
+            s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+                Scores of all possible labels on each edge.
+
+        Returns:
+            ~torch.LongTensor:
+                Predicted labels of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        return s_label.argmax(-1).masked_fill_(s_edge.lt(0.5), -1)
diff --git a/tania_scripts/supar/models/sdp/vi/parser.py b/tania_scripts/supar/models/sdp/vi/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbb85c863e992f5f303537867925c816b6d7692c
--- /dev/null
+++ b/tania_scripts/supar/models/sdp/vi/parser.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+
+from typing import Iterable, Union
+
+import torch
+
+from supar.models.dep.biaffine.transform import CoNLL
+from supar.models.sdp.biaffine.parser import BiaffineSemanticDependencyParser
+from supar.models.sdp.vi.model import VISemanticDependencyModel
+from supar.utils import Config
+from supar.utils.logging import get_logger
+from supar.utils.metric import ChartMetric
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class VISemanticDependencyParser(BiaffineSemanticDependencyParser):
+    r"""
+    The implementation of Semantic Dependency Parser using Variational Inference :cite:`wang-etal-2019-second`.
+    """
+
+    NAME = 'vi-semantic-dependency'
+    MODEL = VISemanticDependencyModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.LEMMA = self.transform.LEMMA
+        self.TAG = self.transform.POS
+        self.LABEL = self.transform.PHEAD
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int = 1000,
+        patience: int = 100,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().train(**Config().update(locals()))
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().evaluate(**Config().update(locals()))
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        return super().predict(**Config().update(locals()))
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        words, *feats, labels = batch
+        mask = batch.mask
+        mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask[:, 0] = 0
+        s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats)
+        loss, s_edge = self.model.loss(s_edge, s_sib, s_cop, s_grd, s_label, labels, mask)
+        return loss
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> ChartMetric:
+        words, *feats, labels = batch
+        mask = batch.mask
+        mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask[:, 0] = 0
+        s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats)
+        loss, s_edge = self.model.loss(s_edge, s_sib, s_cop, s_grd, s_label, labels, mask)
+        label_preds = self.model.decode(s_edge, s_label)
+        return ChartMetric(loss, label_preds.masked_fill(~mask, -1), labels.masked_fill(~mask, -1))
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        words, *feats = batch
+        mask, lens = batch.mask, (batch.lens - 1).tolist()
+        mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask[:, 0] = 0
+        s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats)
+        s_edge = self.model.inference((s_edge, s_sib, s_cop, s_grd), mask)
+        label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1)
+        batch.labels = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row]
+                                               for row in chart[1:i, :i].tolist()])
+                        for i, chart in zip(lens, label_preds)]
+        if self.args.prob:
+            batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.unbind())]
+        return batch
diff --git a/tania_scripts/supar/modules/__init__.py b/tania_scripts/supar/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..981a6bfbfd0d05203f949840d34d90e2a9097bad
--- /dev/null
+++ b/tania_scripts/supar/modules/__init__.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+
+from .affine import Biaffine, Triaffine
+from .dropout import IndependentDropout, SharedDropout, TokenDropout
+from .gnn import GraphConvolutionalNetwork
+from .lstm import CharLSTM, VariationalLSTM
+from .mlp import MLP
+from .decoder import DecoderLSTM
+from .pretrained import ELMoEmbedding, TransformerEmbedding
+from .transformer import (TransformerDecoder, TransformerEncoder,
+                          TransformerWordEmbedding)
+
+__all__ = ['Biaffine', 'Triaffine',
+           'IndependentDropout', 'SharedDropout', 'TokenDropout',
+           'GraphConvolutionalNetwork',
+           'CharLSTM', 'VariationalLSTM',
+           'MLP', 'DecoderLSTM',
+           'ELMoEmbedding', 'TransformerEmbedding',
+           'TransformerWordEmbedding',
+           'TransformerDecoder', 'TransformerEncoder']
diff --git a/tania_scripts/supar/modules/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c13e44da2a22e9ce2413ca58f980bfe9c2110b8
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8dc3d835d3ccc7fab896ae595885b0226ecebc2
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/affine.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/affine.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39667f1870be23948663babfaa2080b05451c54f
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/affine.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/affine.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/affine.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c47fed9d260e0af01681cd36e39620f2209a26b3
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/affine.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/decoder.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76a4b2187f9eabbe13f6151c3b2f5c962b577852
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/decoder.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/decoder.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/decoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6de6a887d7ce612ec0a7ef15992c489fb6b21aa7
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/decoder.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/dropout.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/dropout.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9431bcc12f190fe9a1715caba2b9b9b578ef7e8
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/dropout.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/dropout.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/dropout.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24e3f7b05eeb7b0809eb5b72992b12d9e5c160d4
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/dropout.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/gnn.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/gnn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c9899ee4c9fc3620942fc1768b69b489f133280
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/gnn.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/gnn.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/gnn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7471454ca9e111edcc54852b2bfeafb07ef79179
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/gnn.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/lstm.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/lstm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b41ac7e996529a8790ae248cd831e01297c48eb3
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/lstm.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/lstm.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/lstm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cbe023f8403b631f77886bc738ca050d7e7cf659
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/lstm.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/mlp.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/mlp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..856587590500e87d19034d97b32fba8d05cf1ddb
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/mlp.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/mlp.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/mlp.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69c92f7adbd0833ce5ccf154ad0997ff3e2cb672
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/mlp.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/pretrained.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe5270597f6a76ea7e0866c0b20e3227024c7199
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/pretrained.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75014d025196b5be3ee840b78c519e0d692f6c4c
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/transformer.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0aa648d7943086a791868bfe31ccd2dac17e4d0e
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/transformer.cpython-310.pyc differ
diff --git a/tania_scripts/supar/modules/__pycache__/transformer.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb5927d047ca6f9c5b2983417896033e2322681e
Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/transformer.cpython-311.pyc differ
diff --git a/tania_scripts/supar/modules/affine.py b/tania_scripts/supar/modules/affine.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5defbfddebaf66a687b0ed15c0a4e557a69683
--- /dev/null
+++ b/tania_scripts/supar/modules/affine.py
@@ -0,0 +1,260 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+from supar.modules.mlp import MLP
+
+
+class Biaffine(nn.Module):
+    r"""
+    Biaffine layer for first-order scoring :cite:`dozat-etal-2017-biaffine`.
+
+    This function has a tensor of weights :math:`W` and bias terms if needed.
+    The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y / d^s`,
+    where `d` and `s` are vector dimension and scaling factor respectively.
+    :math:`x` and :math:`y` can be concatenated with bias terms.
+
+    Args:
+        n_in (int):
+            The size of the input feature.
+        n_out (int):
+            The number of output channels.
+        n_proj (Optional[int]):
+            If specified, applies MLP layers to reduce vector dimensions. Default: ``None``.
+        dropout (Optional[float]):
+            If specified, applies a :class:`SharedDropout` layer with the ratio on MLP outputs. Default: 0.
+        scale (float):
+            Factor to scale the scores. Default: 0.
+        bias_x (bool):
+            If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``.
+        bias_y (bool):
+            If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``.
+        decompose (bool):
+            If ``True``, represents the weight as the product of 2 independent matrices. Default: ``False``.
+        init (Callable):
+            Callable initialization method. Default: `nn.init.zeros_`.
+    """
+
+    def __init__(
+        self,
+        n_in: int,
+        n_out: int = 1,
+        n_proj: Optional[int] = None,
+        dropout: Optional[float] = 0,
+        scale: int = 0,
+        bias_x: bool = True,
+        bias_y: bool = True,
+        decompose: bool = False,
+        init: Callable = nn.init.zeros_
+    ) -> Biaffine:
+        super().__init__()
+
+        self.n_in = n_in
+        self.n_out = n_out
+        self.n_proj = n_proj
+        self.dropout = dropout
+        self.scale = scale
+        self.bias_x = bias_x
+        self.bias_y = bias_y
+        self.decompose = decompose
+        self.init = init
+
+        if n_proj is not None:
+            self.mlp_x, self.mlp_y = MLP(n_in, n_proj, dropout), MLP(n_in, n_proj, dropout)
+        self.n_model = n_proj or n_in
+        if not decompose:
+            self.weight = nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x, self.n_model + bias_y))
+        else:
+            self.weight = nn.ParameterList((nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x)),
+                                            nn.Parameter(torch.Tensor(n_out, self.n_model + bias_y))))
+
+        self.reset_parameters()
+
+    def __repr__(self):
+        s = f"n_in={self.n_in}"
+        if self.n_out > 1:
+            s += f", n_out={self.n_out}"
+        if self.n_proj is not None:
+            s += f", n_proj={self.n_proj}"
+        if self.dropout > 0:
+            s += f", dropout={self.dropout}"
+        if self.scale != 0:
+            s += f", scale={self.scale}"
+        if self.bias_x:
+            s += f", bias_x={self.bias_x}"
+        if self.bias_y:
+            s += f", bias_y={self.bias_y}"
+        if self.decompose:
+            s += f", decompose={self.decompose}"
+        return f"{self.__class__.__name__}({s})"
+
+    def reset_parameters(self):
+        if self.decompose:
+            for i in self.weight:
+                self.init(i)
+        else:
+            self.init(self.weight)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        y: torch.Tensor
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+            y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+
+        Returns:
+            ~torch.Tensor:
+                A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``.
+                If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
+        """
+
+        if hasattr(self, 'mlp_x'):
+            x, y = self.mlp_x(x), self.mlp_y(y)
+        if self.bias_x:
+            x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
+        if self.bias_y:
+            y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
+        # [batch_size, n_out, seq_len, seq_len]
+        if self.decompose:
+            wx = torch.einsum('bxi,oi->box', x, self.weight[0])
+            wy = torch.einsum('byj,oj->boy', y, self.weight[1])
+            s = torch.einsum('box,boy->boxy', wx, wy)
+        else:
+            s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y)
+        return s.squeeze(1) / self.n_in ** self.scale
+
+
+class Triaffine(nn.Module):
+    r"""
+    Triaffine layer for second-order scoring :cite:`zhang-etal-2020-efficient,wang-etal-2019-second`.
+
+    This function has a tensor of weights :math:`W` and bias terms if needed.
+    The score :math:`s(x, y, z)` of the vector triple :math:`(x, y, z)` is computed as :math:`x^T z^T W y / d^s`,
+    where `d` and `s` are vector dimension and scaling factor respectively.
+    :math:`x` and :math:`y` can be concatenated with bias terms.
+
+    Args:
+        n_in (int):
+            The size of the input feature.
+        n_out (int):
+            The number of output channels.
+        n_proj (Optional[int]):
+            If specified, applies MLP layers to reduce vector dimensions. Default: ``None``.
+        dropout (Optional[float]):
+            If specified, applies a :class:`SharedDropout` layer with the ratio on MLP outputs. Default: 0.
+        scale (float):
+            Factor to scale the scores. Default: 0.
+        bias_x (bool):
+            If ``True``, adds a bias term for tensor :math:`x`. Default: ``False``.
+        bias_y (bool):
+            If ``True``, adds a bias term for tensor :math:`y`. Default: ``False``.
+        decompose (bool):
+            If ``True``, represents the weight as the product of 3 independent matrices. Default: ``False``.
+        init (Callable):
+            Callable initialization method. Default: `nn.init.zeros_`.
+    """
+
+    def __init__(
+        self,
+        n_in: int,
+        n_out: int = 1,
+        n_proj: Optional[int] = None,
+        dropout: Optional[float] = 0,
+        scale: int = 0,
+        bias_x: bool = False,
+        bias_y: bool = False,
+        decompose: bool = False,
+        init: Callable = nn.init.zeros_
+    ) -> Triaffine:
+        super().__init__()
+
+        self.n_in = n_in
+        self.n_out = n_out
+        self.n_proj = n_proj
+        self.dropout = dropout
+        self.scale = scale
+        self.bias_x = bias_x
+        self.bias_y = bias_y
+        self.decompose = decompose
+        self.init = init
+
+        if n_proj is not None:
+            self.mlp_x = MLP(n_in, n_proj, dropout)
+            self.mlp_y = MLP(n_in, n_proj, dropout)
+            self.mlp_z = MLP(n_in, n_proj, dropout)
+        self.n_model = n_proj or n_in
+        if not decompose:
+            self.weight = nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x, self.n_model, self.n_model + bias_y))
+        else:
+            self.weight = nn.ParameterList((nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x)),
+                                            nn.Parameter(torch.Tensor(n_out, self.n_model)),
+                                            nn.Parameter(torch.Tensor(n_out, self.n_model + bias_y))))
+
+        self.reset_parameters()
+
+    def __repr__(self):
+        s = f"n_in={self.n_in}"
+        if self.n_out > 1:
+            s += f", n_out={self.n_out}"
+        if self.n_proj is not None:
+            s += f", n_proj={self.n_proj}"
+        if self.dropout > 0:
+            s += f", dropout={self.dropout}"
+        if self.scale != 0:
+            s += f", scale={self.scale}"
+        if self.bias_x:
+            s += f", bias_x={self.bias_x}"
+        if self.bias_y:
+            s += f", bias_y={self.bias_y}"
+        if self.decompose:
+            s += f", decompose={self.decompose}"
+        return f"{self.__class__.__name__}({s})"
+
+    def reset_parameters(self):
+        if self.decompose:
+            for i in self.weight:
+                self.init(i)
+        else:
+            self.init(self.weight)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        y: torch.Tensor,
+        z: torch.Tensor
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+            y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+            z (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+
+        Returns:
+            ~torch.Tensor:
+                A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len, seq_len]``.
+                If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
+        """
+
+        if hasattr(self, 'mlp_x'):
+            x, y, z = self.mlp_x(x), self.mlp_y(y), self.mlp_z(y)
+        if self.bias_x:
+            x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
+        if self.bias_y:
+            y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
+        # [batch_size, n_out, seq_len, seq_len, seq_len]
+        if self.decompose:
+            wx = torch.einsum('bxi,oi->box', x, self.weight[0])
+            wz = torch.einsum('bzk,ok->boz', z, self.weight[1])
+            wy = torch.einsum('byj,oj->boy', y, self.weight[2])
+            s = torch.einsum('box,boz,boy->bozxy', wx, wz, wy)
+        else:
+            w = torch.einsum('bzk,oikj->bozij', z, self.weight)
+            s = torch.einsum('bxi,bozij,byj->bozxy', x, w, y)
+        return s.squeeze(1) / self.n_in ** self.scale
diff --git a/tania_scripts/supar/modules/decoder.py b/tania_scripts/supar/modules/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e9f97f310628c589ab16520c288c1b02f11c050
--- /dev/null
+++ b/tania_scripts/supar/modules/decoder.py
@@ -0,0 +1,38 @@
+import torch.nn as nn
+from supar.modules.mlp import MLP
+import torch
+
+class DecoderLSTM(nn.Module):
+    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int, dropout: float, device: str):
+        super().__init__()
+
+        self.input_size, self.hidden_size = input_size, hidden_size
+        self.output_size = output_size
+        self.num_layers = num_layers
+
+        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
+                            num_layers=num_layers, bidirectional=False,
+                            dropout=dropout, batch_first=True)
+
+        self.mlp = MLP(n_in=hidden_size, n_out=output_size, dropout=dropout,
+                       activation=True)
+        self.device = device
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        r"""
+        :param x: torch.Tensor [batch_size, seq_len, input_size]
+        :returns torch.Tensor [batch_size, seq_len, output_size]
+        """
+        batch_size, seq_len, _ = x.shape
+
+        # LSTM forward pass
+        h0, c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device), \
+                 torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)
+        hn, cn = self.lstm(x, (h0, c0))
+
+        # MLP forward pass
+        output = self.mlp(hn.reshape(batch_size*seq_len, self.hidden_size))
+        return output.reshape(batch_size, seq_len, self.output_size)
+
+    def __repr__(self):
+        return f'DecoderLSTM(input_size={self.input_size}, hidden_size={self.hidden_size}, num_layers={self.num_layers}, output_size={self.output_size}'
diff --git a/tania_scripts/supar/modules/dropout.py b/tania_scripts/supar/modules/dropout.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e3c94705e5b02bb6e1a3df1fc8428377f40684a
--- /dev/null
+++ b/tania_scripts/supar/modules/dropout.py
@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import List
+
+import torch
+import torch.nn as nn
+
+
+class TokenDropout(nn.Module):
+    r"""
+    :class:`TokenDropout` seeks to randomly zero the vectors of some tokens with the probability of `p`.
+
+    Args:
+        p (float):
+            The probability of an element to be zeroed. Default: 0.5.
+
+    Examples:
+        >>> batch_size, seq_len, hidden_size = 1, 3, 5
+        >>> x = torch.ones(batch_size, seq_len, hidden_size)
+        >>> nn.Dropout()(x)
+        tensor([[[0., 2., 2., 0., 0.],
+                 [2., 2., 0., 2., 2.],
+                 [2., 2., 2., 2., 0.]]])
+        >>> TokenDropout()(x)
+        tensor([[[2., 2., 2., 2., 2.],
+                 [0., 0., 0., 0., 0.],
+                 [2., 2., 2., 2., 2.]]])
+    """
+
+    def __init__(self, p: float = 0.5) -> TokenDropout:
+        super().__init__()
+
+        self.p = p
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(p={self.p})"
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor):
+                A tensor of any shape.
+        Returns:
+            A tensor with the same shape as `x`.
+        """
+
+        if not self.training:
+            return x
+        return x * (x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) / (1 - self.p)).unsqueeze(-1)
+
+
+class SharedDropout(nn.Module):
+    r"""
+    :class:`SharedDropout` differs from the vanilla dropout strategy in that the dropout mask is shared across one dimension.
+
+    Args:
+        p (float):
+            The probability of an element to be zeroed. Default: 0.5.
+        batch_first (bool):
+            If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``.
+            Default: ``True``.
+
+    Examples:
+        >>> batch_size, seq_len, hidden_size = 1, 3, 5
+        >>> x = torch.ones(batch_size, seq_len, hidden_size)
+        >>> nn.Dropout()(x)
+        tensor([[[0., 2., 2., 0., 0.],
+                 [2., 2., 0., 2., 2.],
+                 [2., 2., 2., 2., 0.]]])
+        >>> SharedDropout()(x)
+        tensor([[[2., 0., 2., 0., 2.],
+                 [2., 0., 2., 0., 2.],
+                 [2., 0., 2., 0., 2.]]])
+    """
+
+    def __init__(self, p: float = 0.5, batch_first: bool = True) -> SharedDropout:
+        super().__init__()
+
+        self.p = p
+        self.batch_first = batch_first
+
+    def __repr__(self):
+        s = f"p={self.p}"
+        if self.batch_first:
+            s += f", batch_first={self.batch_first}"
+        return f"{self.__class__.__name__}({s})"
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor):
+                A tensor of any shape.
+        Returns:
+            A tensor with the same shape as `x`.
+        """
+
+        if not self.training:
+            return x
+        return x * self.get_mask(x[:, 0], self.p).unsqueeze(1) if self.batch_first else self.get_mask(x[0], self.p)
+
+    @staticmethod
+    def get_mask(x: torch.Tensor, p: float) -> torch.FloatTensor:
+        return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p)
+
+
+class IndependentDropout(nn.Module):
+    r"""
+    For :math:`N` tensors, they use different dropout masks respectively.
+    When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate,
+    and when all of them are dropped together, zeros are returned.
+
+    Args:
+        p (float):
+            The probability of an element to be zeroed. Default: 0.5.
+
+    Examples:
+        >>> batch_size, seq_len, hidden_size = 1, 3, 5
+        >>> x, y = torch.ones(batch_size, seq_len, hidden_size), torch.ones(batch_size, seq_len, hidden_size)
+        >>> x, y = IndependentDropout()(x, y)
+        >>> x
+        tensor([[[1., 1., 1., 1., 1.],
+                 [0., 0., 0., 0., 0.],
+                 [2., 2., 2., 2., 2.]]])
+        >>> y
+        tensor([[[1., 1., 1., 1., 1.],
+                 [2., 2., 2., 2., 2.],
+                 [0., 0., 0., 0., 0.]]])
+    """
+
+    def __init__(self, p: float = 0.5) -> IndependentDropout:
+        super().__init__()
+
+        self.p = p
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(p={self.p})"
+
+    def forward(self, *items: List[torch.Tensor]) -> List[torch.Tensor]:
+        r"""
+        Args:
+            items (List[~torch.Tensor]):
+                A list of tensors that have the same shape except the last dimension.
+        Returns:
+            A tensors are of the same shape as `items`.
+        """
+
+        if not self.training:
+            return items
+        masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items]
+        total = sum(masks)
+        scale = len(items) / total.max(torch.ones_like(total))
+        masks = [mask * scale for mask in masks]
+        return [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)]
diff --git a/tania_scripts/supar/modules/gnn.py b/tania_scripts/supar/modules/gnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f2108c20a02ed4a1eb18d468ab2084d365f2b43
--- /dev/null
+++ b/tania_scripts/supar/modules/gnn.py
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+
+class GraphConvolutionalNetwork(nn.Module):
+    r"""
+    Multiple GCN layers with layer normalization and residual connections, each executing the operator
+    from the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper
+
+    .. math::
+        \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
+        \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
+
+    where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops
+    and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
+
+    Its node-wise formulation is given by:
+
+    .. math::
+        \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in
+        \mathcal{N}(v) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j
+        \hat{d}_i}} \mathbf{x}_j
+
+    with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where
+    :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target
+    node :obj:`i` (default: :obj:`1.0`)
+
+    Args:
+        n_model (int):
+            The size of node feature vectors.
+        n_layers (int):
+            The number of GCN layers. Default: 1.
+        selfloop (bool):
+            If ``True``, adds self-loops to adjacent matrices. Default: ``True``.
+        dropout (float):
+            The probability of feature vector elements to be zeroed. Default: 0.
+        norm (bool):
+            If ``True``, adds a :class:`~torch.nn.LayerNorm` layer after each GCN layer. Default: ``True``.
+    """
+
+    def __init__(
+        self,
+        n_model: int,
+        n_layers: int = 1,
+        selfloop: bool = True,
+        dropout: float = 0.,
+        norm: bool = True
+    ) -> GraphConvolutionalNetwork:
+        super().__init__()
+
+        self.n_model = n_model
+        self.n_layers = n_layers
+        self.selfloop = selfloop
+        self.norm = norm
+
+        self.conv_layers = nn.ModuleList([
+            nn.Sequential(
+                GraphConv(n_model),
+                nn.LayerNorm([n_model]) if norm else nn.Identity()
+            )
+            for _ in range(n_layers)
+        ])
+        self.dropout = nn.Dropout(dropout)
+
+    def __repr__(self):
+        s = f"n_model={self.n_model}, n_layers={self.n_layers}"
+        if self.selfloop:
+            s += f", selfloop={self.selfloop}"
+        if self.dropout.p > 0:
+            s += f", dropout={self.dropout.p}"
+        if self.norm:
+            s += f", norm={self.norm}"
+        return f"{self.__class__.__name__}({s})"
+
+    def forward(self, x: torch.Tensor, adj: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor):
+                Node feature tensors of shape ``[batch_size, seq_len, n_model]``.
+            adj (~torch.Tensor):
+                Adjacent matrix of shape ``[batch_size, seq_len, seq_len]``.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask for covering the unpadded tokens in each chart.
+
+        Returns:
+            ~torch.Tensor:
+                Node feature tensors of shape ``[batch_size, seq_len, n_model]``.
+        """
+
+        if self.selfloop:
+            adj.diagonal(0, 1, 2).fill_(1.)
+        adj = adj.masked_fill(~(mask.unsqueeze(1) & mask.unsqueeze(2)), 0)
+        for conv, norm in self.conv_layers:
+            x = norm(x + self.dropout(conv(x, adj).relu()))
+        return x
+
+
+class GraphConv(nn.Module):
+
+    def __init__(self, n_model: int, bias: bool = True) -> GraphConv:
+        super().__init__()
+
+        self.n_model = n_model
+
+        self.linear = nn.Linear(n_model, n_model, bias=False)
+        self.bias = nn.Parameter(torch.zeros(n_model)) if bias else None
+
+    def __repr__(self):
+        s = f"n_model={self.n_model}"
+        if self.bias is not None:
+            s += ", bias=True"
+        return f"{self.__class__.__name__}({s})"
+
+    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor):
+                Node feature tensors of shape ``[batch_size, seq_len, n_model]``.
+            adj (~torch.Tensor):
+                Adjacent matrix of shape ``[batch_size, seq_len, seq_len]``.
+
+        Returns:
+            ~torch.Tensor:
+                Node feature tensors of shape ``[batch_size, seq_len, n_model]``.
+        """
+
+        x = self.linear(x)
+        x = torch.matmul(adj * (adj.sum(1, True) * adj.sum(2, True) + torch.finfo(adj.dtype).eps).pow(-0.5), x)
+        if self.bias is not None:
+            x = x + self.bias
+        return x
diff --git a/tania_scripts/supar/modules/lstm.py b/tania_scripts/supar/modules/lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..759f9c187f28ff6c8b662a6c6feee751198b83db
--- /dev/null
+++ b/tania_scripts/supar/modules/lstm.py
@@ -0,0 +1,271 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from supar.modules.dropout import SharedDropout
+from torch.nn.modules.rnn import apply_permutation
+from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence
+
+
+class CharLSTM(nn.Module):
+    r"""
+    CharLSTM aims to generate character-level embeddings for tokens.
+    It summarizes the information of characters in each token to an embedding using a LSTM layer.
+
+    Args:
+        n_char (int):
+            The number of characters.
+        n_embed (int):
+            The size of each embedding vector as input to LSTM.
+        n_hidden (int):
+            The size of each LSTM hidden state.
+        n_out (int):
+            The size of each output vector. Default: 0.
+            If 0, equals to the size of hidden states.
+        pad_index (int):
+            The index of the padding token in the vocabulary. Default: 0.
+        dropout (float):
+            The dropout ratio of CharLSTM hidden states. Default: 0.
+    """
+
+    def __init__(
+        self,
+        n_chars: int,
+        n_embed: int,
+        n_hidden: int,
+        n_out: int = 0,
+        pad_index: int = 0,
+        dropout: float = 0
+    ) -> CharLSTM:
+        super().__init__()
+
+        self.n_chars = n_chars
+        self.n_embed = n_embed
+        self.n_hidden = n_hidden
+        self.n_out = n_out or n_hidden
+        self.pad_index = pad_index
+
+        self.embed = nn.Embedding(num_embeddings=n_chars, embedding_dim=n_embed)
+        self.lstm = nn.LSTM(input_size=n_embed, hidden_size=n_hidden//2, batch_first=True, bidirectional=True)
+        self.dropout = nn.Dropout(p=dropout)
+        self.projection = nn.Linear(in_features=n_hidden, out_features=self.n_out) if n_hidden != self.n_out else nn.Identity()
+
+    def __repr__(self):
+        s = f"{self.n_chars}, {self.n_embed}"
+        if self.n_hidden != self.n_out:
+            s += f", n_hidden={self.n_hidden}"
+        s += f", n_out={self.n_out}, pad_index={self.pad_index}"
+        if self.dropout.p != 0:
+            s += f", dropout={self.dropout.p}"
+        return f"{self.__class__.__name__}({s})"
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor): ``[batch_size, seq_len, fix_len]``.
+                Characters of all tokens.
+                Each token holds no more than `fix_len` characters, and the excess is cut off directly.
+        Returns:
+            ~torch.Tensor:
+                The embeddings of shape ``[batch_size, seq_len, n_out]`` derived from the characters.
+        """
+
+        # [batch_size, seq_len, fix_len]
+        mask = x.ne(self.pad_index)
+        # [batch_size, seq_len]
+        lens = mask.sum(-1)
+        char_mask = lens.gt(0)
+
+        # [n, fix_len, n_embed]
+        x = self.embed(x[char_mask])
+        x = pack_padded_sequence(x, lens[char_mask].tolist(), True, False)
+        x, (h, _) = self.lstm(x)
+        # [n, fix_len, n_hidden]
+        h = self.dropout(torch.cat(torch.unbind(h), -1))
+        # [n, fix_len, n_out]
+        h = self.projection(h)
+        # [batch_size, seq_len, n_out]
+        return h.new_zeros(*lens.shape, self.n_out).masked_scatter_(char_mask.unsqueeze(-1), h)
+
+
+class VariationalLSTM(nn.Module):
+    r"""
+    VariationalLSTM :cite:`yarin-etal-2016-dropout` is an variant of the vanilla bidirectional LSTM
+    adopted by Biaffine Parser with the only difference of the dropout strategy.
+    It drops nodes in the LSTM layers (input and recurrent connections)
+    and applies the same dropout mask at every recurrent timesteps.
+
+    APIs are roughly the same as :class:`~torch.nn.LSTM` except that we only allows
+    :class:`~torch.nn.utils.rnn.PackedSequence` as input.
+
+    Args:
+        input_size (int):
+            The number of expected features in the input.
+        hidden_size (int):
+            The number of features in the hidden state `h`.
+        num_layers (int):
+            The number of recurrent layers. Default: 1.
+        bidirectional (bool):
+            If ``True``, becomes a bidirectional LSTM. Default: ``False``
+        dropout (float):
+            If non-zero, introduces a :class:`SharedDropout` layer on the outputs of each LSTM layer except the last layer.
+            Default: 0.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        hidden_size: int,
+        num_layers: int = 1,
+        bidirectional: bool = False,
+        dropout: float = .0
+    ) -> VariationalLSTM:
+        super().__init__()
+
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.bidirectional = bidirectional
+        self.dropout = dropout
+        self.num_directions = 1 + self.bidirectional
+
+        self.f_cells = nn.ModuleList()
+        if bidirectional:
+            self.b_cells = nn.ModuleList()
+        for _ in range(self.num_layers):
+            self.f_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size))
+            if bidirectional:
+                self.b_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size))
+            input_size = hidden_size * self.num_directions
+
+        self.reset_parameters()
+
+    def __repr__(self):
+        s = f"{self.input_size}, {self.hidden_size}"
+        if self.num_layers > 1:
+            s += f", num_layers={self.num_layers}"
+        if self.bidirectional:
+            s += f", bidirectional={self.bidirectional}"
+        if self.dropout > 0:
+            s += f", dropout={self.dropout}"
+        return f"{self.__class__.__name__}({s})"
+
+    def reset_parameters(self):
+        for param in self.parameters():
+            # apply orthogonal_ to weight
+            if len(param.shape) > 1:
+                nn.init.orthogonal_(param)
+            # apply zeros_ to bias
+            else:
+                nn.init.zeros_(param)
+
+    def permute_hidden(
+        self,
+        hx: Tuple[torch.Tensor, torch.Tensor],
+        permutation: torch.LongTensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if permutation is None:
+            return hx
+        h = apply_permutation(hx[0], permutation)
+        c = apply_permutation(hx[1], permutation)
+
+        return h, c
+
+    def layer_forward(
+        self,
+        x: List[torch.Tensor],
+        hx: Tuple[torch.Tensor, torch.Tensor],
+        cell: nn.LSTMCell,
+        batch_sizes: List[int],
+        reverse: bool = False
+    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+        hx_0 = hx_i = hx
+        hx_n, output = [], []
+        steps = reversed(range(len(x))) if reverse else range(len(x))
+        if self.training:
+            hid_mask = SharedDropout.get_mask(hx_0[0], self.dropout)
+
+        for t in steps:
+            last_batch_size, batch_size = len(hx_i[0]), batch_sizes[t]
+            if last_batch_size < batch_size:
+                hx_i = [torch.cat((h, ih[last_batch_size:batch_size])) for h, ih in zip(hx_i, hx_0)]
+            else:
+                hx_n.append([h[batch_size:] for h in hx_i])
+                hx_i = [h[:batch_size] for h in hx_i]
+            hx_i = [h for h in cell(x[t], hx_i)]
+            output.append(hx_i[0])
+            if self.training:
+                hx_i[0] = hx_i[0] * hid_mask[:batch_size]
+        if reverse:
+            hx_n = hx_i
+            output.reverse()
+        else:
+            hx_n.append(hx_i)
+            hx_n = [torch.cat(h) for h in zip(*reversed(hx_n))]
+        output = torch.cat(output)
+
+        return output, hx_n
+
+    def forward(
+        self,
+        sequence: PackedSequence,
+        hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+    ) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
+        r"""
+        Args:
+            sequence (~torch.nn.utils.rnn.PackedSequence):
+                A packed variable length sequence.
+            hx (~torch.Tensor, ~torch.Tensor):
+                A tuple composed of two tensors `h` and `c`.
+                `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial hidden state
+                for each element in the batch.
+                `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial cell state
+                for each element in the batch.
+                If `hx` is not provided, both `h` and `c` default to zero.
+                Default: ``None``.
+
+        Returns:
+            ~torch.nn.utils.rnn.PackedSequence, (~torch.Tensor, ~torch.Tensor):
+                The first is a packed variable length sequence.
+                The second is a tuple of tensors `h` and `c`.
+                `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the hidden state for `t=seq_len`.
+                Like output, the layers can be separated using ``h.view(num_layers, num_directions, batch_size, hidden_size)``
+                and similarly for c.
+                `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the cell state for `t=seq_len`.
+        """
+        x, batch_sizes = sequence.data, sequence.batch_sizes.tolist()
+        batch_size = batch_sizes[0]
+        h_n, c_n = [], []
+
+        if hx is None:
+            ih = x.new_zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size)
+            h, c = ih, ih
+        else:
+            h, c = self.permute_hidden(hx, sequence.sorted_indices)
+        h = h.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
+        c = c.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
+
+        for i in range(self.num_layers):
+            x = torch.split(x, batch_sizes)
+            if self.training:
+                mask = SharedDropout.get_mask(x[0], self.dropout)
+                x = [i * mask[:len(i)] for i in x]
+            x_i, (h_i, c_i) = self.layer_forward(x, (h[i, 0], c[i, 0]), self.f_cells[i], batch_sizes)
+            if self.bidirectional:
+                x_b, (h_b, c_b) = self.layer_forward(x, (h[i, 1], c[i, 1]), self.b_cells[i], batch_sizes, True)
+                x_i = torch.cat((x_i, x_b), -1)
+                h_i = torch.stack((h_i, h_b))
+                c_i = torch.stack((c_i, c_b))
+            x = x_i
+            h_n.append(h_i)
+            c_n.append(c_i)
+
+        x = PackedSequence(x, sequence.batch_sizes, sequence.sorted_indices, sequence.unsorted_indices)
+        hx = torch.cat(h_n, 0), torch.cat(c_n, 0)
+        hx = self.permute_hidden(hx, sequence.unsorted_indices)
+
+        return x, hx
diff --git a/tania_scripts/supar/modules/mlp.py b/tania_scripts/supar/modules/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26cdd983b70e80110a81253387f30c3c1c79ead
--- /dev/null
+++ b/tania_scripts/supar/modules/mlp.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+from supar.modules.dropout import SharedDropout
+
+
+class MLP(nn.Module):
+    r"""
+    Applies a linear transformation together with a non-linear activation to the incoming tensor:
+    :math:`y = \mathrm{Activation}(x A^T + b)`
+
+    Args:
+        n_in (~torch.Tensor):
+            The size of each input feature.
+        n_out (~torch.Tensor):
+            The size of each output feature.
+        dropout (float):
+            If non-zero, introduces a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0.
+        activation (bool):
+            Whether to use activations. Default: True.
+    """
+
+    def __init__(self, n_in: int, n_out: int, dropout: float = .0, activation: bool = True) -> MLP:
+        super().__init__()
+
+        self.n_in = n_in
+        self.n_out = n_out
+        self.linear = nn.Linear(n_in, n_out)
+        self.activation = nn.LeakyReLU(negative_slope=0.1) if activation else nn.Identity()
+        self.dropout = SharedDropout(p=dropout)
+
+        self.reset_parameters()
+
+    def __repr__(self):
+        s = f"n_in={self.n_in}, n_out={self.n_out}"
+        if self.dropout.p > 0:
+            s += f", dropout={self.dropout.p}"
+
+        return f"{self.__class__.__name__}({s})"
+
+    def reset_parameters(self):
+        nn.init.orthogonal_(self.linear.weight)
+        nn.init.zeros_(self.linear.bias)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        r"""
+        Args:
+            x (~torch.Tensor):
+                The size of each input feature is `n_in`.
+
+        Returns:
+            A tensor with the size of each output feature `n_out`.
+        """
+
+        x = self.linear(x)
+        x = self.activation(x)
+        x = self.dropout(x)
+
+        return x
diff --git a/tania_scripts/supar/modules/pretrained.py b/tania_scripts/supar/modules/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..4805c88474a42bd3fde1cab6c9fc12f0b8b03558
--- /dev/null
+++ b/tania_scripts/supar/modules/pretrained.py
@@ -0,0 +1,256 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from supar.utils.fn import pad
+from supar.utils.tokenizer import TransformerTokenizer
+
+
+class TransformerEmbedding(nn.Module):
+    r"""
+    Bidirectional transformer embeddings of words from various transformer architectures :cite:`devlin-etal-2019-bert`.
+
+    Args:
+        name (str):
+            Path or name of the pretrained models registered in `transformers`_, e.g., ``'bert-base-cased'``.
+        n_layers (int):
+            The number of BERT layers to use. If 0, uses all layers.
+        n_out (int):
+            The requested size of the embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
+        stride (int):
+            A sequence longer than max length will be splitted into several small pieces
+            with a window size of ``stride``. Default: 10.
+        pooling (str):
+            Pooling way to get from token piece embeddings to token embedding.
+            ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
+            Default: ``mean``.
+        pad_index (int):
+            The index of the padding token in BERT vocabulary. Default: 0.
+        mix_dropout (float):
+            The dropout ratio of BERT layers. This value will be passed into the :class:`ScalarMix` layer. Default: 0.
+        finetune (bool):
+            If ``True``, the model parameters will be updated together with the downstream task. Default: ``False``.
+
+    .. _transformers:
+        https://github.com/huggingface/transformers
+    """
+
+    def __init__(
+        self,
+        name: str,
+        n_layers: int,
+        n_out: int = 0,
+        stride: int = 256,
+        pooling: str = 'mean',
+        pad_index: int = 0,
+        mix_dropout: float = .0,
+        finetune: bool = False
+    ) -> TransformerEmbedding:
+        super().__init__()
+
+        from transformers import AutoModel
+        try:
+            self.model = AutoModel.from_pretrained(name, output_hidden_states=True, local_files_only=True)
+        except Exception:
+            self.model = AutoModel.from_pretrained(name, output_hidden_states=True, local_files_only=False)
+        self.model = self.model.requires_grad_(finetune)
+        self.tokenizer = TransformerTokenizer(name)
+
+        self.name = name
+        self.n_layers = n_layers or self.model.config.num_hidden_layers
+        self.hidden_size = self.model.config.hidden_size
+        self.n_out = n_out or self.hidden_size
+        self.pooling = pooling
+        self.pad_index = pad_index
+        self.mix_dropout = mix_dropout
+        self.finetune = finetune
+        try:
+            self.max_len = int(max(0, self.model.config.max_position_embeddings) or 1e12) - 2
+        except:
+            self.max_len = 512
+        self.stride = min(stride, self.max_len)
+
+        self.scalar_mix = ScalarMix(self.n_layers, mix_dropout)
+        self.projection = nn.Linear(self.hidden_size, self.n_out, False) if self.hidden_size != n_out else nn.Identity()
+
+    def __repr__(self):
+        s = f"{self.name}, n_layers={self.n_layers}, n_out={self.n_out}, "
+        s += f"stride={self.stride}, pooling={self.pooling}, pad_index={self.pad_index}"
+        if self.mix_dropout > 0:
+            s += f", mix_dropout={self.mix_dropout}"
+        if self.finetune:
+            s += f", finetune={self.finetune}"
+        return f"{self.__class__.__name__}({s})"
+
+    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
+        r"""
+        Args:
+            tokens (~torch.Tensor): ``[batch_size, seq_len, fix_len]``.
+
+        Returns:
+            ~torch.Tensor:
+                Contextualized token embeddings of shape ``[batch_size, seq_len, n_out]``.
+        """
+
+        mask = tokens.ne(self.pad_index)
+        lens = mask.sum((1, 2))
+        # [batch_size, n_subwords]
+        tokens = pad(tokens[mask].split(lens.tolist()), self.pad_index, padding_side=self.tokenizer.padding_side)
+        token_mask = pad(mask[mask].split(lens.tolist()), 0, padding_side=self.tokenizer.padding_side)
+
+        # return the hidden states of all layers
+        x = self.model(tokens[:, :self.max_len], attention_mask=token_mask[:, :self.max_len].float())[-1]
+        # [batch_size, max_len, hidden_size]
+        x = self.scalar_mix(x[-self.n_layers:])
+        # [batch_size, n_subwords, hidden_size]
+        for i in range(self.stride, (tokens.shape[1]-self.max_len+self.stride-1)//self.stride*self.stride+1, self.stride):
+            part = self.model(tokens[:, i:i+self.max_len], attention_mask=token_mask[:, i:i+self.max_len].float())[-1]
+            x = torch.cat((x, self.scalar_mix(part[-self.n_layers:])[:, self.max_len-self.stride:]), 1)
+        # [batch_size, seq_len]
+        lens = mask.sum(-1)
+        lens = lens.masked_fill_(lens.eq(0), 1)
+        # [batch_size, seq_len, fix_len, hidden_size]
+        x = x.new_zeros(*mask.shape, self.hidden_size).masked_scatter_(mask.unsqueeze(-1), x[token_mask])
+        # [batch_size, seq_len, hidden_size]
+        if self.pooling == 'first':
+            x = x[:, :, 0]
+        elif self.pooling == 'last':
+            x = x.gather(2, (lens-1).unsqueeze(-1).repeat(1, 1, self.hidden_size).unsqueeze(2)).squeeze(2)
+        elif self.pooling == 'mean':
+            x = x.sum(2) / lens.unsqueeze(-1)
+        else:
+            raise RuntimeError(f'Unsupported pooling method "{self.pooling}"!')
+        return self.projection(x)
+
+
+class ELMoEmbedding(nn.Module):
+    r"""
+    Contextual word embeddings using word-level bidirectional LM :cite:`peters-etal-2018-deep`.
+
+    Args:
+        name (str):
+            The name of the pretrained ELMo registered in `OPTION` and `WEIGHT`. Default: ``'original_5b'``.
+        bos_eos (Tuple[bool]):
+            A tuple of two boolean values indicating whether to keep start/end boundaries of sentence outputs.
+            Default: ``(True, True)``.
+        n_out (int):
+            The requested size of the embeddings. If 0, uses the default size of ELMo outputs. Default: 0.
+        dropout (float):
+            The dropout ratio for the ELMo layer. Default: 0.
+        finetune (bool):
+            If ``True``, the model parameters will be updated together with the downstream task. Default: ``False``.
+    """
+
+    OPTION = {
+        'small': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json',  # noqa
+        'medium': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json',  # noqa
+        'original': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json',  # noqa
+        'original_5b': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json',  # noqa
+    }
+    WEIGHT = {
+        'small': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5',  # noqa
+        'medium': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5',  # noqa
+        'original': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5',  # noqa
+        'original_5b': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5',  # noqa
+    }
+
+    def __init__(
+        self,
+        name: str = 'original_5b',
+        bos_eos: Tuple[bool, bool] = (True, True),
+        n_out: int = 0,
+        dropout: float = 0.5,
+        finetune: bool = False
+    ) -> ELMoEmbedding:
+        super().__init__()
+
+        from allennlp.modules import Elmo
+
+        self.elmo = Elmo(options_file=self.OPTION[name],
+                         weight_file=self.WEIGHT[name],
+                         num_output_representations=1,
+                         dropout=dropout,
+                         finetune=finetune,
+                         keep_sentence_boundaries=True)
+
+        self.name = name
+        self.bos_eos = bos_eos
+        self.hidden_size = self.elmo.get_output_dim()
+        self.n_out = n_out or self.hidden_size
+        self.dropout = dropout
+        self.finetune = finetune
+
+        self.projection = nn.Linear(self.hidden_size, self.n_out, False) if self.hidden_size != n_out else nn.Identity()
+
+    def __repr__(self):
+        s = f"{self.name}, n_out={self.n_out}"
+        if self.dropout > 0:
+            s += f", dropout={self.dropout}"
+        if self.finetune:
+            s += f", finetune={self.finetune}"
+        return f"{self.__class__.__name__}({s})"
+
+    def forward(self, chars: torch.LongTensor) -> torch.Tensor:
+        r"""
+        Args:
+            chars (~torch.LongTensor): ``[batch_size, seq_len, fix_len]``.
+
+        Returns:
+            ~torch.Tensor:
+                ELMo embeddings of shape ``[batch_size, seq_len, n_out]``.
+        """
+
+        x = self.projection(self.elmo(chars)['elmo_representations'][0])
+        if not self.bos_eos[0]:
+            x = x[:, 1:]
+        if not self.bos_eos[1]:
+            x = x[:, :-1]
+        return x
+
+
+class ScalarMix(nn.Module):
+    r"""
+    Computes a parameterized scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)`
+    where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters.
+
+    Args:
+        n_layers (int):
+            The number of layers to be mixed, i.e., :math:`N`.
+        dropout (float):
+            The dropout ratio of the layer weights.
+            If dropout > 0, then for each scalar weight, adjusts its softmax weight mass to 0
+            with the dropout probability (i.e., setting the unnormalized weight to -inf).
+            This effectively redistributes the dropped probability mass to all other weights.
+            Default: 0.
+    """
+
+    def __init__(self, n_layers: int, dropout: float = .0) -> ScalarMix:
+        super().__init__()
+
+        self.n_layers = n_layers
+
+        self.weights = nn.Parameter(torch.zeros(n_layers))
+        self.gamma = nn.Parameter(torch.tensor([1.0]))
+        self.dropout = nn.Dropout(dropout)
+
+    def __repr__(self):
+        s = f"n_layers={self.n_layers}"
+        if self.dropout.p > 0:
+            s += f", dropout={self.dropout.p}"
+        return f"{self.__class__.__name__}({s})"
+
+    def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
+        r"""
+        Args:
+            tensors (List[~torch.Tensor]):
+                :math:`N` tensors to be mixed.
+
+        Returns:
+            The mixture of :math:`N` tensors.
+        """
+
+        return self.gamma * sum(w * h for w, h in zip(self.dropout(self.weights.softmax(-1)), tensors))
diff --git a/tania_scripts/supar/modules/transformer.py b/tania_scripts/supar/modules/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7c967fee3e99bdbe4a81f8542267a93465dca7
--- /dev/null
+++ b/tania_scripts/supar/modules/transformer.py
@@ -0,0 +1,585 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import copy
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TransformerWordEmbedding(nn.Module):
+
+    def __init__(
+        self,
+        n_vocab: int = None,
+        n_embed: int = None,
+        embed_scale: Optional[int] = None,
+        max_len: Optional[int] = 512,
+        pos: Optional[str] = None,
+        pad_index: Optional[int] = None,
+    ) -> TransformerWordEmbedding:
+        super(TransformerWordEmbedding, self).__init__()
+
+        self.embed = nn.Embedding(num_embeddings=n_vocab,
+                                  embedding_dim=n_embed)
+        if pos is None:
+            self.pos_embed = nn.Identity()
+        elif pos == 'sinusoid':
+            self.pos_embed = SinusoidPositionalEmbedding()
+        elif pos == 'sinusoid_relative':
+            self.pos_embed = SinusoidRelativePositionalEmbedding()
+        elif pos == 'learnable':
+            self.pos_embed = PositionalEmbedding(max_len=max_len)
+        elif pos == 'learnable_relative':
+            self.pos_embed = RelativePositionalEmbedding(max_len=max_len)
+        else:
+            raise ValueError(f'Unknown positional embedding type {pos}')
+
+        self.n_vocab = n_vocab
+        self.n_embed = n_embed
+        self.embed_scale = embed_scale or n_embed ** 0.5
+        self.max_len = max_len
+        self.pos = pos
+        self.pad_index = pad_index
+
+        self.reset_parameters()
+
+    def __repr__(self):
+        s = self.__class__.__name__ + '('
+        s += f"{self.n_vocab}, {self.n_embed}"
+        if self.embed_scale is not None:
+            s += f", embed_scale={self.embed_scale:.2f}"
+        if self.max_len is not None:
+            s += f", max_len={self.max_len}"
+        if self.pos is not None:
+            s += f", pos={self.pos}"
+        if self.pad_index is not None:
+            s += f", pad_index={self.pad_index}"
+        s += ')'
+        return s
+
+    def reset_parameters(self):
+        nn.init.normal_(self.embed.weight, 0, self.n_embed ** -0.5)
+        if self.pad_index is not None:
+            nn.init.zeros_(self.embed.weight[self.pad_index])
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.embed(x)
+        if self.embed_scale:
+            x = x * self.embed_scale
+        if self.pos is not None:
+            x = x + self.pos_embed(x)
+        return x
+
+
+class TransformerEncoder(nn.Module):
+
+    def __init__(
+        self,
+        layer: nn.Module,
+        n_layers: int = 6,
+        n_model: int = 1024,
+        pre_norm: bool = False,
+    ) -> TransformerEncoder:
+        super(TransformerEncoder, self).__init__()
+
+        self.n_layers = n_layers
+        self.n_model = n_model
+        self.pre_norm = pre_norm
+
+        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
+        self.norm = nn.LayerNorm(n_model) if self.pre_norm else None
+
+    def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        x = x.transpose(0, 1)
+        for layer in self.layers:
+            x = layer(x, mask)
+        if self.pre_norm:
+            x = self.norm(x)
+        return x.transpose(0, 1)
+
+
+class TransformerDecoder(nn.Module):
+
+    def __init__(
+        self,
+        layer: nn.Module,
+        n_layers: int = 6,
+        n_model: int = 1024,
+        pre_norm: bool = False,
+    ) -> TransformerDecoder:
+        super(TransformerDecoder, self).__init__()
+
+        self.n_layers = n_layers
+        self.n_model = n_model
+        self.pre_norm = pre_norm
+
+        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
+        self.norm = nn.LayerNorm(n_model) if self.pre_norm else None
+
+    def forward(
+        self,
+        x_tgt: torch.Tensor,
+        x_src: torch.Tensor,
+        tgt_mask: torch.BoolTensor,
+        src_mask: torch.BoolTensor,
+        attn_mask: Optional[torch.BoolTensor] = None
+    ) -> torch.Tensor:
+        x_tgt, x_src = x_tgt.transpose(0, 1), x_src.transpose(0, 1)
+        for layer in self.layers:
+            x_tgt = layer(x_tgt=x_tgt,
+                          x_src=x_src,
+                          tgt_mask=tgt_mask,
+                          src_mask=src_mask,
+                          attn_mask=attn_mask)
+        if self.pre_norm:
+            x_tgt = self.norm(x_tgt)
+        return x_tgt.transpose(0, 1)
+
+
+class TransformerEncoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        n_heads: int = 8,
+        n_model: int = 1024,
+        n_inner: int = 2048,
+        activation: str = 'relu',
+        bias: bool = True,
+        pre_norm: bool = False,
+        attn_dropout: float = 0.1,
+        ffn_dropout: float = 0.1,
+        dropout: float = 0.1
+    ) -> TransformerEncoderLayer:
+        super(TransformerEncoderLayer, self).__init__()
+
+        self.attn = MultiHeadAttention(n_heads=n_heads,
+                                       n_model=n_model,
+                                       n_embed=n_model//n_heads,
+                                       dropout=attn_dropout,
+                                       bias=bias)
+        self.attn_norm = nn.LayerNorm(n_model)
+        self.ffn = PositionwiseFeedForward(n_model=n_model,
+                                           n_inner=n_inner,
+                                           activation=activation,
+                                           dropout=ffn_dropout)
+        self.ffn_norm = nn.LayerNorm(n_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.pre_norm = pre_norm
+
+    def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        if self.pre_norm:
+            n = self.attn_norm(x)
+            x = x + self.dropout(self.attn(n, n, n, mask))
+            n = self.ffn_norm(x)
+            x = x + self.dropout(self.ffn(n))
+        else:
+            x = self.attn_norm(x + self.dropout(self.attn(x, x, x, mask)))
+            x = self.ffn_norm(x + self.dropout(self.ffn(x)))
+        return x
+
+
+class RelativePositionTransformerEncoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        n_heads: int = 8,
+        n_model: int = 1024,
+        n_inner: int = 2048,
+        activation: str = 'relu',
+        pre_norm: bool = False,
+        attn_dropout: float = 0.1,
+        ffn_dropout: float = 0.1,
+        dropout: float = 0.1
+    ) -> RelativePositionTransformerEncoderLayer:
+        super(RelativePositionTransformerEncoderLayer, self).__init__()
+
+        self.attn = RelativePositionMultiHeadAttention(n_heads=n_heads,
+                                                       n_model=n_model,
+                                                       n_embed=n_model//n_heads,
+                                                       dropout=attn_dropout)
+        self.attn_norm = nn.LayerNorm(n_model)
+        self.ffn = PositionwiseFeedForward(n_model=n_model,
+                                           n_inner=n_inner,
+                                           activation=activation,
+                                           dropout=ffn_dropout)
+        self.ffn_norm = nn.LayerNorm(n_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.pre_norm = pre_norm
+
+    def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        if self.pre_norm:
+            n = self.attn_norm(x)
+            x = x + self.dropout(self.attn(n, n, n, mask))
+            n = self.ffn_norm(x)
+            x = x + self.dropout(self.ffn(n))
+        else:
+            x = self.attn_norm(x + self.dropout(self.attn(x, x, x, mask)))
+            x = self.ffn_norm(x + self.dropout(self.ffn(x)))
+        return x
+
+
+class TransformerDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        n_heads: int = 8,
+        n_model: int = 1024,
+        n_inner: int = 2048,
+        activation: str = 'relu',
+        bias: bool = True,
+        pre_norm: bool = False,
+        attn_dropout: float = 0.1,
+        ffn_dropout: float = 0.1,
+        dropout: float = 0.1
+    ) -> TransformerDecoderLayer:
+        super(TransformerDecoderLayer, self).__init__()
+
+        self.self_attn = MultiHeadAttention(n_heads=n_heads,
+                                            n_model=n_model,
+                                            n_embed=n_model//n_heads,
+                                            dropout=attn_dropout,
+                                            bias=bias)
+        self.self_attn_norm = nn.LayerNorm(n_model)
+        self.mha_attn = MultiHeadAttention(n_heads=n_heads,
+                                           n_model=n_model,
+                                           n_embed=n_model//n_heads,
+                                           dropout=attn_dropout,
+                                           bias=bias)
+        self.mha_attn_norm = nn.LayerNorm(n_model)
+        self.ffn = PositionwiseFeedForward(n_model=n_model,
+                                           n_inner=n_inner,
+                                           activation=activation,
+                                           dropout=ffn_dropout)
+        self.ffn_norm = nn.LayerNorm(n_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.pre_norm = pre_norm
+
+    def forward(
+        self,
+        x_tgt: torch.Tensor,
+        x_src: torch.Tensor,
+        tgt_mask: torch.BoolTensor,
+        src_mask: torch.BoolTensor,
+        attn_mask: Optional[torch.BoolTensor] = None
+    ) -> torch.Tensor:
+        if self.pre_norm:
+            n_tgt = self.self_attn_norm(x_tgt)
+            x_tgt = x_tgt + self.dropout(self.self_attn(n_tgt, n_tgt, n_tgt, tgt_mask, attn_mask))
+            n_tgt = self.mha_attn_norm(x_tgt)
+            x_tgt = x_tgt + self.dropout(self.mha_attn(n_tgt, x_src, x_src, src_mask))
+            n_tgt = self.ffn_norm(x_tgt)
+            x_tgt = x_tgt + self.dropout(self.ffn(x_tgt))
+        else:
+            x_tgt = self.self_attn_norm(x_tgt + self.dropout(self.self_attn(x_tgt, x_tgt, x_tgt, tgt_mask, attn_mask)))
+            x_tgt = self.mha_attn_norm(x_tgt + self.dropout(self.mha_attn(x_tgt, x_src, x_src, src_mask)))
+            x_tgt = self.ffn_norm(x_tgt + self.dropout(self.ffn(x_tgt)))
+        return x_tgt
+
+
+class RelativePositionTransformerDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        n_heads: int = 8,
+        n_model: int = 1024,
+        n_inner: int = 2048,
+        activation: str = 'relu',
+        pre_norm: bool = False,
+        attn_dropout: float = 0.1,
+        ffn_dropout: float = 0.1,
+        dropout: float = 0.1
+    ) -> RelativePositionTransformerDecoderLayer:
+        super(RelativePositionTransformerDecoderLayer, self).__init__()
+
+        self.self_attn = RelativePositionMultiHeadAttention(n_heads=n_heads,
+                                                            n_model=n_model,
+                                                            n_embed=n_model//n_heads,
+                                                            dropout=attn_dropout)
+        self.self_attn_norm = nn.LayerNorm(n_model)
+        self.mha_attn = RelativePositionMultiHeadAttention(n_heads=n_heads,
+                                                           n_model=n_model,
+                                                           n_embed=n_model//n_heads,
+                                                           dropout=attn_dropout)
+        self.mha_attn_norm = nn.LayerNorm(n_model)
+        self.ffn = PositionwiseFeedForward(n_model=n_model,
+                                           n_inner=n_inner,
+                                           activation=activation,
+                                           dropout=ffn_dropout)
+        self.ffn_norm = nn.LayerNorm(n_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.pre_norm = pre_norm
+
+    def forward(
+        self,
+        x_tgt: torch.Tensor,
+        x_src: torch.Tensor,
+        tgt_mask: torch.BoolTensor,
+        src_mask: torch.BoolTensor,
+        attn_mask: Optional[torch.BoolTensor] = None
+    ) -> torch.Tensor:
+        if self.pre_norm:
+            n_tgt = self.self_attn_norm(x_tgt)
+            x_tgt = x_tgt + self.dropout(self.self_attn(n_tgt, n_tgt, n_tgt, tgt_mask, attn_mask))
+            n_tgt = self.mha_attn_norm(x_tgt)
+            x_tgt = x_tgt + self.dropout(self.mha_attn(n_tgt, x_src, x_src, src_mask))
+            n_tgt = self.ffn_norm(x_tgt)
+            x_tgt = x_tgt + self.dropout(self.ffn(x_tgt))
+        else:
+            x_tgt = self.self_attn_norm(x_tgt + self.dropout(self.self_attn(x_tgt, x_tgt, x_tgt, tgt_mask, attn_mask)))
+            x_tgt = self.mha_attn_norm(x_tgt + self.dropout(self.mha_attn(x_tgt, x_src, x_src, src_mask)))
+            x_tgt = self.ffn_norm(x_tgt + self.dropout(self.ffn(x_tgt)))
+        return x_tgt
+
+
+class MultiHeadAttention(nn.Module):
+
+    def __init__(
+        self,
+        n_heads: int = 8,
+        n_model: int = 1024,
+        n_embed: int = 128,
+        dropout: float = 0.1,
+        bias: bool = True,
+        attn: bool = False,
+    ) -> MultiHeadAttention:
+        super(MultiHeadAttention, self).__init__()
+
+        self.n_heads = n_heads
+        self.n_model = n_model
+        self.n_embed = n_embed
+        self.scale = n_embed**0.5
+
+        self.wq = nn.Linear(n_model, n_heads * n_embed, bias=bias)
+        self.wk = nn.Linear(n_model, n_heads * n_embed, bias=bias)
+        self.wv = nn.Linear(n_model, n_heads * n_embed, bias=bias)
+        self.wo = nn.Linear(n_heads * n_embed, n_model, bias=bias)
+        self.dropout = nn.Dropout(dropout)
+
+        self.bias = bias
+        self.attn = attn
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        # borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py
+        nn.init.xavier_uniform_(self.wq.weight, 2 ** -0.5)
+        nn.init.xavier_uniform_(self.wk.weight, 2 ** -0.5)
+        nn.init.xavier_uniform_(self.wv.weight, 2 ** -0.5)
+        nn.init.xavier_uniform_(self.wo.weight)
+
+    def forward(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        mask: torch.BoolTensor,
+        attn_mask: Optional[torch.BoolTensor] = None
+    ) -> torch.Tensor:
+        batch_size, _ = mask.shape
+        # [seq_len, batch_size * n_heads, n_embed]
+        q = self.wq(q).view(-1, batch_size * self.n_heads, self.n_embed)
+        # [src_len, batch_size * n_heads, n_embed]
+        k = self.wk(k).view(-1, batch_size * self.n_heads, self.n_embed)
+        v = self.wv(v).view(-1, batch_size * self.n_heads, self.n_embed)
+
+        mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:])
+        # [batch_size * n_heads, seq_len, src_len]
+        if attn_mask is not None:
+            mask = mask & attn_mask
+        # [batch_size * n_heads, seq_len, src_len]
+        attn = torch.bmm(q.transpose(0, 1) / self.scale, k.movedim((0, 1), (2, 0)))
+        attn = torch.softmax(attn + torch.where(mask, 0., float('-inf')), -1)
+        attn = self.dropout(attn)
+        # [seq_len, batch_size * n_heads, n_embed]
+        x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1)
+        # [seq_len, batch_size, n_model]
+        x = self.wo(x.reshape(-1, batch_size, self.n_heads * self.n_embed))
+
+        return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x
+
+
+class RelativePositionMultiHeadAttention(nn.Module):
+
+    def __init__(
+        self,
+        n_heads: int = 8,
+        n_model: int = 1024,
+        n_embed: int = 128,
+        dropout: float = 0.1,
+        attn: bool = False
+    ) -> RelativePositionMultiHeadAttention:
+        super(RelativePositionMultiHeadAttention, self).__init__()
+
+        self.n_heads = n_heads
+        self.n_model = n_model
+        self.n_embed = n_embed
+        self.scale = n_embed**0.5
+
+        self.pos_embed = RelativePositionalEmbedding(n_model=n_embed)
+        self.wq = nn.Parameter(torch.zeros(n_model, n_heads * n_embed))
+        self.wk = nn.Parameter(torch.zeros(n_model, n_heads * n_embed))
+        self.wv = nn.Parameter(torch.zeros(n_model, n_heads * n_embed))
+        self.wo = nn.Parameter(torch.zeros(n_heads * n_embed, n_model))
+        self.bu = nn.Parameter(torch.zeros(n_heads, n_embed))
+        self.bv = nn.Parameter(torch.zeros(n_heads, n_embed))
+        self.dropout = nn.Dropout(dropout)
+
+        self.attn = attn
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        # borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py
+        nn.init.xavier_uniform_(self.wq, 2 ** -0.5)
+        nn.init.xavier_uniform_(self.wk, 2 ** -0.5)
+        nn.init.xavier_uniform_(self.wv, 2 ** -0.5)
+        nn.init.xavier_uniform_(self.wo)
+
+    def forward(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        mask: torch.BoolTensor,
+        attn_mask: Optional[torch.BoolTensor] = None
+    ) -> torch.Tensor:
+        batch_size, _ = mask.shape
+        # [seq_len, batch_size, n_heads, n_embed]
+        q = F.linear(q, self.wq).view(-1, batch_size, self.n_heads, self.n_embed)
+        # [src_len, batch_size * n_heads, n_embed]
+        k = F.linear(k, self.wk).view(-1, batch_size * self.n_heads, self.n_embed)
+        v = F.linear(v, self.wv).view(-1, batch_size * self.n_heads, self.n_embed)
+        # [seq_len, src_len, n_embed]
+        p = self.pos_embed(q[:, 0, 0], k[:, 0])
+        # [seq_len, batch_size * n_heads, n_embed]
+        qu, qv = (q + self.bu).view(-1, *k.shape[1:]), (q + self.bv).view(-1, *k.shape[1:])
+
+        mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:])
+        if attn_mask is not None:
+            mask = mask & attn_mask
+        # [batch_size * n_heads, seq_len, src_len]
+        attn = torch.bmm(qu.transpose(0, 1), k.movedim((0, 1), (2, 0)))
+        attn = attn + torch.matmul(qv.transpose(0, 1).unsqueeze(2), p.transpose(1, 2)).squeeze(2)
+        attn = torch.softmax(attn / self.scale + torch.where(mask, 0., float('-inf')), -1)
+        attn = self.dropout(attn)
+        # [seq_len, batch_size * n_heads, n_embed]
+        x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1)
+        # [seq_len, batch_size, n_model]
+        x = F.linear(x.reshape(-1, batch_size, self.n_heads * self.n_embed), self.wo)
+
+        return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x
+
+
+class PositionwiseFeedForward(nn.Module):
+
+    def __init__(
+        self,
+        n_model: int = 1024,
+        n_inner: int = 2048,
+        activation: str = 'relu',
+        dropout: float = 0.1
+    ) -> PositionwiseFeedForward:
+        super(PositionwiseFeedForward, self).__init__()
+
+        self.w1 = nn.Linear(n_model, n_inner)
+        self.activation = nn.ReLU() if activation == 'relu' else nn.GELU()
+        self.dropout = nn.Dropout(dropout)
+        self.w2 = nn.Linear(n_inner, n_model)
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        nn.init.xavier_uniform_(self.w1.weight)
+        nn.init.xavier_uniform_(self.w2.weight)
+        nn.init.zeros_(self.w1.bias)
+        nn.init.zeros_(self.w2.bias)
+
+    def forward(self, x):
+        x = self.w1(x)
+        x = self.activation(x)
+        x = self.dropout(x)
+        x = self.w2(x)
+
+        return x
+
+
+class PositionalEmbedding(nn.Module):
+
+    def __init__(
+        self,
+        n_model: int = 1024,
+        max_len: int = 1024
+    ) -> PositionalEmbedding:
+        super().__init__()
+
+        self.embed = nn.Embedding(max_len, n_model)
+
+        self.reset_parameters()
+
+    @torch.no_grad()
+    def reset_parameters(self):
+        w = self.embed.weight
+        max_len, n_model = w.shape
+        w = w.new_tensor(range(max_len)).unsqueeze(-1)
+        w = w / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
+        w[:, 0::2], w[:, 1::2] = w[:, 0::2].sin(), w[:, 1::2].cos()
+        self.embed.weight.copy_(w)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.embed(x.new_tensor(range(x.shape[1])).long())
+
+
+class RelativePositionalEmbedding(nn.Module):
+
+    def __init__(
+        self,
+        n_model: int = 1024,
+        max_len: int = 1024
+    ) -> RelativePositionalEmbedding:
+        super().__init__()
+
+        self.embed = nn.Embedding(max_len, n_model)
+
+        self.reset_parameters()
+
+    @torch.no_grad()
+    def reset_parameters(self):
+        w = self.embed.weight
+        max_len, n_model = w.shape
+        pos = torch.cat((w.new_tensor(range(-max_len//2, 0)), w.new_tensor(range(max_len//2))))
+        w = pos.unsqueeze(-1) / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
+        w[:, 0::2], w[:, 1::2] = w[:, 0::2].sin(), w[:, 1::2].cos()
+        self.embed.weight.copy_(w)
+
+    def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
+        offset = sum(divmod(self.embed.weight.shape[0], 2))
+        return self.embed((k.new_tensor(range(k.shape[0])) - q.new_tensor(range(q.shape[0])).unsqueeze(-1)).long() + offset)
+
+
+class SinusoidPositionalEmbedding(nn.Module):
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        seq_len, n_model = x[0].shape
+        pos = x.new_tensor(range(seq_len)).unsqueeze(-1)
+        pos = pos / 10000 ** (x.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
+        pos[:, 0::2], pos[:, 1::2] = pos[:, 0::2].sin(), pos[:, 1::2].cos()
+        return pos
+
+
+class SinusoidRelativePositionalEmbedding(nn.Module):
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        seq_len, n_model = x[0].shape
+        pos = x.new_tensor(range(seq_len))
+        pos = (pos - pos.unsqueeze(-1)).unsqueeze(-1)
+        pos = pos / 10000 ** (x.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
+        pos[..., 0::2], pos[..., 1::2] = pos[..., 0::2].sin(), pos[..., 1::2].cos()
+        return pos
diff --git a/tania_scripts/supar/parser.py b/tania_scripts/supar/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1c372e52475d87780f414396606f71a7494653b
--- /dev/null
+++ b/tania_scripts/supar/parser.py
@@ -0,0 +1,620 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import contextlib
+import os
+import shutil
+import sys
+import tempfile
+import pickle
+from contextlib import contextmanager
+from datetime import datetime, timedelta
+from typing import Any, Iterable, Union
+
+import dill
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.cuda.amp import GradScaler
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler
+
+import supar
+from supar.utils import Config, Dataset
+from supar.utils.field import Field
+from supar.utils.fn import download, get_rng_state, set_rng_state
+from supar.utils.logging import get_logger, init_logger, progress_bar
+from supar.utils.metric import Metric
+from supar.utils.optim import InverseSquareRootLR, LinearLR
+from supar.utils.parallel import DistributedDataParallel as DDP
+from supar.utils.parallel import gather, is_dist, is_master, reduce
+from supar.utils.transform import Batch
+
+logger = get_logger(__name__)
+
+
+class Parser(object):
+
+    NAME = None
+    MODEL = None
+
+    def __init__(self, args, model, transform):
+        self.args = args
+        self.model = model
+        self.transform = transform
+
+    @property
+    def device(self):
+        return 'cuda' if torch.cuda.is_available() else 'cpu'
+
+    @property
+    def sync_grad(self):
+        return self.step % self.args.update_steps == 0 or self.step % self.n_batches == 0
+
+    @contextmanager
+    def sync(self):
+        context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext')
+        if is_dist() and not self.sync_grad:
+            context = self.model.no_sync
+        with context():
+            yield
+
+    @contextmanager
+    def join(self):
+        context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext')
+        if not is_dist():
+            with context():
+                yield
+        elif self.model.training:
+            with self.model.join():
+                yield
+        else:
+            try:
+                dist_model = self.model
+                # https://github.com/pytorch/pytorch/issues/54059
+                if hasattr(self.model, 'module'):
+                    self.model = self.model.module
+                yield
+            finally:
+                self.model = dist_model
+
+    def train(
+        self,
+        train: Union[str, Iterable],
+        dev: Union[str, Iterable],
+        test: Union[str, Iterable],
+        epochs: int,
+        patience: int,
+        batch_size: int = 5000,
+        update_steps: int = 1,
+        buckets: int = 32,
+        workers: int = 0,
+        clip: float = 5.0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ) -> None:
+        r"""
+        Args:
+            train/dev/test (Union[str, Iterable]):
+                Filenames of the train/dev/test datasets.
+            epochs (int):
+                The number of training iterations.
+            patience (int):
+                The number of consecutive iterations after which the training process would be early stopped if no improvement.
+            batch_size (int):
+                The number of tokens in each batch. Default: 5000.
+            update_steps (int):
+                Gradient accumulation steps. Default: 1.
+            buckets (int):
+                The number of buckets that sentences are assigned to. Default: 32.
+            workers (int):
+                The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
+            clip (float):
+                Clips gradient of an iterable of parameters at specified value. Default: 5.0.
+            amp (bool):
+                Specifies whether to use automatic mixed precision. Default: ``False``.
+            cache (bool):
+                If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
+            verbose (bool):
+                If ``True``, increases the output verbosity. Default: ``True``.
+        """
+
+        args = self.args.update(locals())
+        init_logger(logger, verbose=args.verbose)
+
+        self.transform.train()
+        batch_size = batch_size // update_steps
+        eval_batch_size = args.get('eval_batch_size', batch_size)
+        if is_dist():
+            batch_size = batch_size // dist.get_world_size()
+            eval_batch_size = eval_batch_size // dist.get_world_size()
+        logger.info("Loading the data")
+        if args.cache:
+            args.bin = os.path.join(os.path.dirname(args.path), 'bin')
+        args.even = args.get('even', is_dist())
+        
+        train = Dataset(self.transform, args.train, **args).build(batch_size=batch_size,
+                                                                  n_buckets=buckets,
+                                                                  shuffle=True,
+                                                                  distributed=is_dist(),
+                                                                  even=args.even,
+                                                                  n_workers=workers)
+        dev = Dataset(self.transform, args.dev, **args).build(batch_size=eval_batch_size,
+                                                              n_buckets=buckets,
+                                                              shuffle=False,
+                                                              distributed=is_dist(),
+                                                              even=False,
+                                                              n_workers=workers)
+        logger.info(f"{'train:':6} {train}")
+        if not args.test:
+            logger.info(f"{'dev:':6} {dev}\n")
+        else:
+            test = Dataset(self.transform, args.test, **args).build(batch_size=eval_batch_size,
+                                                                    n_buckets=buckets,
+                                                                    shuffle=False,
+                                                                    distributed=is_dist(),
+                                                                    even=False,
+                                                                    n_workers=workers)
+            logger.info(f"{'dev:':6} {dev}")
+            logger.info(f"{'test:':6} {test}\n")
+        loader, sampler = train.loader, train.loader.batch_sampler
+        args.steps = len(loader) * epochs // args.update_steps
+        args.save(f"{args.path}.yaml")
+
+        self.optimizer = self.init_optimizer()
+        self.scheduler = self.init_scheduler()
+        self.scaler = GradScaler(enabled=args.amp)
+
+        if dist.is_initialized():
+            self.model = DDP(module=self.model,
+                             device_ids=[args.local_rank],
+                             find_unused_parameters=args.get('find_unused_parameters', True),
+                             static_graph=args.get('static_graph', False))
+            if args.amp:
+                from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook
+                self.model.register_comm_hook(dist.group.WORLD, fp16_compress_hook)
+        if args.wandb and is_master():
+            import wandb
+            # start a new wandb run to track this script
+            wandb.init(config=args.primitive_config,
+                       project=args.get('project', self.NAME),
+                       name=args.get('name', args.path),
+                       resume=self.args.checkpoint)
+        self.step, self.epoch, self.best_e, self.patience = 1, 1, 1, patience
+        # uneven batches are excluded
+        self.n_batches = min(gather(len(loader))) if is_dist() else len(loader)
+        self.best_metric, self.elapsed = Metric(), timedelta()
+        if args.checkpoint:
+            try:
+                self.optimizer.load_state_dict(self.checkpoint_state_dict.pop('optimizer_state_dict'))
+                self.scheduler.load_state_dict(self.checkpoint_state_dict.pop('scheduler_state_dict'))
+                self.scaler.load_state_dict(self.checkpoint_state_dict.pop('scaler_state_dict'))
+                set_rng_state(self.checkpoint_state_dict.pop('rng_state'))
+                for k, v in self.checkpoint_state_dict.items():
+                    setattr(self, k, v)
+                sampler.set_epoch(self.epoch)
+            except AttributeError:
+                logger.warning("No checkpoint found. Try re-launching the training procedure instead")
+
+        for epoch in range(self.epoch, args.epochs + 1):
+            start = datetime.now()
+            bar, metric = progress_bar(loader), Metric()
+
+            logger.info(f"Epoch {epoch} / {args.epochs}:")
+            self.model.train()
+            with self.join():
+                # we should reset `step` as the number of batches in different processes is not necessarily equal
+                self.step = 1
+                for batch in bar:
+                    with self.sync():
+                        with torch.autocast(self.device, enabled=args.amp):
+                            loss = self.train_step(batch)
+                        self.backward(loss)
+                    if self.sync_grad:
+                        self.clip_grad_norm_(self.model.parameters(), args.clip)
+                        self.scaler.step(self.optimizer)
+                        self.scaler.update()
+                        self.scheduler.step()
+                        self.optimizer.zero_grad(True)
+
+                    bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f}")
+                    # log metrics to wandb
+                    if args.wandb and is_master():
+                        wandb.log({'lr': self.scheduler.get_last_lr()[0], 'loss': loss})
+                    self.step += 1
+                logger.info(f"{bar.postfix}")
+            self.model.eval()
+            with self.join(), torch.autocast(self.device, enabled=args.amp):
+                metric = self.reduce(sum([self.eval_step(i) for i in progress_bar(dev.loader)], Metric()))
+                logger.info(f"{'dev:':5} {metric}")
+                if args.wandb and is_master():
+                    wandb.log({'dev': metric.values, 'epochs': epoch})
+                if args.test:
+                    test_metric = sum([self.eval_step(i) for i in progress_bar(test.loader)], Metric())
+                    logger.info(f"{'test:':5} {self.reduce(test_metric)}")
+                    if args.wandb and is_master():
+                        wandb.log({'test': test_metric.values, 'epochs': epoch})
+
+            t = datetime.now() - start
+            self.epoch += 1
+            self.patience -= 1
+            self.elapsed += t
+
+            if metric > self.best_metric:
+                self.best_e, self.patience, self.best_metric = epoch, patience, metric
+                if is_master():
+                    self.save_checkpoint(args.path)
+                logger.info(f"{t}s elapsed (saved)\n")
+            else:
+                logger.info(f"{t}s elapsed\n")
+            if self.patience < 1:
+                break
+        if is_dist():
+            dist.barrier()
+
+        best = self.load(**args)
+        # only allow the master device to save models
+        if is_master():
+            best.save(args.path)
+
+        logger.info(f"Epoch {self.best_e} saved")
+        logger.info(f"{'dev:':5} {self.best_metric}")
+        if args.test:
+            best.model.eval()
+            with best.join():
+                test_metric = sum([best.eval_step(i) for i in progress_bar(test.loader)], Metric())
+                logger.info(f"{'test:':5} {best.reduce(test_metric)}")
+        logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch")
+        if args.wandb and is_master():
+            wandb.finish()
+
+        self.evaluate(data=args.test, batch_size=batch_size)
+        self.predict(args.test, batch_size=batch_size, buckets=buckets, workers=workers)
+
+        with open(f'{self.args.folder}/status', 'w') as file:
+            file.write('finished')
+
+
+
+    def evaluate(
+        self,
+        data: Union[str, Iterable],
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        amp: bool = False,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                The data for evaluation. Both a filename and a list of instances are allowed.
+            batch_size (int):
+                The number of tokens in each batch. Default: 5000.
+            buckets (int):
+                The number of buckets that sentences are assigned to. Default: 8.
+            workers (int):
+                The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
+            amp (bool):
+                Specifies whether to use automatic mixed precision. Default: ``False``.
+            cache (bool):
+                If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
+            verbose (bool):
+                If ``True``, increases the output verbosity. Default: ``True``.
+
+        Returns:
+            The evaluation results.
+        """
+
+        args = self.args.update(locals())
+        init_logger(logger, verbose=args.verbose)
+
+        self.transform.train()
+        logger.info("Loading the data")
+        if args.cache:
+            args.bin = os.path.join(os.path.dirname(args.path), 'bin')
+        if is_dist():
+            batch_size = batch_size // dist.get_world_size()
+        data = Dataset(self.transform, **args)
+        data.build(batch_size=batch_size,
+                   n_buckets=buckets,
+                   shuffle=False,
+                   distributed=is_dist(),
+                   even=False,
+                   n_workers=workers)
+        logger.info(f"\n{data}")
+
+        logger.info("Evaluating the data")
+        start = datetime.now()
+        self.model.eval()
+        with self.join():
+            bar, metric = progress_bar(data.loader), Metric()
+            for batch in bar:
+                metric += self.eval_step(batch)
+                bar.set_postfix_str(metric)
+            metric = self.reduce(metric)
+        elapsed = datetime.now() - start
+        logger.info(f"{metric}")
+        logger.info(f"{elapsed}s elapsed, "
+                    f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, "
+                    f"{len(data)/elapsed.total_seconds():.2f} Sents/s")
+        os.makedirs(os.path.dirname(self.args.folder + '/metrics.pickle'), exist_ok=True)
+        with open(f'{self.args.folder}/metrics.pickle', 'wb') as file:
+            pickle.dump(obj=metric, file=file)
+
+        return metric
+
+    def predict(
+        self,
+        data: Union[str, Iterable],
+        pred: str = None,
+        lang: str = None,
+        prob: bool = False,
+        batch_size: int = 5000,
+        buckets: int = 8,
+        workers: int = 0,
+        cache: bool = False,
+        verbose: bool = True,
+        **kwargs
+    ):
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                The data for prediction.
+                - a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
+                - a list of instances.
+            pred (str):
+                If specified, the predicted results will be saved to the file. Default: ``None``.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+            prob (bool):
+                If ``True``, outputs the probabilities. Default: ``False``.
+            batch_size (int):
+                The number of tokens in each batch. Default: 5000.
+            buckets (int):
+                The number of buckets that sentences are assigned to. Default: 8.
+            workers (int):
+                The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
+            amp (bool):
+                Specifies whether to use automatic mixed precision. Default: ``False``.
+            cache (bool):
+                If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
+            verbose (bool):
+                If ``True``, increases the output verbosity. Default: ``True``.
+
+        Returns:
+            A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
+        """
+
+        args = self.args.update(locals())
+        init_logger(logger, verbose=args.verbose)
+
+        if self.args.use_vq:
+            self.model.passes_remaining = 0
+            self.model.vq.observe_steps_remaining = 0
+
+        self.transform.eval()
+        if args.prob:
+            self.transform.append(Field('probs'))
+
+        #logger.info("Loading the data")
+        if args.cache:
+            args.bin = os.path.join(os.path.dirname(args.path), 'bin')
+        if is_dist():
+            batch_size = batch_size // dist.get_world_size()
+        data = Dataset(self.transform, **args)
+        data.build(batch_size=batch_size,
+                   n_buckets=buckets,
+                   shuffle=False,
+                   distributed=is_dist(),
+                   even=False,
+                   n_workers=workers)
+
+        #logger.info(f"\n{data}")
+
+        #logger.info("Making predictions on the data")
+        start = datetime.now()
+        self.model.eval()
+        #with tempfile.TemporaryDirectory() as t:
+            # we have clustered the sentences by length here to speed up prediction,
+            # so the order of the yielded sentences can't be guaranteed
+        for batch in progress_bar(data.loader):
+            #batch, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict = self.pred_step(batch)
+            *predicted_values, = self.pred_step(batch)
+            #print('429 supar/parser.py ', batch.sentences)
+            #print(head_preds, deprel_preds, stack_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text)
+            #logger.info(f"Saving predicted results to {pred}")
+           # with open(pred, 'w') as f:
+                    
+
+            elapsed = datetime.now() - start
+
+            #if is_dist():
+            #    dist.barrier()
+            #tdirs = gather(t) if is_dist() else (t,)
+            #if pred is not None and is_master():
+                #logger.info(f"Saving predicted results to {pred}")
+            """with open(pred, 'w') as f:
+                    # merge all predictions into one single file
+                    if is_dist() or args.cache:
+                        sentences = (os.path.join(i, s) for i in tdirs for s in os.listdir(i))
+                        for i in progress_bar(sorted(sentences, key=lambda x: int(os.path.basename(x)))):
+                            with open(i) as s:
+                                shutil.copyfileobj(s, f)
+                    else:
+                        for s in progress_bar(data):
+                            f.write(str(s) + '\n')"""
+            # exit util all files have been merged
+            if is_dist():
+                dist.barrier()
+        #logger.info(f"{elapsed}s elapsed, "
+        #            f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, "
+        #            f"{len(data)/elapsed.total_seconds():.2f} Sents/s")
+
+        if not cache:
+            #return data, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict
+            return *predicted_values,
+
+    def backward(self, loss: torch.Tensor, **kwargs):
+        loss /= self.args.update_steps
+        if hasattr(self, 'scaler'):
+            self.scaler.scale(loss).backward(**kwargs)
+        else:
+            loss.backward(**kwargs)
+
+    def clip_grad_norm_(
+        self,
+        params: Union[Iterable[torch.Tensor], torch.Tensor],
+        max_norm: float,
+        norm_type: float = 2
+    ) -> torch.Tensor:
+        self.scaler.unscale_(self.optimizer)
+        return nn.utils.clip_grad_norm_(params, max_norm, norm_type)
+
+    def clip_grad_value_(
+        self,
+        params: Union[Iterable[torch.Tensor], torch.Tensor],
+        clip_value: float
+    ) -> None:
+        self.scaler.unscale_(self.optimizer)
+        return nn.utils.clip_grad_value_(params, clip_value)
+
+    def reduce(self, obj: Any) -> Any:
+        if not is_dist():
+            return obj
+        return reduce(obj)
+
+    def train_step(self, batch: Batch) -> torch.Tensor:
+        ...
+
+    @torch.no_grad()
+    def eval_step(self, batch: Batch) -> Metric:
+        ...
+
+    @torch.no_grad()
+    def pred_step(self, batch: Batch) -> Batch:
+        ...
+
+    def init_optimizer(self) -> Optimizer:
+        if self.args.encoder in ('lstm', 'transformer'):
+            optimizer = Adam(params=self.model.parameters(),
+                             lr=self.args.lr,
+                             betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)),
+                             eps=self.args.get('eps', 1e-8),
+                             weight_decay=self.args.get('weight_decay', 0))
+        else:
+            # we found that Huggingface's AdamW is more robust and empirically better than the native implementation
+            from transformers import AdamW
+            optimizer = AdamW(params=[{'params': p, 'lr': self.args.lr * (1 if n.startswith('encoder') else self.args.lr_rate)}
+                                      for n, p in self.model.named_parameters()],
+                              lr=self.args.lr,
+                              betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)),
+                              eps=self.args.get('eps', 1e-8),
+                              weight_decay=self.args.get('weight_decay', 0))
+        return optimizer
+
+    def init_scheduler(self) -> _LRScheduler:
+        if self.args.encoder == 'lstm':
+            scheduler = ExponentialLR(optimizer=self.optimizer,
+                                      gamma=self.args.decay**(1/self.args.decay_steps))
+        elif self.args.encoder == 'transformer':
+            scheduler = InverseSquareRootLR(optimizer=self.optimizer,
+                                            warmup_steps=self.args.warmup_steps)
+        else:
+            scheduler = LinearLR(optimizer=self.optimizer,
+                                 warmup_steps=self.args.get('warmup_steps', int(self.args.steps*self.args.get('warmup', 0))),
+                                 steps=self.args.steps)
+        return scheduler
+
+    @classmethod
+    def build(cls, path, **kwargs):
+        ...
+
+    @classmethod
+    def load(
+        cls,
+        path: str,
+        reload: bool = True,
+        src: str = 'github',
+        checkpoint: bool = False,
+        **kwargs
+    ) -> Parser:
+        r"""
+        Loads a parser with data fields and pretrained model parameters.
+
+        Args:
+            path (str):
+                - a string with the shortcut name of a pretrained model defined in ``supar.MODEL``
+                  to load from cache or download, e.g., ``'biaffine-dep-en'``.
+                - a local path to a pretrained model, e.g., ``./<path>/model``.
+            reload (bool):
+                Whether to discard the existing cache and force a fresh download. Default: ``False``.
+            src (str):
+                Specifies where to download the model.
+                ``'github'``: github release page.
+                ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
+                Default: ``'github'``.
+            checkpoint (bool):
+                If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
+
+        Examples:
+            >>> from supar import Parser
+            >>> parser = Parser.load('biaffine-dep-en')
+            >>> parser = Parser.load('./ptb.biaffine.dep.lstm.char')
+        """
+
+        args = Config(**locals())
+        if not os.path.exists(path):
+            path = download(supar.MODEL[src].get(path, path), reload=reload)
+        state = torch.load(path, map_location='cpu', weights_only=False)
+        #torch.load(path, map_location='cpu')
+        cls = supar.PARSER[state['name']] if cls.NAME is None else cls
+        args = state['args'].update(args)
+        #print('ARGS', args)
+        model = cls.MODEL(**args)
+        model.load_pretrained(state['pretrained'])
+        model.load_state_dict(state['state_dict'], True)
+        transform = state['transform']
+        parser = cls(args, model, transform)
+        parser.checkpoint_state_dict = state.get('checkpoint_state_dict', None) if checkpoint else None
+        parser.model.to(parser.device)
+        return parser
+
+    def save(self, path: str) -> None:
+        model = self.model
+        if hasattr(model, 'module'):
+            model = self.model.module
+        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
+        pretrained = state_dict.pop('pretrained.weight', None)
+        state = {'name': self.NAME,
+                 'args': model.args,
+                 'state_dict': state_dict,
+                 'pretrained': pretrained,
+                 'transform': self.transform}
+        torch.save(state, path, pickle_module=dill)
+
+    def save_checkpoint(self, path: str) -> None:
+        model = self.model
+        if hasattr(model, 'module'):
+            model = self.model.module
+        checkpoint_state_dict = {k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']}
+        checkpoint_state_dict.update({'optimizer_state_dict': self.optimizer.state_dict(),
+                                      'scheduler_state_dict': self.scheduler.state_dict(),
+                                      'scaler_state_dict': self.scaler.state_dict(),
+                                      'rng_state': get_rng_state()})
+        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
+        pretrained = state_dict.pop('pretrained.weight', None)
+        state = {'name': self.NAME,
+                 'args': model.args,
+                 'state_dict': state_dict,
+                 'pretrained': pretrained,
+                 'checkpoint_state_dict': checkpoint_state_dict,
+                 'transform': self.transform}
+        torch.save(state, path, pickle_module=dill)
diff --git a/tania_scripts/supar/structs/__init__.py b/tania_scripts/supar/structs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9afbd79927a0bb15b198b08ae53a18af03e42b7f
--- /dev/null
+++ b/tania_scripts/supar/structs/__init__.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+
+from .chain import LinearChainCRF, SemiMarkovCRF
+from .dist import StructuredDistribution
+from .tree import (BiLexicalizedConstituencyCRF, ConstituencyCRF,
+                   Dependency2oCRF, DependencyCRF, MatrixTree)
+from .vi import (ConstituencyLBP, ConstituencyMFVI, DependencyLBP,
+                 DependencyMFVI, SemanticDependencyLBP, SemanticDependencyMFVI)
+
+__all__ = ['StructuredDistribution',
+           'LinearChainCRF',
+           'SemiMarkovCRF',
+           'MatrixTree',
+           'DependencyCRF',
+           'Dependency2oCRF',
+           'ConstituencyCRF',
+           'BiLexicalizedConstituencyCRF',
+           'DependencyMFVI',
+           'DependencyLBP',
+           'ConstituencyMFVI',
+           'ConstituencyLBP',
+           'SemanticDependencyMFVI',
+           'SemanticDependencyLBP', ]
diff --git a/tania_scripts/supar/structs/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef548e93f5fe211cae6a010de99dbcb04c3710ad
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4517337ef5ef26d1d5c1d3e7001d4ea719d5120
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/chain.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/chain.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e42d31e38ad624ca095da2a77618bca27d421e88
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/chain.cpython-310.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/chain.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/chain.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f42a094bb248a728f5748ec2648f7828e587fdd
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/chain.cpython-311.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/dist.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/dist.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fd9588165983a30bc91fedfe71a124d9e274b7a
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/dist.cpython-310.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/dist.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/dist.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bde8b84641e2db640d4bc8b6077df3a2e1aae03a
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/dist.cpython-311.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/fn.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/fn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..128838b69315f5cfbbab3bcd5933e1af0b4e9816
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/fn.cpython-310.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/fn.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/fn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52542e3188bfcde5e56f78ad516e4e75e2da4552
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/fn.cpython-311.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/semiring.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/semiring.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eebbfcd87f3a9e230655a6f65e90217253551767
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/semiring.cpython-310.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/semiring.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/semiring.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9cfb3a935ad9c1fee1faac17b0a0e4317f6aa3d
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/semiring.cpython-311.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/tree.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/tree.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cfa02dffd8263f04ff8556087083f71ef8bccee
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/tree.cpython-310.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/tree.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/tree.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b42efb7cbe328a1db42ec469766abe4c93796ba5
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/tree.cpython-311.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/vi.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/vi.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0dc6ac96670b2c3dda14334e09a50e31c53a3ff5
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/vi.cpython-310.pyc differ
diff --git a/tania_scripts/supar/structs/__pycache__/vi.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/vi.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b5a62221a93d1af40f302d54c6da4bfff1818d4
Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/vi.cpython-311.pyc differ
diff --git a/tania_scripts/supar/structs/chain.py b/tania_scripts/supar/structs/chain.py
new file mode 100644
index 0000000000000000000000000000000000000000..828caab76fac936df414ee46118050525933405f
--- /dev/null
+++ b/tania_scripts/supar/structs/chain.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import List, Optional
+
+import torch
+from supar.structs.dist import StructuredDistribution
+from supar.structs.semiring import LogSemiring, Semiring
+from torch.distributions.utils import lazy_property
+
+
+class LinearChainCRF(StructuredDistribution):
+    r"""
+        Linear-chain CRFs :cite:`lafferty-etal-2001-crf`.
+
+        Args:
+            scores (~torch.Tensor): ``[batch_size, seq_len, n_tags]``.
+                Log potentials.
+            trans (~torch.Tensor): ``[n_tags+1, n_tags+1]``.
+                Transition scores.
+                ``trans[-1, :-1]``/``trans[:-1, -1]`` represent oracle for start/end positions respectively.
+            lens (~torch.LongTensor): ``[batch_size]``.
+                Sentence lengths for masking. Default: ``None``.
+
+        Examples:
+            >>> from supar import LinearChainCRF
+            >>> batch_size, seq_len, n_tags = 2, 5, 4
+            >>> lens = torch.tensor([3, 4])
+            >>> value = torch.randint(n_tags, (batch_size, seq_len))
+            >>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags),
+                                    torch.randn(n_tags+1, n_tags+1),
+                                    lens)
+            >>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags),
+                                    torch.randn(n_tags+1, n_tags+1),
+                                    lens)
+            >>> s1.max
+            tensor([4.4120, 8.9672], grad_fn=<MaxBackward0>)
+            >>> s1.argmax
+            tensor([[2, 0, 3, 0, 0],
+                    [3, 3, 3, 2, 0]])
+            >>> s1.log_partition
+            tensor([ 6.3486, 10.9106], grad_fn=<LogsumexpBackward>)
+            >>> s1.log_prob(value)
+            tensor([ -8.1515, -10.5572], grad_fn=<SubBackward0>)
+            >>> s1.entropy
+            tensor([3.4150, 3.6549], grad_fn=<SelectBackward>)
+            >>> s1.kl(s2)
+            tensor([4.0333, 4.3807], grad_fn=<SelectBackward>)
+    """
+
+    def __init__(
+        self,
+        scores: torch.Tensor,
+        trans: Optional[torch.Tensor] = None,
+        lens: Optional[torch.LongTensor] = None
+    ) -> LinearChainCRF:
+        super().__init__(scores, lens=lens)
+
+        batch_size, seq_len, self.n_tags = scores.shape[:3]
+        self.lens = scores.new_full((batch_size,), seq_len).long() if lens is None else lens
+        self.mask = self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len)))
+
+        self.trans = self.scores.new_full((self.n_tags+1, self.n_tags+1), LogSemiring.one) if trans is None else trans
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(n_tags={self.n_tags})"
+
+    def __add__(self, other):
+        return LinearChainCRF(torch.stack((self.scores, other.scores), -1),
+                              torch.stack((self.trans, other.trans), -1),
+                              self.lens)
+
+    @lazy_property
+    def argmax(self):
+        return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2])
+
+    def topk(self, k: int) -> torch.LongTensor:
+        preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1)
+        return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
+
+    def score(self, value: torch.LongTensor) -> torch.Tensor:
+        scores, mask, value = self.scores.transpose(0, 1), self.mask.t(), value.t()
+        prev, succ = torch.cat((torch.full_like(value[:1], -1), value[:-1]), 0), value
+        # [seq_len, batch_size]
+        alpha = scores.gather(-1, value.unsqueeze(-1)).squeeze(-1)
+        # [batch_size]
+        alpha = LogSemiring.prod(LogSemiring.one_mask(LogSemiring.mul(alpha, self.trans[prev, succ]), ~mask), 0)
+        alpha = alpha + self.trans[value.gather(0, self.lens.unsqueeze(0) - 1).squeeze(0), torch.full_like(value[0], -1)]
+        return alpha
+
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        # [seq_len, batch_size, n_tags, ...]
+        scores = semiring.convert(self.scores.transpose(0, 1))
+        trans = semiring.convert(self.trans)
+        mask = self.mask.t()
+
+        # [batch_size, n_tags]
+        alpha = semiring.mul(trans[-1, :-1], scores[0])
+        for i in range(1, len(mask)):
+            alpha[mask[i]] = semiring.mul(semiring.dot(alpha.unsqueeze(2), trans[:-1, :-1], 1), scores[i])[mask[i]]
+        alpha = semiring.dot(alpha, trans[:-1, -1], 1)
+        return semiring.unconvert(alpha)
+
+
+class SemiMarkovCRF(StructuredDistribution):
+    r"""
+        Semi-markov CRFs :cite:`sarawagi-cohen-2004-semicrf`.
+
+        Args:
+            scores (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_tags]``.
+                Log potentials.
+            trans (~torch.Tensor): ``[n_tags, n_tags]``.
+                Transition scores.
+            lens (~torch.LongTensor): ``[batch_size]``.
+                Sentence lengths for masking. Default: ``None``.
+
+        Examples:
+            >>> from supar import SemiMarkovCRF
+            >>> batch_size, seq_len, n_tags = 2, 5, 4
+            >>> lens = torch.tensor([3, 4])
+            >>> value = torch.tensor([[[ 0, -1, -1, -1, -1],
+                                       [-1, -1,  2, -1, -1],
+                                       [-1, -1, -1, -1, -1],
+                                       [-1, -1, -1, -1, -1],
+                                       [-1, -1, -1, -1, -1]],
+                                      [[-1,  1, -1, -1, -1],
+                                       [-1, -1,  3, -1, -1],
+                                       [-1, -1, -1,  0, -1],
+                                       [-1, -1, -1, -1, -1],
+                                       [-1, -1, -1, -1, -1]]])
+            >>> s1 = SemiMarkovCRF(torch.randn(batch_size, seq_len, seq_len, n_tags),
+                                   torch.randn(n_tags, n_tags),
+                                   lens)
+            >>> s2 = SemiMarkovCRF(torch.randn(batch_size, seq_len, seq_len, n_tags),
+                                   torch.randn(n_tags, n_tags),
+                                   lens)
+            >>> s1.max
+            tensor([4.1971, 5.5746], grad_fn=<MaxBackward0>)
+            >>> s1.argmax
+            [[[0, 0, 1], [1, 1, 0], [2, 2, 1]], [[0, 0, 1], [1, 1, 3], [2, 2, 0], [3, 3, 1]]]
+            >>> s1.log_partition
+            tensor([6.3641, 8.4384], grad_fn=<LogsumexpBackward0>)
+            >>> s1.log_prob(value)
+            tensor([-5.7982, -7.4534], grad_fn=<SubBackward0>)
+            >>> s1.entropy
+            tensor([3.7520, 5.1609], grad_fn=<SelectBackward0>)
+            >>> s1.kl(s2)
+            tensor([3.5348, 2.2826], grad_fn=<SelectBackward0>)
+    """
+
+    def __init__(
+        self,
+        scores: torch.Tensor,
+        trans: Optional[torch.Tensor] = None,
+        lens: Optional[torch.LongTensor] = None
+    ) -> SemiMarkovCRF:
+        super().__init__(scores, lens=lens)
+
+        batch_size, seq_len, _, self.n_tags = scores.shape[:4]
+        self.lens = scores.new_full((batch_size,), seq_len).long() if lens is None else lens
+        self.mask = self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len)))
+        self.mask = self.mask.unsqueeze(1) & self.mask.unsqueeze(2)
+
+        self.trans = self.scores.new_full((self.n_tags, self.n_tags), LogSemiring.one) if trans is None else trans
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(n_tags={self.n_tags})"
+
+    def __add__(self, other):
+        return SemiMarkovCRF(torch.stack((self.scores, other.scores), -1),
+                             torch.stack((self.trans, other.trans), -1),
+                             self.lens)
+
+    @lazy_property
+    def argmax(self) -> List:
+        return [torch.nonzero(i).tolist() for i in self.backward(self.max.sum())]
+
+    def topk(self, k: int) -> List:
+        return list(zip(*[[torch.nonzero(j).tolist() for j in self.backward(i)] for i in self.kmax(k).sum(0)]))
+
+    def score(self, value: torch.LongTensor) -> torch.Tensor:
+        mask = self.mask & value.ge(0)
+        lens = mask.sum((1, 2))
+        indices = torch.where(mask)
+        batch_size, seq_len = lens.shape[0], lens.max()
+        span_mask = lens.unsqueeze(-1).gt(lens.new_tensor(range(seq_len)))
+        scores = self.scores.new_full((batch_size, seq_len), LogSemiring.one)
+        scores = scores.masked_scatter_(span_mask, self.scores[(*indices, value[indices])])
+        scores = LogSemiring.prod(LogSemiring.one_mask(scores, ~span_mask), -1)
+        value = value.new_zeros(batch_size, seq_len).masked_scatter_(span_mask, value[indices])
+        trans = LogSemiring.prod(LogSemiring.one_mask(self.trans[value[:, :-1], value[:, 1:]], ~span_mask[:, 1:]), -1)
+        return LogSemiring.mul(scores, trans)
+
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        # [seq_len, seq_len, batch_size, n_tags, ...]
+        scores = semiring.convert(self.scores.movedim((1, 2), (0, 1)))
+        trans = semiring.convert(self.trans)
+        # [seq_len, batch_size, n_tags, ...]
+        alpha = semiring.zeros_like(scores[0])
+
+        alpha[0] = scores[0, 0]
+        # [batch_size, n_tags]
+        for t in range(1, len(scores)):
+            # [batch_size, n_tags, ...]
+            s = semiring.dot(semiring.dot(alpha[:t].unsqueeze(3), trans, 2), scores[1:t+1, t], 0)
+            alpha[t] = semiring.sum(torch.stack((s, scores[0, t])), 0)
+        return semiring.unconvert(semiring.sum(alpha[self.lens - 1, range(len(self.lens))], 1))
diff --git a/tania_scripts/supar/structs/dist.py b/tania_scripts/supar/structs/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab594d4e4c72374122e66672529394f656d788e
--- /dev/null
+++ b/tania_scripts/supar/structs/dist.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import Iterable, Union
+
+import torch
+import torch.autograd as autograd
+from supar.structs.semiring import (CrossEntropySemiring, EntropySemiring,
+                                    KLDivergenceSemiring, KMaxSemiring,
+                                    LogSemiring, MaxSemiring, SampledSemiring,
+                                    Semiring)
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import lazy_property
+
+
+class StructuredDistribution(Distribution):
+    r"""
+    Base class for structured distribution :math:`p(y)` :cite:`eisner-2016-inside,goodman-1999-semiring,li-eisner-2009-first`.
+
+    Args:
+        scores (torch.Tensor):
+            Log potentials, also for high-order cases.
+
+    """
+
+    def __init__(self, scores: torch.Tensor, **kwargs) -> StructuredDistribution:
+        self.scores = scores.requires_grad_() if isinstance(scores, torch.Tensor) else [s.requires_grad_() for s in scores]
+        self.kwargs = kwargs
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}()"
+
+    def __add__(self, other: 'StructuredDistribution') -> StructuredDistribution:
+        return self.__class__(torch.stack((self.scores, other.scores), -1), lens=self.lens)
+
+    @lazy_property
+    def log_partition(self):
+        r"""
+        Computes the log partition function of the distribution :math:`p(y)`.
+        """
+
+        return self.forward(LogSemiring)
+
+    @lazy_property
+    def marginals(self):
+        r"""
+        Computes marginal probabilities of the distribution :math:`p(y)`.
+        """
+
+        return self.backward(self.log_partition.sum())
+
+    @lazy_property
+    def max(self):
+        r"""
+        Computes the max score of the distribution :math:`p(y)`.
+        """
+
+        return self.forward(MaxSemiring)
+
+    @lazy_property
+    def argmax(self):
+        r"""
+        Computes :math:`\arg\max_y p(y)` of the distribution :math:`p(y)`.
+        """
+
+        return self.backward(self.max.sum())
+
+    @lazy_property
+    def mode(self):
+        return self.argmax
+
+    def kmax(self, k: int) -> torch.Tensor:
+        r"""
+        Computes the k-max of the distribution :math:`p(y)`.
+        """
+
+        return self.forward(KMaxSemiring(k))
+
+    def topk(self, k: int) -> Union[torch.Tensor, Iterable]:
+        r"""
+        Computes the k-argmax of the distribution :math:`p(y)`.
+        """
+        raise NotImplementedError
+
+    def sample(self):
+        r"""
+        Obtains a structured sample from the distribution :math:`y \sim p(y)`.
+        TODO: multi-sampling.
+        """
+
+        return self.backward(self.forward(SampledSemiring).sum()).detach()
+
+    @lazy_property
+    def entropy(self):
+        r"""
+        Computes entropy :math:`H[p]` of the distribution :math:`p(y)`.
+        """
+
+        return self.forward(EntropySemiring)
+
+    def cross_entropy(self, other: 'StructuredDistribution') -> torch.Tensor:
+        r"""
+        Computes cross-entropy :math:`H[p,q]` of self and another distribution.
+
+        Args:
+            other (~supar.structs.dist.StructuredDistribution): Comparison distribution.
+        """
+
+        return (self + other).forward(CrossEntropySemiring)
+
+    def kl(self, other: 'StructuredDistribution') -> torch.Tensor:
+        r"""
+        Computes KL-divergence :math:`KL[p \parallel q]=H[p,q]-H[p]` of self and another distribution.
+
+        Args:
+            other (~supar.structs.dist.StructuredDistribution): Comparison distribution.
+        """
+
+        return (self + other).forward(KLDivergenceSemiring)
+
+    def log_prob(self, value: torch.LongTensor, *args, **kwargs) -> torch.Tensor:
+        """
+        Computes log probability over values :math:`p(y)`.
+        """
+
+        return self.score(value, *args, **kwargs) - self.log_partition
+
+    def score(self, value: torch.LongTensor, *args, **kwargs) -> torch.Tensor:
+        raise NotImplementedError
+
+    @torch.enable_grad()
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        raise NotImplementedError
+
+    def backward(self, log_partition: torch.Tensor) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
+        grads = autograd.grad(log_partition, self.scores, create_graph=True)
+        return grads[0] if isinstance(self.scores, torch.Tensor) else grads
diff --git a/tania_scripts/supar/structs/fn.py b/tania_scripts/supar/structs/fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a0af2ad71380169b1b30d27e545bf332a281e84
--- /dev/null
+++ b/tania_scripts/supar/structs/fn.py
@@ -0,0 +1,379 @@
+# -*- coding: utf-8 -*-
+
+import operator
+from typing import Iterable, Tuple, Union
+
+import torch
+from supar.utils.common import INF, MIN
+from supar.utils.fn import pad
+from torch.autograd import Function
+
+
+def tarjan(sequence: Iterable[int]) -> Iterable[int]:
+    r"""
+    Tarjan algorithm for finding Strongly Connected Components (SCCs) of a graph.
+
+    Args:
+        sequence (list):
+            List of head indices.
+
+    Yields:
+        A list of indices making up a SCC. All self-loops are ignored.
+
+    Examples:
+        >>> next(tarjan([2, 5, 0, 3, 1]))  # (1 -> 5 -> 2 -> 1) is a cycle
+        [2, 5, 1]
+    """
+
+    sequence = [-1] + sequence
+    # record the search order, i.e., the timestep
+    dfn = [-1] * len(sequence)
+    # record the the smallest timestep in a SCC
+    low = [-1] * len(sequence)
+    # push the visited into the stack
+    stack, onstack = [], [False] * len(sequence)
+
+    def connect(i, timestep):
+        dfn[i] = low[i] = timestep[0]
+        timestep[0] += 1
+        stack.append(i)
+        onstack[i] = True
+
+        for j, head in enumerate(sequence):
+            if head != i:
+                continue
+            if dfn[j] == -1:
+                yield from connect(j, timestep)
+                low[i] = min(low[i], low[j])
+            elif onstack[j]:
+                low[i] = min(low[i], dfn[j])
+
+        # a SCC is completed
+        if low[i] == dfn[i]:
+            cycle = [stack.pop()]
+            while cycle[-1] != i:
+                onstack[cycle[-1]] = False
+                cycle.append(stack.pop())
+            onstack[i] = False
+            # ignore the self-loop
+            if len(cycle) > 1:
+                yield cycle
+
+    timestep = [0]
+    for i in range(len(sequence)):
+        if dfn[i] == -1:
+            yield from connect(i, timestep)
+
+
+def chuliu_edmonds(s: torch.Tensor) -> torch.Tensor:
+    r"""
+    ChuLiu/Edmonds algorithm for non-projective decoding :cite:`mcdonald-etal-2005-non`.
+
+    Some code is borrowed from `tdozat's implementation`_.
+    Descriptions of notations and formulas can be found in :cite:`mcdonald-etal-2005-non`.
+
+    Notes:
+        The algorithm does not guarantee to parse a single-root tree.
+
+    Args:
+        s (~torch.Tensor): ``[seq_len, seq_len]``.
+            Scores of all dependent-head pairs.
+
+    Returns:
+        ~torch.Tensor:
+            A tensor with shape ``[seq_len]`` for the resulting non-projective parse tree.
+
+    .. _tdozat's implementation:
+        https://github.com/tdozat/Parser-v3
+    """
+
+    s[0, 1:] = MIN
+    # prevent self-loops
+    s.diagonal()[1:].fill_(MIN)
+    # select heads with highest scores
+    tree = s.argmax(-1)
+    # return the cycle finded by tarjan algorithm lazily
+    cycle = next(tarjan(tree.tolist()[1:]), None)
+    # if the tree has no cycles, then it is a MST
+    if not cycle:
+        return tree
+    # indices of cycle in the original tree
+    cycle = torch.tensor(cycle)
+    # indices of noncycle in the original tree
+    noncycle = torch.ones(len(s)).index_fill_(0, cycle, 0)
+    noncycle = torch.where(noncycle.gt(0))[0]
+
+    def contract(s):
+        # heads of cycle in original tree
+        cycle_heads = tree[cycle]
+        # scores of cycle in original tree
+        s_cycle = s[cycle, cycle_heads]
+
+        # calculate the scores of cycle's potential dependents
+        # s(c->x) = max(s(x'->x)), x in noncycle and x' in cycle
+        s_dep = s[noncycle][:, cycle]
+        # find the best cycle head for each noncycle dependent
+        deps = s_dep.argmax(1)
+        # calculate the scores of cycle's potential heads
+        # s(x->c) = max(s(x'->x) - s(a(x')->x') + s(cycle)), x in noncycle and x' in cycle
+        #                                                    a(v) is the predecessor of v in cycle
+        #                                                    s(cycle) = sum(s(a(v)->v))
+        s_head = s[cycle][:, noncycle] - s_cycle.view(-1, 1) + s_cycle.sum()
+        # find the best noncycle head for each cycle dependent
+        heads = s_head.argmax(0)
+
+        contracted = torch.cat((noncycle, torch.tensor([-1])))
+        # calculate the scores of contracted graph
+        s = s[contracted][:, contracted]
+        # set the contracted graph scores of cycle's potential dependents
+        s[:-1, -1] = s_dep[range(len(deps)), deps]
+        # set the contracted graph scores of cycle's potential heads
+        s[-1, :-1] = s_head[heads, range(len(heads))]
+
+        return s, heads, deps
+
+    # keep track of the endpoints of the edges into and out of cycle for reconstruction later
+    s, heads, deps = contract(s)
+
+    # y is the contracted tree
+    y = chuliu_edmonds(s)
+    # exclude head of cycle from y
+    y, cycle_head = y[:-1], y[-1]
+
+    # fix the subtree with no heads coming from the cycle
+    # len(y) denotes heads coming from the cycle
+    subtree = y < len(y)
+    # add the nodes to the new tree
+    tree[noncycle[subtree]] = noncycle[y[subtree]]
+    # fix the subtree with heads coming from the cycle
+    subtree = ~subtree
+    # add the nodes to the tree
+    tree[noncycle[subtree]] = cycle[deps[subtree]]
+    # fix the root of the cycle
+    cycle_root = heads[cycle_head]
+    # break the cycle and add the root of the cycle to the tree
+    tree[cycle[cycle_root]] = noncycle[cycle_head]
+
+    return tree
+
+
+def mst(scores: torch.Tensor, mask: torch.BoolTensor, multiroot: bool = False) -> torch.Tensor:
+    r"""
+    MST algorithm for decoding non-projective trees.
+    This is a wrapper for ChuLiu/Edmonds algorithm.
+
+    The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots,
+    If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find
+    best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds.
+    Otherwise the resulting trees are directly taken as the final outputs.
+
+    Args:
+        scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+            Scores of all dependent-head pairs.
+        mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+            The mask to avoid parsing over padding tokens.
+            The first column serving as pseudo words for roots should be ``False``.
+        multiroot (bool):
+            Ensures to parse a single-root tree If ``False``.
+
+    Returns:
+        ~torch.Tensor:
+            A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees.
+
+    Examples:
+        >>> scores = torch.tensor([[[-11.9436, -13.1464,  -6.4789, -13.8917],
+                                    [-60.6957, -60.2866, -48.6457, -63.8125],
+                                    [-38.1747, -49.9296, -45.2733, -49.5571],
+                                    [-19.7504, -23.9066,  -9.9139, -16.2088]]])
+        >>> scores[:, 0, 1:] = MIN
+        >>> scores.diagonal(0, 1, 2)[1:].fill_(MIN)
+        >>> mask = torch.tensor([[False,  True,  True,  True]])
+        >>> mst(scores, mask)
+        tensor([[0, 2, 0, 2]])
+    """
+
+    _, seq_len, _ = scores.shape
+    scores = scores.cpu().unbind()
+
+    preds = []
+    for i, length in enumerate(mask.sum(1).tolist()):
+        s = scores[i][:length+1, :length+1]
+        tree = chuliu_edmonds(s)
+        roots = torch.where(tree[1:].eq(0))[0] + 1
+        if not multiroot and len(roots) > 1:
+            s_root = s[:, 0]
+            s_best = MIN
+            s = s.index_fill(1, torch.tensor(0), MIN)
+            for root in roots:
+                s[:, 0] = MIN
+                s[root, 0] = s_root[root]
+                t = chuliu_edmonds(s)
+                s_tree = s[1:].gather(1, t[1:].unsqueeze(-1)).sum()
+                if s_tree > s_best:
+                    s_best, tree = s_tree, t
+        preds.append(tree)
+
+    return pad(preds, total_length=seq_len).to(mask.device)
+
+
+def levenshtein(x: Iterable, y: Iterable, align: bool = False) -> int:
+    """
+    Calculates the Levenshtein edit-distance between two sequences.
+    The edit distance is the number of characters that need to be
+    substituted, inserted, or deleted, to transform `x` into `y`.
+
+    For example, transforming "rain" to "shine" requires three steps,
+    consisting of two substitutions and one insertion:
+    "rain" -> "sain" -> "shin" -> "shine".
+    These operations could have been done in other orders, but at least three steps are needed.
+
+    Allows specifying the cost of substitution edits (e.g., "a" -> "b"),
+    because sometimes it makes sense to assign greater penalties to substitutions.
+
+    The code is revised from `nltk`_ and `wiki`_'s implementations.
+
+    Args:
+        x/y (Iterable):
+            The sequences to be analysed.
+        align (bool):
+            Whether to return the alignments based on the minimum Levenshtein edit-distance. Default: ``False``.
+
+    Examples:
+        >>> from supar.structs.utils.fn import levenshtein
+        >>> levenshtein('intention', 'execution', align=True)
+        (5, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)])
+
+    .. _nltk:
+        https://github.com/nltk/nltk/blob/develop/nltk/metrics/distance.py
+    .. _wiki:
+        https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance
+    """
+
+    # set up a 2-D array
+    len1, len2 = len(x), len(y)
+    lev = [list(range(len2 + 1))] + [[i] + [0] * len2 for i in range(1, len1 + 1)]
+
+    # iterate over the array
+    # i and j start from 1 and not 0 to stay close to the wikipedia pseudo-code
+    # see https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance
+    for i in range(1, len1 + 1):
+        for j in range(1, len2 + 1):
+            # substitution
+            s = lev[i - 1][j - 1] + (x[i - 1] != y[j - 1])
+            # deletion
+            a = lev[i - 1][j] + 1
+            # insertion
+            b = lev[i][j - 1] + 1
+
+            lev[i][j] = min(s, a, b)
+    distance = lev[-1][-1]
+    if align:
+        i, j = len1, len2
+        alignments = [(i, j)]
+        while (i, j) != (0, 0):
+            directions = [
+                (i - 1, j - 1),  # substitution
+                (i - 1, j),  # deletion
+                (i, j - 1),  # insertion
+            ]
+            direction_costs = ((lev[i][j] if (i >= 0 and j >= 0) else INF, (i, j)) for i, j in directions)
+            _, (i, j) = min(direction_costs, key=operator.itemgetter(0))
+            alignments.append((i, j))
+        alignments = list(reversed(alignments))
+    return (distance, alignments) if align else distance
+
+
+class Logsumexp(Function):
+
+    r"""
+    Safer ``logsumexp`` to cure unnecessary NaN values that arise from inf arguments.
+    See discussions at http://github.com/pytorch/pytorch/issues/49724.
+    To be optimized with C++/Cuda extensions.
+    """
+
+    @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
+    def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        output = x.logsumexp(dim)
+        ctx.dim = dim
+        ctx.save_for_backward(x, output)
+        return output.clone()
+
+    @staticmethod
+    @torch.cuda.amp.custom_bwd
+    def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]:
+        x, output, dim = *ctx.saved_tensors, ctx.dim
+        g, output = g.unsqueeze(dim), output.unsqueeze(dim)
+        mask = g.eq(0).expand_as(x)
+        grad = g * (x - output).exp()
+        return torch.where(mask, x.new_tensor(0.), grad), None
+
+
+class Logaddexp(Function):
+
+    @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
+    def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        output = torch.logaddexp(x, y)
+        ctx.save_for_backward(x, y, output)
+        return output.clone()
+
+    @staticmethod
+    @torch.cuda.amp.custom_bwd
+    def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]:
+        x, y, output = ctx.saved_tensors
+        mask = g.eq(0)
+        grad_x, grad_y = (x - output).exp(), (y - output).exp()
+        grad_x = torch.where(mask, x.new_tensor(0.), grad_x)
+        grad_y = torch.where(mask, y.new_tensor(0.), grad_y)
+        return grad_x, grad_y
+
+
+class SampledLogsumexp(Function):
+
+    @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
+    def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        ctx.dim = dim
+        ctx.save_for_backward(x)
+        return x.logsumexp(dim=dim)
+
+    @staticmethod
+    @torch.cuda.amp.custom_bwd
+    def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]:
+        from torch.distributions import OneHotCategorical
+        (x, ), dim = ctx.saved_tensors, ctx.dim
+        return g.unsqueeze(dim).mul(OneHotCategorical(logits=x.movedim(dim, -1)).sample().movedim(-1, dim)), None
+
+
+class Sparsemax(Function):
+
+    @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float)
+    def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        ctx.dim = dim
+        sorted_x, _ = x.sort(dim, True)
+        z = sorted_x.cumsum(dim) - 1
+        k = x.new_tensor(range(1, sorted_x.size(dim) + 1)).view(-1, *[1] * (x.dim() - 1)).transpose(0, dim)
+        k = (k * sorted_x).gt(z).sum(dim, True)
+        tau = z.gather(dim, k - 1) / k
+        p = torch.clamp(x - tau, 0)
+        ctx.save_for_backward(k, p)
+        return p
+
+    @staticmethod
+    @torch.cuda.amp.custom_bwd
+    def backward(ctx, g: torch.Tensor) -> Tuple[torch.Tensor, None]:
+        k, p, dim = *ctx.saved_tensors, ctx.dim
+        grad = g.masked_fill(p.eq(0), 0)
+        grad = torch.where(p.ne(0), grad - grad.sum(dim, True) / k, grad)
+        return grad, None
+
+
+logsumexp = Logsumexp.apply
+
+logaddexp = Logaddexp.apply
+
+sampled_logsumexp = SampledLogsumexp.apply
+
+sparsemax = Sparsemax.apply
diff --git a/tania_scripts/supar/structs/semiring.py b/tania_scripts/supar/structs/semiring.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b66beee4f2e681a12cb7329e26c8880020077af
--- /dev/null
+++ b/tania_scripts/supar/structs/semiring.py
@@ -0,0 +1,406 @@
+# -*- coding: utf-8 -*-
+
+import itertools
+from functools import reduce
+from typing import Iterable
+
+import torch
+from supar.structs.fn import sampled_logsumexp, sparsemax
+from supar.utils.common import MIN
+
+
+class Semiring(object):
+    r"""
+    Base semiring class :cite:`goodman-1999-semiring`.
+
+    A semiring is defined by a tuple :math:`<K, \oplus, \otimes, \mathbf{0}, \mathbf{1}>`.
+    :math:`K` is a set of values;
+    :math:`\oplus` is commutative, associative and has an identity element `0`;
+    :math:`\otimes` is associative, has an identity element `1` and distributes over `+`.
+    """
+
+    zero = 0
+    one = 1
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return x + y
+
+    @classmethod
+    def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return x * y
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.sum(dim)
+
+    @classmethod
+    def prod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.prod(dim)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.cumsum(dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.cumprod(dim)
+
+    @classmethod
+    def dot(cls, x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return cls.sum(cls.mul(x, y), dim)
+
+    @classmethod
+    def times(cls, *x: Iterable[torch.Tensor]) -> torch.Tensor:
+        return reduce(lambda i, j: cls.mul(i, j), x)
+
+    @classmethod
+    def zero_(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.fill_(cls.zero)
+
+    @classmethod
+    def one_(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.fill_(cls.one)
+
+    @classmethod
+    def zero_mask(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        return x.masked_fill(mask, cls.zero)
+
+    @classmethod
+    def zero_mask_(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        return x.masked_fill_(mask, cls.zero)
+
+    @classmethod
+    def one_mask(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        return x.masked_fill(mask, cls.one)
+
+    @classmethod
+    def one_mask_(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+        return x.masked_fill_(mask, cls.one)
+
+    @classmethod
+    def zeros_like(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.new_full(x.shape, cls.zero)
+
+    @classmethod
+    def ones_like(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.new_full(x.shape, cls.one)
+
+    @classmethod
+    def convert(cls, x: torch.Tensor) -> torch.Tensor:
+        return x
+
+    @classmethod
+    def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
+        return x
+
+
+class LogSemiring(Semiring):
+    r"""
+    Log-space semiring :math:`<\mathrm{logsumexp}, +, -\infty, 0>`.
+    """
+
+    zero = MIN
+    one = 0
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return x.logaddexp(y)
+
+    @classmethod
+    def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return x + y
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.logsumexp(dim)
+
+    @classmethod
+    def prod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.sum(dim)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.logcumsumexp(dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.cumsum(dim)
+
+
+class MaxSemiring(LogSemiring):
+    r"""
+    Max semiring :math:`<\mathrm{max}, +, -\infty, 0>`.
+    """
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return x.max(y)
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.max(dim)[0]
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.cummax(dim)
+
+
+def KMaxSemiring(k):
+    r"""
+    k-max semiring :math:`<\mathrm{kmax}, +, [-\infty, -\infty, \dots], [0, -\infty, \dots]>`.
+    """
+
+    class KMaxSemiring(LogSemiring):
+
+        @classmethod
+        def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+            return x.unsqueeze(-1).max(y.unsqueeze(-2)).flatten(-2).topk(k, -1)[0]
+
+        @classmethod
+        def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+            return (x.unsqueeze(-1) + y.unsqueeze(-2)).flatten(-2).topk(k, -1)[0]
+
+        @classmethod
+        def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+            return x.movedim(dim, -1).flatten(-2).topk(k, -1)[0]
+
+        @classmethod
+        def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+            return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
+
+        @classmethod
+        def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+            return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
+
+        @classmethod
+        def one_(cls, x: torch.Tensor) -> torch.Tensor:
+            x[..., :1].fill_(cls.one)
+            x[..., 1:].fill_(cls.zero)
+            return x
+
+        @classmethod
+        def convert(cls, x: torch.Tensor) -> torch.Tensor:
+            return torch.cat((x.unsqueeze(-1), cls.zero_(x.new_empty(*x.shape, k - 1))), -1)
+
+    return KMaxSemiring
+
+
+class ExpectationSemiring(Semiring):
+    r"""
+    Expectation semiring :math:`<\oplus, +, [0, 0], [1, 0]>` :cite:`li-eisner-2009-first`.
+
+    Practical Applications: :math:`H(p) = \log Z - \frac{1}{Z}\sum_{d \in D} p(d) r(d)`.
+    """
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return x + y
+
+    @classmethod
+    def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return torch.stack((x[..., 0] * y[..., 0], x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]), -1)
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return x.sum(dim)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
+
+    @classmethod
+    def zero_(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.fill_(cls.zero)
+
+    @classmethod
+    def one_(cls, x: torch.Tensor) -> torch.Tensor:
+        x[..., 0].fill_(cls.one)
+        x[..., 1].fill_(cls.zero)
+        return x
+
+
+class EntropySemiring(LogSemiring):
+    r"""
+    Entropy expectation semiring :math:`<\oplus, +, [-\infty, 0], [0, 0]>`,
+    where :math:`\oplus` computes the log-values and the running distributional entropy :math:`H[p]`
+    :cite:`li-eisner-2009-first,hwa-2000-sample,kim-etal-2019-unsupervised`.
+    """
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return cls.sum(torch.stack((x, y)), 0)
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        p = x[..., 0].logsumexp(dim)
+        r = x[..., 0] - p.unsqueeze(dim)
+        r = r.exp().mul((x[..., 1] - r)).sum(dim)
+        return torch.stack((p, r), -1)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
+
+    @classmethod
+    def zero_(cls, x: torch.Tensor) -> torch.Tensor:
+        x[..., 0].fill_(cls.zero)
+        x[..., 1].fill_(cls.one)
+        return x
+
+    @classmethod
+    def one_(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.fill_(cls.one)
+
+    @classmethod
+    def convert(cls, x: torch.Tensor) -> torch.Tensor:
+        return torch.stack((x, cls.ones_like(x)), -1)
+
+    @classmethod
+    def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
+        return x[..., 1]
+
+
+class CrossEntropySemiring(LogSemiring):
+    r"""
+    Cross Entropy expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
+    where :math:`\oplus` computes the log-values and the running distributional cross entropy :math:`H[p,q]`
+    of the two distributions :cite:`li-eisner-2009-first`.
+    """
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return cls.sum(torch.stack((x, y)), 0)
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        p = x[..., :-1].logsumexp(dim)
+        r = x[..., :-1] - p.unsqueeze(dim)
+        r = r[..., 0].exp().mul((x[..., -1] - r[..., 1])).sum(dim)
+        return torch.cat((p, r.unsqueeze(-1)), -1)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
+
+    @classmethod
+    def zero_(cls, x: torch.Tensor) -> torch.Tensor:
+        x[..., :-1].fill_(cls.zero)
+        x[..., -1].fill_(cls.one)
+        return x
+
+    @classmethod
+    def one_(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.fill_(cls.one)
+
+    @classmethod
+    def convert(cls, x: torch.Tensor) -> torch.Tensor:
+        return torch.cat((x, cls.one_(torch.empty_like(x[..., :1]))), -1)
+
+    @classmethod
+    def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
+        return x[..., -1]
+
+
+class KLDivergenceSemiring(LogSemiring):
+    r"""
+    KL divergence expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
+    where :math:`\oplus` computes the log-values and the running distributional KL divergence :math:`KL[p \parallel q]`
+    of the two distributions :cite:`li-eisner-2009-first`.
+    """
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return cls.sum(torch.stack((x, y)), 0)
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        p = x[..., :-1].logsumexp(dim)
+        r = x[..., :-1] - p.unsqueeze(dim)
+        r = r[..., 0].exp().mul((x[..., -1] - r[..., 1] + r[..., 0])).sum(dim)
+        return torch.cat((p, r.unsqueeze(-1)), -1)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
+
+    @classmethod
+    def zero_(cls, x: torch.Tensor) -> torch.Tensor:
+        x[..., :-1].fill_(cls.zero)
+        x[..., -1].fill_(cls.one)
+        return x
+
+    @classmethod
+    def one_(cls, x: torch.Tensor) -> torch.Tensor:
+        return x.fill_(cls.one)
+
+    @classmethod
+    def convert(cls, x: torch.Tensor) -> torch.Tensor:
+        return torch.cat((x, cls.one_(torch.empty_like(x[..., :1]))), -1)
+
+    @classmethod
+    def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
+        return x[..., -1]
+
+
+class SampledSemiring(LogSemiring):
+    r"""
+    Sampling semiring :math:`<\mathrm{logsumexp}, +, -\infty, 0>`,
+    which is an exact forward-filtering, backward-sampling approach.
+    """
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return cls.sum(torch.stack((x, y)), 0)
+
+    @classmethod
+    def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return sampled_logsumexp(x, dim)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
+
+
+class SparsemaxSemiring(LogSemiring):
+    r"""
+    Sparsemax semiring :math:`<\mathrm{sparsemax}, +, -\infty, 0>`
+    :cite:`martins-etal-2016-sparsemax,mensch-etal-2018-dp,correia-etal-2020-efficient`.
+    """
+
+    @classmethod
+    def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        return cls.sum(torch.stack((x, y)), 0)
+
+    @staticmethod
+    def sum(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        p = sparsemax(x, dim)
+        return x.mul(p).sum(dim) - p.norm(p=2, dim=dim)
+
+    @classmethod
+    def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
+
+    @classmethod
+    def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
diff --git a/tania_scripts/supar/structs/tree.py b/tania_scripts/supar/structs/tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1cc5554f9fbb1476916a31f25f3bc9821ee71d6
--- /dev/null
+++ b/tania_scripts/supar/structs/tree.py
@@ -0,0 +1,635 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from supar.structs.dist import StructuredDistribution
+from supar.structs.fn import mst
+from supar.structs.semiring import LogSemiring, Semiring
+from supar.utils.fn import diagonal_stripe, expanded_stripe, stripe
+from torch.distributions.utils import lazy_property
+
+
+class MatrixTree(StructuredDistribution):
+    r"""
+    MatrixTree for calculating partitions and marginals of non-projective dependency trees in :math:`O(n^3)`
+    by an adaptation of Kirchhoff's MatrixTree Theorem :cite:`koo-etal-2007-structured`.
+
+    Args:
+        scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+            Scores of all possible dependent-head pairs.
+        lens (~torch.LongTensor): ``[batch_size]``.
+            Sentence lengths for masking, regardless of root positions. Default: ``None``.
+        multiroot (bool):
+            If ``False``, requires the tree to contain only a single root. Default: ``True``.
+
+    Examples:
+        >>> from supar import MatrixTree
+        >>> batch_size, seq_len = 2, 5
+        >>> lens = torch.tensor([3, 4])
+        >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]])
+        >>> s1 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens)
+        >>> s2 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens)
+        >>> s1.max
+        tensor([0.7174, 3.7910], grad_fn=<SumBackward1>)
+        >>> s1.argmax
+        tensor([[0, 0, 1, 1, 0],
+                [0, 4, 1, 0, 3]])
+        >>> s1.log_partition
+        tensor([2.0229, 6.0558], grad_fn=<CopyBackwards>)
+        >>> s1.log_prob(arcs)
+        tensor([-3.2209, -2.5756], grad_fn=<SubBackward0>)
+        >>> s1.entropy
+        tensor([1.9711, 3.4497], grad_fn=<SubBackward0>)
+        >>> s1.kl(s2)
+        tensor([1.3354, 2.6914], grad_fn=<AddBackward0>)
+    """
+
+    def __init__(
+        self,
+        scores: torch.Tensor,
+        lens: Optional[torch.LongTensor] = None,
+        multiroot: bool = False
+    ) -> MatrixTree:
+        super().__init__(scores)
+
+        batch_size, seq_len, *_ = scores.shape
+        self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens
+        self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len)))
+        self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0)
+
+        self.multiroot = multiroot
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(multiroot={self.multiroot})"
+
+    def __add__(self, other):
+        return MatrixTree(torch.stack((self.scores, other.scores)), self.lens, self.multiroot)
+
+    @lazy_property
+    def max(self):
+        arcs = self.argmax
+        return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1)
+
+    @lazy_property
+    def argmax(self):
+        with torch.no_grad():
+            return mst(self.scores, self.mask, self.multiroot)
+
+    def kmax(self, k: int) -> torch.Tensor:
+        # TODO: Camerini algorithm
+        raise NotImplementedError
+
+    def sample(self):
+        raise NotImplementedError
+
+    @lazy_property
+    def entropy(self):
+        return self.log_partition - (self.marginals * self.scores).sum((-1, -2))
+
+    def cross_entropy(self, other: MatrixTree) -> torch.Tensor:
+        return other.log_partition - (self.marginals * other.scores).sum((-1, -2))
+
+    def kl(self, other: MatrixTree) -> torch.Tensor:
+        return other.log_partition - self.log_partition + (self.marginals * (self.scores - other.scores)).sum((-1, -2))
+
+    def score(self, value: torch.LongTensor, partial: bool = False) -> torch.Tensor:
+        arcs = value
+        if partial:
+            mask, lens = self.mask, self.lens
+            mask = mask.index_fill(1, self.lens.new_tensor(0), 1)
+            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+            arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1)
+            arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0)
+            scores = LogSemiring.zero_mask(self.scores, ~(arcs & mask))
+            return self.__class__(scores, lens, **self.kwargs).log_partition
+        return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1)
+
+    @torch.enable_grad()
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        s_arc = self.scores
+        batch_size, *_ = s_arc.shape
+        mask, lens = self.mask.index_fill(1, self.lens.new_tensor(0), 1), self.lens
+        # double precision to prevent overflows
+        s_arc = semiring.zero_mask(s_arc, ~(mask.unsqueeze(-1) & mask.unsqueeze(-2))).double()
+
+        # A(i, j) = exp(s(i, j))
+        m = s_arc.view(batch_size, -1).max(-1)[0]
+        A = torch.exp(s_arc - m.view(-1, 1, 1))
+
+        # Weighted degree matrix
+        # D(i, j) = sum_j(A(i, j)), if h == m
+        #           0,              otherwise
+        D = torch.zeros_like(A)
+        D.diagonal(0, 1, 2).copy_(A.sum(-1))
+        # Laplacian matrix
+        # L(i, j) = D(i, j) - A(i, j)
+        L = D - A
+        if not self.multiroot:
+            L.diagonal(0, 1, 2).add_(-A[..., 0])
+            L[..., 1] = A[..., 0]
+        L = nn.init.eye_(torch.empty_like(A[0])).repeat(batch_size, 1, 1).masked_scatter_(mask.unsqueeze(-1), L[mask])
+        L = L + nn.init.eye_(torch.empty_like(A[0])) * torch.finfo().tiny
+        # Z = L^(0, 0), the minor of L w.r.t row 0 and column 0
+        return (L[:, 1:, 1:].logdet() + m * lens).float()
+
+
+class DependencyCRF(StructuredDistribution):
+    r"""
+    First-order TreeCRF for projective dependency trees :cite:`eisner-2000-bilexical,zhang-etal-2020-efficient`.
+
+    Args:
+        scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+            Scores of all possible dependent-head pairs.
+        lens (~torch.LongTensor): ``[batch_size]``.
+            Sentence lengths for masking, regardless of root positions. Default: ``None``.
+        multiroot (bool):
+            If ``False``, requires the tree to contain only a single root. Default: ``True``.
+
+    Examples:
+        >>> from supar import DependencyCRF
+        >>> batch_size, seq_len = 2, 5
+        >>> lens = torch.tensor([3, 4])
+        >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]])
+        >>> s1 = DependencyCRF(torch.randn(batch_size, seq_len, seq_len), lens)
+        >>> s2 = DependencyCRF(torch.randn(batch_size, seq_len, seq_len), lens)
+        >>> s1.max
+        tensor([3.6346, 1.7194], grad_fn=<IndexBackward>)
+        >>> s1.argmax
+        tensor([[0, 2, 3, 0, 0],
+                [0, 0, 3, 1, 1]])
+        >>> s1.log_partition
+        tensor([4.1007, 3.3383], grad_fn=<IndexBackward>)
+        >>> s1.log_prob(arcs)
+        tensor([-1.3866, -5.5352], grad_fn=<SubBackward0>)
+        >>> s1.entropy
+        tensor([0.9979, 2.6056], grad_fn=<IndexBackward>)
+        >>> s1.kl(s2)
+        tensor([1.6631, 2.6558], grad_fn=<IndexBackward>)
+    """
+
+    def __init__(
+        self,
+        scores: torch.Tensor,
+        lens: Optional[torch.LongTensor] = None,
+        multiroot: bool = False
+    ) -> DependencyCRF:
+        super().__init__(scores)
+
+        batch_size, seq_len, *_ = scores.shape
+        self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens
+        self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len)))
+        self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0)
+
+        self.multiroot = multiroot
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(multiroot={self.multiroot})"
+
+    def __add__(self, other):
+        return DependencyCRF(torch.stack((self.scores, other.scores), -1), self.lens, self.multiroot)
+
+    @lazy_property
+    def argmax(self):
+        return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2])
+
+    def topk(self, k: int) -> torch.LongTensor:
+        preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1)
+        return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
+
+    def score(self, value: torch.Tensor, partial: bool = False) -> torch.Tensor:
+        arcs = value
+        if partial:
+            mask, lens = self.mask, self.lens
+            mask = mask.index_fill(1, self.lens.new_tensor(0), 1)
+            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+            arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1)
+            arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0)
+            scores = LogSemiring.zero_mask(self.scores, ~(arcs & mask))
+            return self.__class__(scores, lens, **self.kwargs).log_partition
+        return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1)
+
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        s_arc = self.scores
+        batch_size, seq_len = s_arc.shape[:2]
+        # [seq_len, seq_len, batch_size, ...], (h->m)
+        s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0)))
+        s_i = semiring.zeros_like(s_arc)
+        s_c = semiring.zeros_like(s_arc)
+        semiring.one_(s_c.diagonal().movedim(-1, 1))
+
+        for w in range(1, seq_len):
+            n = seq_len - w
+
+            # [n, batch_size, ...]
+            il = ir = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1)
+            # INCOMPLETE-L: I(j->i) = <C(i->r), C(j->r+1)> * s(j->i), i <= r < j
+            # fill the w-th diagonal of the lower triangular part of s_i with I(j->i) of n spans
+            s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1))
+            # INCOMPLETE-R: I(i->j) = <C(i->r), C(j->r+1)> * s(i->j), i <= r < j
+            # fill the w-th diagonal of the upper triangular part of s_i with I(i->j) of n spans
+            s_i.diagonal(w).copy_(semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1))
+
+            # [n, batch_size, ...]
+            # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j
+            cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1)
+            s_c.diagonal(-w).copy_(cl.movedim(0, -1))
+            # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j
+            cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1)
+            s_c.diagonal(w).copy_(cr.movedim(0, -1))
+            if not self.multiroot:
+                s_c[0, w][self.lens.ne(w)] = semiring.zero
+        return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
+
+
+class Dependency2oCRF(StructuredDistribution):
+    r"""
+    Second-order TreeCRF for projective dependency trees :cite:`mcdonald-pereira-2006-online,zhang-etal-2020-efficient`.
+
+    Args:
+        scores (tuple(~torch.Tensor, ~torch.Tensor)):
+            Scores of all possible dependent-head pairs (``[batch_size, seq_len, seq_len]``) and
+            dependent-head-sibling triples ``[batch_size, seq_len, seq_len, seq_len]``.
+        lens (~torch.LongTensor): ``[batch_size]``.
+            Sentence lengths for masking, regardless of root positions. Default: ``None``.
+        multiroot (bool):
+            If ``False``, requires the tree to contain only a single root. Default: ``True``.
+
+    Examples:
+        >>> from supar import Dependency2oCRF
+        >>> batch_size, seq_len = 2, 5
+        >>> lens = torch.tensor([3, 4])
+        >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]])
+        >>> sibs = torch.tensor([CoNLL.get_sibs(i) for i in arcs[:, 1:].tolist()])
+        >>> s1 = Dependency2oCRF((torch.randn(batch_size, seq_len, seq_len),
+                                  torch.randn(batch_size, seq_len, seq_len, seq_len)),
+                                 lens)
+        >>> s2 = Dependency2oCRF((torch.randn(batch_size, seq_len, seq_len),
+                                  torch.randn(batch_size, seq_len, seq_len, seq_len)),
+                                 lens)
+        >>> s1.max
+        tensor([0.7574, 3.3634], grad_fn=<IndexBackward>)
+        >>> s1.argmax
+        tensor([[0, 3, 3, 0, 0],
+                [0, 4, 4, 4, 0]])
+        >>> s1.log_partition
+        tensor([1.9906, 4.3599], grad_fn=<IndexBackward>)
+        >>> s1.log_prob((arcs, sibs))
+        tensor([-0.6975, -6.2845], grad_fn=<SubBackward0>)
+        >>> s1.entropy
+        tensor([1.6436, 2.1717], grad_fn=<IndexBackward>)
+        >>> s1.kl(s2)
+        tensor([0.4929, 2.0759], grad_fn=<IndexBackward>)
+    """
+
+    def __init__(
+        self,
+        scores: Tuple[torch.Tensor, torch.Tensor],
+        lens: Optional[torch.LongTensor] = None,
+        multiroot: bool = False
+    ) -> Dependency2oCRF:
+        super().__init__(scores)
+
+        batch_size, seq_len, *_ = scores[0].shape
+        self.lens = scores[0].new_full((batch_size,), seq_len-1).long() if lens is None else lens
+        self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len)))
+        self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0)
+
+        self.multiroot = multiroot
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(multiroot={self.multiroot})"
+
+    def __add__(self, other):
+        return Dependency2oCRF([torch.stack((i, j), -1) for i, j in zip(self.scores, other.scores)], self.lens, self.multiroot)
+
+    @lazy_property
+    def argmax(self):
+        return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask,
+                                                                    torch.where(self.backward(self.max.sum())[0])[2])
+
+    def topk(self, k: int) -> torch.LongTensor:
+        preds = torch.stack([torch.where(self.backward(i)[0])[2] for i in self.kmax(k).sum(0)], -1)
+        return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
+
+    def score(self, value: Tuple[torch.LongTensor, torch.LongTensor], partial: bool = False) -> torch.Tensor:
+        arcs, sibs = value
+        if partial:
+            mask, lens = self.mask, self.lens
+            mask = mask.index_fill(1, self.lens.new_tensor(0), 1)
+            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
+            arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1)
+            arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0)
+            s_arc, s_sib = LogSemiring.zero_mask(self.scores[0], ~(arcs & mask)), self.scores[1]
+            return self.__class__((s_arc, s_sib), lens, **self.kwargs).log_partition
+        s_arc = self.scores[0].gather(-1, arcs.unsqueeze(-1)).squeeze(-1)
+        s_arc = LogSemiring.prod(LogSemiring.one_mask(s_arc, ~self.mask), -1)
+        s_sib = self.scores[1].gather(-1, sibs.unsqueeze(-1)).squeeze(-1)
+        s_sib = LogSemiring.prod(LogSemiring.one_mask(s_sib, ~sibs.gt(0)), (-1, -2))
+        return LogSemiring.mul(s_arc, s_sib)
+
+    @torch.enable_grad()
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        s_arc, s_sib = self.scores
+        batch_size, seq_len = s_arc.shape[:2]
+        # [seq_len, seq_len, batch_size, ...], (h->m)
+        s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0)))
+        # [seq_len, seq_len, seq_len, batch_size, ...], (h->m->s)
+        s_sib = semiring.convert(s_sib.movedim((0, 2), (3, 0)))
+        s_i = semiring.zeros_like(s_arc)
+        s_s = semiring.zeros_like(s_arc)
+        s_c = semiring.zeros_like(s_arc)
+        semiring.one_(s_c.diagonal().movedim(-1, 1))
+
+        for w in range(1, seq_len):
+            n = seq_len - w
+
+            # INCOMPLETE-L: I(j->i) = <I(j->r), S(j->r, i)> * s(j->i), i < r < j
+            #                         <C(j->j), C(i->j-1)>  * s(j->i), otherwise
+            # [n, w, batch_size, ...]
+            il = semiring.times(stripe(s_i, n, w, (w, 1)),
+                                stripe(s_s, n, w, (1, 0), 0),
+                                stripe(s_sib[range(w, n+w), range(n), :], n, w, (0, 1)))
+            il[:, -1] = semiring.mul(stripe(s_c, n, 1, (w, w)), stripe(s_c, n, 1, (0, w - 1))).squeeze(1)
+            il = semiring.sum(il, 1)
+            s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1))
+            # INCOMPLETE-R: I(i->j) = <I(i->r), S(i->r, j)> * s(i->j), i < r < j
+            #                         <C(i->i), C(j->i+1)>  * s(i->j), otherwise
+            # [n, w, batch_size, ...]
+            ir = semiring.times(stripe(s_i, n, w),
+                                stripe(s_s, n, w, (0, w), 0),
+                                stripe(s_sib[range(n), range(w, n+w), :], n, w))
+            if not self.multiroot:
+                semiring.zero_(ir[0])
+            ir[:, 0] = semiring.mul(stripe(s_c, n, 1), stripe(s_c, n, 1, (w, 1))).squeeze(1)
+            ir = semiring.sum(ir, 1)
+            s_i.diagonal(w).copy_(semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1))
+
+            # [batch_size, ..., n]
+            sl = sr = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1).movedim(0, -1)
+            # SIB: S(j, i) = <C(i->r), C(j->r+1)>, i <= r < j
+            s_s.diagonal(-w).copy_(sl)
+            # SIB: S(i, j) = <C(i->r), C(j->r+1)>, i <= r < j
+            s_s.diagonal(w).copy_(sr)
+
+            # [n, batch_size, ...]
+            # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j
+            cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1)
+            s_c.diagonal(-w).copy_(cl.movedim(0, -1))
+            # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j
+            cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1)
+            s_c.diagonal(w).copy_(cr.movedim(0, -1))
+        return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
+
+
+class ConstituencyCRF(StructuredDistribution):
+    r"""
+    Constituency TreeCRF :cite:`zhang-etal-2020-fast,stern-etal-2017-minimal`.
+
+    Args:
+        scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+            Scores of all constituents.
+        lens (~torch.LongTensor): ``[batch_size]``.
+            Sentence lengths for masking.
+
+    Examples:
+        >>> from supar import ConstituencyCRF
+        >>> batch_size, seq_len, n_labels = 2, 5, 4
+        >>> lens = torch.tensor([3, 4])
+        >>> charts = torch.tensor([[[-1,  0, -1,  0, -1],
+                                    [-1, -1,  0,  0, -1],
+                                    [-1, -1, -1,  0, -1],
+                                    [-1, -1, -1, -1, -1],
+                                    [-1, -1, -1, -1, -1]],
+                                   [[-1,  0,  0, -1,  0],
+                                    [-1, -1,  0, -1, -1],
+                                    [-1, -1, -1,  0,  0],
+                                    [-1, -1, -1, -1,  0],
+                                    [-1, -1, -1, -1, -1]]])
+        >>> s1 = ConstituencyCRF(torch.randn(batch_size, seq_len, seq_len, n_labels), lens, True)
+        >>> s2 = ConstituencyCRF(torch.randn(batch_size, seq_len, seq_len, n_labels), lens, True)
+        >>> s1.max
+        tensor([3.7036, 7.2569], grad_fn=<IndexBackward0>)
+        >>> s1.argmax
+        [[[0, 1, 2], [0, 3, 0], [1, 2, 1], [1, 3, 0], [2, 3, 3]],
+         [[0, 1, 1], [0, 4, 2], [1, 2, 3], [1, 4, 1], [2, 3, 2], [2, 4, 3], [3, 4, 3]]]
+        >>> s1.log_partition
+        tensor([ 8.5394, 12.9940], grad_fn=<IndexBackward0>)
+        >>> s1.log_prob(charts)
+        tensor([ -8.5209, -14.1160], grad_fn=<SubBackward0>)
+        >>> s1.entropy
+        tensor([6.8868, 9.3996], grad_fn=<IndexBackward0>)
+        >>> s1.kl(s2)
+        tensor([4.0039, 4.1037], grad_fn=<IndexBackward0>)
+    """
+
+    def __init__(
+        self,
+        scores: torch.Tensor,
+        lens: Optional[torch.LongTensor] = None,
+        label: bool = False
+    ) -> ConstituencyCRF:
+        super().__init__(scores)
+
+        batch_size, seq_len, *_ = scores.shape
+        self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens
+        self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len)))
+        self.mask = self.mask.unsqueeze(1) & scores.new_ones(scores.shape[:3]).bool().triu_(1)
+        self.label = label
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(label={self.label})"
+
+    def __add__(self, other):
+        return ConstituencyCRF(torch.stack((self.scores, other.scores), -1), self.lens, self.label)
+
+    @lazy_property
+    def argmax(self):
+        return [torch.nonzero(i).tolist() for i in self.backward(self.max.sum())]
+
+    def topk(self, k: int) -> List[List[Tuple]]:
+        return list(zip(*[[torch.nonzero(j).tolist() for j in self.backward(i)] for i in self.kmax(k).sum(0)]))
+
+    def score(self, value: torch.LongTensor) -> torch.Tensor:
+        mask = self.mask & value.ge(0)
+        if self.label:
+            scores = self.scores[mask].gather(-1, value[mask].unsqueeze(-1)).squeeze(-1)
+            scores = torch.full_like(mask, LogSemiring.one, dtype=scores.dtype).masked_scatter_(mask, scores)
+        else:
+            scores = LogSemiring.one_mask(self.scores, ~mask)
+        return LogSemiring.prod(LogSemiring.prod(scores, -1), -1)
+
+    @torch.enable_grad()
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        batch_size, seq_len = self.scores.shape[:2]
+        # [seq_len, seq_len, batch_size, ...], (l->r)
+        scores = semiring.convert(self.scores.movedim((1, 2), (0, 1)))
+        scores = semiring.sum(scores, 3) if self.label else scores
+        s = semiring.zeros_like(scores)
+        s.diagonal(1).copy_(scores.diagonal(1))
+
+        for w in range(2, seq_len):
+            n = seq_len - w
+            # [n, batch_size, ...]
+            s_s = semiring.dot(stripe(s, n, w-1, (0, 1)), stripe(s, n, w-1, (1, w), False), 1)
+            s.diagonal(w).copy_(semiring.mul(s_s, scores.diagonal(w).movedim(-1, 0)).movedim(0, -1))
+        return semiring.unconvert(s)[0][self.lens, range(batch_size)]
+
+
+class BiLexicalizedConstituencyCRF(StructuredDistribution):
+    r"""
+    Grammarless Eisner-Satta Algorithm :cite:`eisner-satta-1999-efficient,yang-etal-2021-neural`.
+
+    Code is revised from `Songlin Yang's implementation <https://github.com/sustcsonglin/span-based-dependency-parsing>`_.
+
+    Args:
+        scores (~torch.Tensor): ``[2, batch_size, seq_len, seq_len]``.
+            Scores of dependencies and constituents.
+        lens (~torch.LongTensor): ``[batch_size]``.
+            Sentence lengths for masking.
+
+    Examples:
+        >>> from supar import BiLexicalizedConstituencyCRF
+        >>> batch_size, seq_len = 2, 5
+        >>> lens = torch.tensor([3, 4])
+        >>> deps = torch.tensor([[0, 0, 1, 1, 0], [0, 3, 1, 0, 3]])
+        >>> cons = torch.tensor([[[0, 1, 1, 1, 0],
+                                  [0, 0, 1, 0, 0],
+                                  [0, 0, 0, 1, 0],
+                                  [0, 0, 0, 0, 0],
+                                  [0, 0, 0, 0, 0]],
+                                 [[0, 1, 1, 1, 1],
+                                  [0, 0, 1, 0, 0],
+                                  [0, 0, 0, 1, 0],
+                                  [0, 0, 0, 0, 1],
+                                  [0, 0, 0, 0, 0]]]).bool()
+        >>> heads = torch.tensor([[[0, 1, 1, 1, 0],
+                                   [0, 0, 2, 0, 0],
+                                   [0, 0, 0, 3, 0],
+                                   [0, 0, 0, 0, 0],
+                                   [0, 0, 0, 0, 0]],
+                                  [[0, 1, 1, 3, 3],
+                                   [0, 0, 2, 0, 0],
+                                   [0, 0, 0, 3, 0],
+                                   [0, 0, 0, 0, 4],
+                                   [0, 0, 0, 0, 0]]])
+        >>> s1 = BiLexicalizedConstituencyCRF((torch.randn(batch_size, seq_len, seq_len),
+                                               torch.randn(batch_size, seq_len, seq_len),
+                                               torch.randn(batch_size, seq_len, seq_len, seq_len)),
+                                              lens)
+        >>> s2 = BiLexicalizedConstituencyCRF((torch.randn(batch_size, seq_len, seq_len),
+                                               torch.randn(batch_size, seq_len, seq_len),
+                                               torch.randn(batch_size, seq_len, seq_len, seq_len)),
+                                              lens)
+        >>> s1.max
+        tensor([0.5792, 2.1737], grad_fn=<MaxBackward0>)
+        >>> s1.argmax[0]
+        tensor([[0, 3, 1, 0, 0],
+                [0, 4, 1, 1, 0]])
+        >>> s1.argmax[1]
+        [[[0, 3], [0, 2], [0, 1], [1, 2], [2, 3]], [[0, 4], [0, 3], [0, 2], [0, 1], [1, 2], [2, 3], [3, 4]]]
+        >>> s1.log_partition
+        tensor([1.1923, 3.2343], grad_fn=<LogsumexpBackward>)
+        >>> s1.log_prob((deps, cons, heads))
+        tensor([-1.9123, -3.6127], grad_fn=<SubBackward0>)
+        >>> s1.entropy
+        tensor([1.3376, 2.2996], grad_fn=<SelectBackward>)
+        >>> s1.kl(s2)
+        tensor([1.0617, 2.7839], grad_fn=<SelectBackward>)
+    """
+
+    def __init__(
+        self,
+        scores: List[torch.Tensor],
+        lens: Optional[torch.LongTensor] = None
+    ) -> BiLexicalizedConstituencyCRF:
+        super().__init__(scores)
+
+        batch_size, seq_len, *_ = scores[1].shape
+        self.lens = scores[1].new_full((batch_size,), seq_len-1).long() if lens is None else lens
+        self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len)))
+        self.mask = self.mask.unsqueeze(1) & scores[1].new_ones(scores[1].shape[:3]).bool().triu_(1)
+
+    def __add__(self, other):
+        return BiLexicalizedConstituencyCRF([torch.stack((i, j), -1) for i, j in zip(self.scores, other.scores)], self.lens)
+
+    @lazy_property
+    def argmax(self):
+        marginals = self.backward(self.max.sum())
+        dep_mask = self.mask[:, 0]
+        dep = self.lens.new_zeros(dep_mask.shape).masked_scatter_(dep_mask, torch.where(marginals[0])[2])
+        con = [torch.nonzero(i).tolist() for i in marginals[1]]
+        return dep, con
+
+    def topk(self, k: int) -> Tuple[torch.LongTensor, List[List[Tuple]]]:
+        dep_mask = self.mask[:, 0]
+        marginals = [self.backward(i) for i in self.kmax(k).sum(0)]
+        dep_preds = torch.stack([torch.where(i)[2] for i in marginals[0]], -1)
+        dep_preds = self.lens.new_zeros(*dep_mask.shape, k).masked_scatter_(dep_mask.unsqueeze(-1), dep_preds)
+        con_preds = list(zip(*[[torch.nonzero(j).tolist() for j in i] for i in marginals[1]]))
+        return dep_preds, con_preds
+
+    def score(self, value: List[Union[torch.LongTensor, torch.BoolTensor]], partial: bool = False) -> torch.Tensor:
+        deps, cons, heads = value
+        s_dep, s_con, s_head = self.scores
+        mask, lens = self.mask, self.lens
+        dep_mask, con_mask = mask[:, 0], mask
+        if partial:
+            if deps is not None:
+                dep_mask = dep_mask.index_fill(1, self.lens.new_tensor(0), 1)
+                dep_mask = dep_mask.unsqueeze(1) & dep_mask.unsqueeze(2)
+                deps = deps.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1)
+                deps = deps.eq(lens.new_tensor(range(mask.shape[1]))) | deps.lt(0)
+                s_dep = LogSemiring.zero_mask(s_dep, ~(deps & dep_mask))
+            if cons is not None:
+                s_con = LogSemiring.zero_mask(s_con, ~(cons & con_mask))
+            if heads is not None:
+                head_mask = heads.unsqueeze(-1).eq(lens.new_tensor(range(mask.shape[1])))
+                head_mask = head_mask & con_mask.unsqueeze(-1)
+                s_head = LogSemiring.zero_mask(s_head, ~head_mask)
+            return self.__class__((s_dep, s_con, s_head), lens, **self.kwargs).log_partition
+        s_dep = LogSemiring.prod(LogSemiring.one_mask(s_dep.gather(-1, deps.unsqueeze(-1)).squeeze(-1), ~dep_mask), -1)
+        s_head = LogSemiring.mul(s_con, s_head.gather(-1, heads.unsqueeze(-1)).squeeze(-1))
+        s_head = LogSemiring.prod(LogSemiring.prod(LogSemiring.one_mask(s_head, ~(con_mask & cons)), -1), -1)
+        return LogSemiring.mul(s_dep, s_head)
+
+    def forward(self, semiring: Semiring) -> torch.Tensor:
+        s_dep, s_con, s_head = self.scores
+        batch_size, seq_len, *_ = s_con.shape
+        # [seq_len, seq_len, batch_size, ...], (m<-h)
+        s_dep = semiring.convert(s_dep.movedim(0, 2))
+        s_root, s_dep = s_dep[1:, 0], s_dep[1:, 1:]
+        # [seq_len, seq_len, batch_size, ...], (i, j)
+        s_con = semiring.convert(s_con.movedim(0, 2))
+        # [seq_len, seq_len, seq_len-1, batch_size, ...], (i, j, h)
+        s_head = semiring.mul(s_con.unsqueeze(2), semiring.convert(s_head.movedim(0, -1)[:, :, 1:]))
+        # [seq_len, seq_len, seq_len-1, batch_size, ...], (i, j, h)
+        s_span = semiring.zeros_like(s_head)
+        # [seq_len, seq_len, seq_len-1, batch_size, ...], (i, j<-h)
+        s_hook = semiring.zeros_like(s_head)
+        diagonal_stripe(s_span, 1).copy_(diagonal_stripe(s_head, 1))
+        s_hook.diagonal(1).copy_(semiring.mul(s_dep, diagonal_stripe(s_head, 1)).movedim(0, -1))
+
+        for w in range(2, seq_len):
+            n = seq_len - w
+            # COMPLETE-L: s_span_l(i, j, h) = <s_span(i, k, h), s_hook(h->k, j)>, i < k < j
+            # [n, w, batch_size, ...]
+            s_l = stripe(semiring.dot(stripe(s_span, n, w-1, (0, 1)), stripe(s_hook, n, w-1, (1, w), False), 1), n, w)
+            # COMPLETE-R: s_span_r(i, j, h) = <s_hook(i, k<-h), s_span(k, j, h)>, i < k < j
+            # [n, w, batch_size, ...]
+            s_r = stripe(semiring.dot(stripe(s_hook, n, w-1, (0, 1)), stripe(s_span, n, w-1, (1, w), False), 1), n, w)
+            # COMPLETE: s_span(i, j, h) = (s_span_l(i, j, h) + s_span_r(i, j, h)) * s(i, j, h)
+            # [n, w, batch_size, ...]
+            s = semiring.mul(semiring.sum(torch.stack((s_l, s_r)), 0), diagonal_stripe(s_head, w))
+            diagonal_stripe(s_span, w).copy_(s)
+
+            if w == seq_len - 1:
+                continue
+            # ATTACH: s_hook(h->i, j) = <s(h->m), s_span(i, j, m)>, i <= m < j
+            # [n, seq_len, batch_size, ...]
+            s = semiring.dot(expanded_stripe(s_dep, n, w), diagonal_stripe(s_span, w).unsqueeze(2), 1)
+            s_hook.diagonal(w).copy_(s.movedim(0, -1))
+        return semiring.unconvert(semiring.dot(s_span[0][self.lens, :, range(batch_size)].transpose(0, 1), s_root, 0))
diff --git a/tania_scripts/supar/structs/vi.py b/tania_scripts/supar/structs/vi.py
new file mode 100644
index 0000000000000000000000000000000000000000..b848825e661ea0ecd0a258ba6c9498b8d7624a23
--- /dev/null
+++ b/tania_scripts/supar/structs/vi.py
@@ -0,0 +1,499 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import List, Optional, Tuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from supar.structs import DependencyCRF
+from supar.utils.common import MIN
+
+
+class DependencyMFVI(nn.Module):
+    r"""
+    Mean Field Variational Inference for approximately calculating marginals
+    of dependency trees :cite:`wang-tu-2020-second`.
+    """
+
+    def __init__(self, max_iter: int = 3) -> DependencyMFVI:
+        super().__init__()
+
+        self.max_iter = max_iter
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(max_iter={self.max_iter})"
+
+    @torch.enable_grad()
+    def forward(
+        self,
+        scores: List[torch.Tensor],
+        mask: torch.BoolTensor,
+        target: Optional[torch.LongTensor] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Args:
+            scores (~torch.Tensor, ~torch.Tensor):
+                Tuple of three tensors `s_arc` and `s_sib`.
+                `s_arc` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs.
+                `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask to avoid aggregation on padding tokens.
+            target (~torch.LongTensor): ``[batch_size, seq_len]``.
+                A Tensor of gold-standard dependent-head pairs. Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+                The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        logits = self.mfvi(*scores, mask)
+        marginals = logits.softmax(-1)
+
+        if target is None:
+            return marginals
+        loss = F.cross_entropy(logits[mask], target[mask])
+
+        return loss, marginals
+
+    def mfvi(self, s_arc, s_sib, mask):
+        batch_size, seq_len = mask.shape
+        ls, rs = torch.stack(torch.where(mask.new_ones(seq_len, seq_len))).view(-1, seq_len, seq_len).sort(0)[0]
+        mask = mask.index_fill(1, ls.new_tensor(0), 1)
+        # [seq_len, seq_len, batch_size], (h->m)
+        mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).permute(2, 1, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        mask2o = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1)
+        # [seq_len, seq_len, batch_size], (h->m)
+        s_arc = s_arc.permute(2, 1, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        s_sib = s_sib.permute(2, 1, 3, 0) * mask2o
+
+        # posterior distributions
+        # [seq_len, seq_len, batch_size], (h->m)
+        q = s_arc
+
+        for _ in range(self.max_iter):
+            q = q.softmax(0)
+            # q(ij) = s(ij) + sum(q(ik)s^sib(ij,ik)), k != i,j
+            q = s_arc + (q.unsqueeze(1) * s_sib).sum(2)
+
+        return q.permute(2, 1, 0)
+
+
+class DependencyLBP(nn.Module):
+    r"""
+    Loopy Belief Propagation for approximately calculating marginals
+    of dependency trees :cite:`smith-eisner-2008-dependency`.
+    """
+
+    def __init__(self, max_iter: int = 3) -> DependencyLBP:
+        super().__init__()
+
+        self.max_iter = max_iter
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(max_iter={self.max_iter})"
+
+    @torch.enable_grad()
+    def forward(
+        self,
+        scores: List[torch.Tensor],
+        mask: torch.BoolTensor,
+        target: Optional[torch.LongTensor] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Args:
+            scores (~torch.Tensor, ~torch.Tensor):
+                Tuple of three tensors `s_arc` and `s_sib`.
+                `s_arc` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs.
+                `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+                The mask to avoid aggregation on padding tokens.
+            target (~torch.LongTensor): ``[batch_size, seq_len]``.
+                A Tensor of gold-standard dependent-head pairs. Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+                The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        logits = self.lbp(*scores, mask)
+        marginals = logits.softmax(-1)
+
+        if target is None:
+            return marginals
+        loss = F.cross_entropy(logits[mask], target[mask])
+
+        return loss, marginals
+
+    def lbp(self, s_arc, s_sib, mask):
+        batch_size, seq_len = mask.shape
+        ls, rs = torch.stack(torch.where(mask.new_ones(seq_len, seq_len))).view(-1, seq_len, seq_len).sort(0)[0]
+        mask = mask.index_fill(1, ls.new_tensor(0), 1)
+        # [seq_len, seq_len, batch_size], (h->m)
+        mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).permute(2, 1, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        mask2o = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1)
+        # [seq_len, seq_len, batch_size], (h->m)
+        s_arc = s_arc.permute(2, 1, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        s_sib = s_sib.permute(2, 1, 3, 0).masked_fill_(~mask2o, MIN)
+
+        # log beliefs
+        # [seq_len, seq_len, batch_size], (h->m)
+        q = s_arc
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        m_sib = s_sib.new_zeros(seq_len, seq_len, seq_len, batch_size)
+
+        for _ in range(self.max_iter):
+            q = q.log_softmax(0)
+            # m(ik->ij) = logsumexp(q(ik) - m(ij->ik) + s(ij->ik))
+            m = q.unsqueeze(2) - m_sib
+            # TODO: better solution for OOM
+            m_sib = torch.logaddexp(m.logsumexp(0), m + s_sib).transpose(1, 2).log_softmax(0)
+            # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j
+            q = s_arc + (m_sib * mask2o).sum(2)
+
+        return q.permute(2, 1, 0)
+
+
+class ConstituencyMFVI(nn.Module):
+    r"""
+    Mean Field Variational Inference for approximately calculating marginals of constituent trees.
+    """
+
+    def __init__(self, max_iter: int = 3) -> ConstituencyMFVI:
+        super().__init__()
+
+        self.max_iter = max_iter
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(max_iter={self.max_iter})"
+
+    @torch.enable_grad()
+    def forward(
+        self,
+        scores: List[torch.Tensor],
+        mask: torch.BoolTensor,
+        target: Optional[torch.LongTensor] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Args:
+            scores (~torch.Tensor, ~torch.Tensor):
+                Tuple of two tensors `s_span` and `s_pair`.
+                `s_span` (``[batch_size, seq_len, seq_len]``) holds scores of all possible spans.
+                `s_pair` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of second-order triples.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask to avoid aggregation on padding tokens.
+            target (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                A Tensor of gold-standard dependent-head pairs. Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+                The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        logits = self.mfvi(*scores, mask)
+        marginals = logits.sigmoid()
+
+        if target is None:
+            return marginals
+        loss = F.binary_cross_entropy_with_logits(logits[mask], target[mask].float())
+
+        return loss, marginals
+
+    def mfvi(self, s_span, s_pair, mask):
+        batch_size, seq_len, _ = mask.shape
+        ls, rs = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len).sort(0)[0]
+        # [seq_len, seq_len, batch_size], (l->r)
+        mask = mask.movedim(0, 2)
+        # [seq_len, seq_len, seq_len, batch_size], (l->r->b)
+        mask2o = mask.unsqueeze(2).repeat(1, 1, seq_len, 1)
+        mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1)
+        # [seq_len, seq_len, batch_size], (l->r)
+        s_span = s_span.movedim(0, 2)
+        # [seq_len, seq_len, seq_len, batch_size], (l->r->b)
+        s_pair = s_pair.permute(1, 2, 3, 0) * mask2o
+
+        # posterior distributions
+        # [seq_len, seq_len, batch_size], (l->r)
+        q = s_span
+
+        for _ in range(self.max_iter):
+            q = q.sigmoid()
+            # q(ij) = s(ij) + sum(q(jk)*s^pair(ij,jk), k != i,j
+            q = s_span + (q.unsqueeze(1) * s_pair).sum(2)
+
+        return q.permute(2, 0, 1)
+
+
+class ConstituencyLBP(nn.Module):
+    r"""
+    Loopy Belief Propagation for approximately calculating marginals of constituent trees.
+    """
+
+    def __init__(self, max_iter: int = 3) -> ConstituencyLBP:
+        super().__init__()
+
+        self.max_iter = max_iter
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(max_iter={self.max_iter})"
+
+    @torch.enable_grad()
+    def forward(
+        self,
+        scores: List[torch.Tensor],
+        mask: torch.BoolTensor,
+        target: Optional[torch.LongTensor] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Args:
+            scores (~torch.Tensor, ~torch.Tensor):
+                Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`.
+                `s_span` (``[batch_size, seq_len, seq_len]``) holds scores of all possible spans.
+                `s_pair` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of second-order triples.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask to avoid aggregation on padding tokens.
+            target (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                A Tensor of gold-standard dependent-head pairs. Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+                The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        logits = self.lbp(*scores, mask)
+        marginals = logits.softmax(-1)[..., 1]
+
+        if target is None:
+            return marginals
+        loss = F.cross_entropy(logits[mask], target[mask].long())
+
+        return loss, marginals
+
+    def lbp(self, s_span, s_pair, mask):
+        batch_size, seq_len, _ = mask.shape
+        ls, rs = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len).sort(0)[0]
+        # [seq_len, seq_len, batch_size], (l->r)
+        mask = mask.movedim(0, 2)
+        # [seq_len, seq_len, seq_len, batch_size], (l->r->b)
+        mask2o = mask.unsqueeze(2).repeat(1, 1, seq_len, 1)
+        mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1)
+        # [2, seq_len, seq_len, batch_size], (l->r)
+        s_span = torch.stack((torch.zeros_like(s_span), s_span)).permute(0, 3, 2, 1)
+        # [seq_len, seq_len, seq_len, batch_size], (l->r->p)
+        s_pair = s_pair.permute(2, 1, 3, 0)
+
+        # log beliefs
+        # [2, seq_len, seq_len, batch_size], (h->m)
+        q = s_span
+        # [2, seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        m_pair = s_pair.new_zeros(2, seq_len, seq_len, seq_len, batch_size)
+
+        for _ in range(self.max_iter):
+            q = q.log_softmax(0)
+            # m(ik->ij) = logsumexp(q(ik) - m(ij->ik) + s(ij->ik))
+            m = q.unsqueeze(3) - m_pair
+            m_pair = torch.stack((m.logsumexp(0), torch.stack((m[0], m[1] + s_pair)).logsumexp(0))).log_softmax(0)
+            # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j
+            q = s_span + (m_pair.transpose(2, 3) * mask2o).sum(3)
+
+        return q.permute(3, 2, 1, 0)
+
+
+class SemanticDependencyMFVI(nn.Module):
+    r"""
+    Mean Field Variational Inference for approximately calculating marginals
+    of semantic dependency trees :cite:`wang-etal-2019-second`.
+    """
+
+    def __init__(self, max_iter: int = 3) -> SemanticDependencyMFVI:
+        super().__init__()
+
+        self.max_iter = max_iter
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(max_iter={self.max_iter})"
+
+    @torch.enable_grad()
+    def forward(
+        self,
+        scores: List[torch.Tensor],
+        mask: torch.BoolTensor,
+        target: Optional[torch.LongTensor] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Args:
+            scores (~torch.Tensor, ~torch.Tensor):
+                Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`.
+                `s_edge` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs.
+                `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples.
+                `s_cop` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-coparent triples.
+                `s_grd` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-grandparent triples.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask to avoid aggregation on padding tokens.
+            target (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                A Tensor of gold-standard dependent-head pairs. Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+                The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        logits = self.mfvi(*scores, mask)
+        marginals = logits.sigmoid()
+
+        if target is None:
+            return marginals
+        loss = F.binary_cross_entropy_with_logits(logits[mask], target[mask].float())
+
+        return loss, marginals
+
+    def mfvi(self, s_edge, s_sib, s_cop, s_grd, mask):
+        _, seq_len, _ = mask.shape
+        hs, ms = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len)
+        # [seq_len, seq_len, batch_size], (h->m)
+        mask = mask.permute(2, 1, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        mask2o = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask2o = mask2o & hs.unsqueeze(-1).ne(hs.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o = mask2o & ms.unsqueeze(-1).ne(ms.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o.diagonal().fill_(0)
+        # [seq_len, seq_len, batch_size], (h->m)
+        s_edge = s_edge.permute(2, 1, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        s_sib = s_sib.permute(2, 1, 3, 0) * mask2o
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->c)
+        s_cop = s_cop.permute(2, 1, 3, 0) * mask2o
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->g)
+        s_grd = s_grd.permute(2, 1, 3, 0) * mask2o
+
+        # posterior distributions
+        # [seq_len, seq_len, batch_size], (h->m)
+        q = s_edge
+
+        for _ in range(self.max_iter):
+            q = q.sigmoid()
+            # q(ij) = s(ij) + sum(q(ik)s^sib(ij,ik) + q(kj)s^cop(ij,kj) + q(jk)s^grd(ij,jk)), k != i,j
+            q = s_edge + (q.unsqueeze(1) * s_sib + q.transpose(0, 1).unsqueeze(0) * s_cop + q.unsqueeze(0) * s_grd).sum(2)
+
+        return q.permute(2, 1, 0)
+
+
+class SemanticDependencyLBP(nn.Module):
+    r"""
+    Loopy Belief Propagation for approximately calculating marginals
+    of semantic dependency trees :cite:`wang-etal-2019-second`.
+    """
+
+    def __init__(self, max_iter: int = 3) -> SemanticDependencyLBP:
+        super().__init__()
+
+        self.max_iter = max_iter
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(max_iter={self.max_iter})"
+
+    @torch.enable_grad()
+    def forward(
+        self,
+        scores: List[torch.Tensor],
+        mask: torch.BoolTensor,
+        target: Optional[torch.LongTensor] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Args:
+            scores (~torch.Tensor, ~torch.Tensor):
+                Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`.
+                `s_edge` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs.
+                `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples.
+                `s_cop` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-coparent triples.
+                `s_grd` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-grandparent triples.
+            mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+                The mask to avoid aggregation on padding tokens.
+            target (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+                A Tensor of gold-standard dependent-head pairs. Default: ``None``.
+
+        Returns:
+            ~torch.Tensor, ~torch.Tensor:
+                The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+                The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``.
+        """
+
+        logits = self.lbp(*scores, mask)
+        marginals = logits.softmax(-1)[..., 1]
+
+        if target is None:
+            return marginals
+        loss = F.cross_entropy(logits[mask], target[mask])
+
+        return loss, marginals
+
+    def lbp(self, s_edge, s_sib, s_cop, s_grd, mask):
+        lens = mask[..., 0].sum(1)
+        _, seq_len, _ = mask.shape
+        hs, ms = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len)
+        # [seq_len, seq_len, batch_size], (h->m)
+        mask = mask.permute(2, 1, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        mask2o = mask.unsqueeze(1) & mask.unsqueeze(2)
+        mask2o = mask2o & hs.unsqueeze(-1).ne(hs.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o = mask2o & ms.unsqueeze(-1).ne(ms.new_tensor(range(seq_len))).unsqueeze(-1)
+        mask2o.diagonal().fill_(0)
+        # [2, seq_len, seq_len, batch_size], (h->m)
+        s_edge = torch.stack((torch.zeros_like(s_edge), s_edge)).permute(0, 3, 2, 1)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        s_sib = s_sib.permute(2, 1, 3, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->c)
+        s_cop = s_cop.permute(2, 1, 3, 0)
+        # [seq_len, seq_len, seq_len, batch_size], (h->m->g)
+        s_grd = s_grd.permute(2, 1, 3, 0)
+
+        # log beliefs
+        # [2, seq_len, seq_len, batch_size], (h->m)
+        q = s_edge
+        # sibling factor
+        # [2, seq_len, seq_len, seq_len, batch_size], (h->m->s)
+        m_sib = s_sib.new_zeros(2, *mask2o.shape)
+        # coparent factor
+        # [2, seq_len, seq_len, seq_len, batch_size], (h->m->c)
+        m_cop = s_cop.new_zeros(2, *mask2o.shape)
+        # grandparent factor
+        # [2, seq_len, seq_len, seq_len, batch_size], (h->m->g)
+        m_grd = s_grd.new_zeros(2, *mask2o.shape)
+        # tree factor
+        # [2, seq_len, seq_len, batch_size], (h->m)
+        m_tree = torch.zeros_like(s_edge)
+
+        for _ in range(self.max_iter):
+            # sibling factor
+            v_sib = q.unsqueeze(2) - m_sib
+            m_sib = torch.stack((v_sib.logsumexp(0), torch.stack((v_sib[0], v_sib[1] + s_sib)).logsumexp(0))).log_softmax(0)
+            # coparent factor
+            v_cop = q.transpose(1, 2).unsqueeze(1) - m_cop
+            m_cop = torch.stack((v_cop.logsumexp(0), torch.stack((v_cop[0], v_cop[1] + s_cop)).logsumexp(0))).log_softmax(0)
+            # grandparent factor
+            v_grd = q.unsqueeze(1) - m_grd
+            m_grd = torch.stack((v_grd.logsumexp(0), torch.stack((v_grd[0], v_grd[1] + s_grd)).logsumexp(0))).log_softmax(0)
+            # tree factor
+            v_tree = q - m_tree
+            b_tree = DependencyCRF((v_tree[1] - v_tree[0]).permute(2, 1, 0), lens).marginals.permute(2, 1, 0)
+            b_tree = torch.stack((1 - b_tree, b_tree))
+            m_tree = (b_tree.clamp(torch.finfo().eps).log() - v_tree).log_softmax(0)
+            # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j
+            q = s_edge + ((m_sib + m_cop + m_grd).transpose(2, 3) * mask2o).sum(3) + m_tree
+
+        return q.permute(3, 2, 1, 0)
diff --git a/tania_scripts/supar/utils/__init__.py b/tania_scripts/supar/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..279bb3fee548100e6538d245e4adf226c6a41a67
--- /dev/null
+++ b/tania_scripts/supar/utils/__init__.py
@@ -0,0 +1,17 @@
+# -*- coding: utf-8 -*-
+
+from . import field, fn, metric, transform
+from .config import Config
+from .data import Dataset
+from .embed import Embedding
+from .field import ChartField, Field, RawField, SubwordField
+from .transform import Transform
+from .vocab import Vocab
+
+__all__ = ['Config',
+           'Dataset',
+           'Embedding',
+           'RawField', 'Field', 'SubwordField', 'ChartField',
+           'Transform',
+           'Vocab',
+           'field', 'fn', 'metric', 'transform']
diff --git a/tania_scripts/supar/utils/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7085981158edde053f2e613913ae5104f375bc5
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea9d342a076da2ffb2ad39c0137168c37c8ba0a9
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/common.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d3aa81c201d0bcf97afeb8657a7fbfda08ae766
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/common.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/common.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/common.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a95bf475fc0c1b5ed5fbe397eed664cc93daf29
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/common.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/config.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ca40a0457ad92a8c30fc0f85221363889e017cd
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/config.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/config.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/config.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c702557431115d22e5cd81b51b3e043b2e3df185
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/config.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/data.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/data.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86c2665f120944e63bad5a419675ff4d4c8cdfdc
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/data.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/data.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/data.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6e9fad9bc068980a106abf82b540436629543ca
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/data.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/embed.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/embed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06e59dca4554b13420c432192952564973ebd013
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/embed.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/embed.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/embed.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50a582efc082811b6496eeb81a30e62f80b7f484
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/embed.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/field.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/field.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67521b6cccf3e5ff86e13752f33222767dc10c66
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/field.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/field.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/field.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..448dcecae208170106c7fee557c7e348e4ae066f
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/field.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/fn.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/fn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d4582deab9dc73c2dbc769f06c95e6ab6002084
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/fn.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/fn.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/fn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b894334d6aceb12d3be2fb5c1afb8363709ba8d3
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/fn.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/logging.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/logging.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d47f1e45e93b38ad9f76e9acb3bef318cf82085f
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/logging.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/logging.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/logging.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f52d49b3cbff680cb6322a30b5b1986941d883a6
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/logging.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/metric.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/metric.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1eb49a374e350bfdb86f44ddc9bb384ce41fe121
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/metric.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/metric.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/metric.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d685d90d287c335bc5300af1b9cb989c3df9b65
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/metric.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/optim.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/optim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf0d056ae56044d19b2e5e290f45058425bc2af3
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/optim.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/optim.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/optim.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f34b83e7bcd94eda11b8b415630ab6d562cf1c79
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/optim.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/parallel.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/parallel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d12400edeb441199d157172e73705a9fc94003fb
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/parallel.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/parallel.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/parallel.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20c67aea66c324dbdeb02c9220e11cd0318b49f8
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/parallel.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb268708f3777028bfa8541824a2616c07fc8b52
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37128ac854fdfde72e08c626b1a6cda83b86d673
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa125d981767ec90dad382896177289419f0d04a
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/transform.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd491045274e55fd2e1572e2fcbdf56acaec5e61
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/transform.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/vocab.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/vocab.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89fc3d6b458d28f39f20492fa93b0fd4981e51fd
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/vocab.cpython-310.pyc differ
diff --git a/tania_scripts/supar/utils/__pycache__/vocab.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/vocab.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37a8843712eaa44afd7d45ac81dfeef4e4038e50
Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/vocab.cpython-311.pyc differ
diff --git a/tania_scripts/supar/utils/common.py b/tania_scripts/supar/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..f320d2909e3bb86a83721d6ce820545295ccca3d
--- /dev/null
+++ b/tania_scripts/supar/utils/common.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+
+import os
+
+PAD = '<pad>'
+UNK = '<unk>'
+BOS = '<bos>'
+EOS = '<eos>'
+NUL = '<nul>'
+
+MIN = -1e32
+INF = float('inf')
+
+CACHE = os.path.expanduser('~/.cache/supar')
diff --git a/tania_scripts/supar/utils/config.py b/tania_scripts/supar/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b18adcab745b57c6a1000090510f82b986062c
--- /dev/null
+++ b/tania_scripts/supar/utils/config.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import argparse
+import yaml
+import os
+from ast import literal_eval
+from configparser import ConfigParser
+from typing import Any, Dict, Optional, Sequence
+
+import supar
+from omegaconf import OmegaConf
+from supar.utils.fn import download
+
+
+class Config(object):
+
+    def __init__(self, **kwargs: Any) -> None:
+        super(Config, self).__init__()
+
+        self.update(kwargs)
+
+    def __repr__(self) -> str:
+        return yaml.dump(self.__dict__)
+
+    def __getitem__(self, key: str) -> Any:
+        return getattr(self, key)
+
+    def __contains__(self, key: str) -> bool:
+        return hasattr(self, key)
+
+    def __getstate__(self) -> Dict[str, Any]:
+        return self.__dict__
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
+
+    @property
+    def primitive_config(self) -> Dict[str, Any]:
+        from enum import Enum
+        from pathlib import Path
+        primitive_types = (int, float, bool, str, bytes, Enum, Path)
+        return {name: value for name, value in self.__dict__.items() if type(value) in primitive_types}
+
+    def keys(self) -> Any:
+        return self.__dict__.keys()
+
+    def items(self) -> Any:
+        return self.__dict__.items()
+
+    def update(self, kwargs: Dict[str, Any]) -> Config:
+        for key in ('self', 'cls', '__class__'):
+            kwargs.pop(key, None)
+        kwargs.update(kwargs.pop('kwargs', dict()))
+        for name, value in kwargs.items():
+            setattr(self, name, value)
+        return self
+
+    def get(self, key: str, default: Optional[Any] = None) -> Any:
+        return getattr(self, key, default)
+
+    def pop(self, key: str, default: Optional[Any] = None) -> Any:
+        return self.__dict__.pop(key, default)
+
+    def save(self, path):
+        with open(path, 'w') as f:
+            f.write(str(self))
+
+    @classmethod
+    def load(cls, conf: str = '', unknown: Optional[Sequence[str]] = None, **kwargs: Any) -> Config:
+        if conf and not os.path.exists(conf):
+            conf = download(supar.CONFIG['github'].get(conf, conf))
+        if conf.endswith(('.yml', '.yaml')):
+            config = OmegaConf.load(conf)
+        else:
+            config = ConfigParser()
+            config.read(conf)
+            config = dict((name, literal_eval(value)) for s in config.sections() for name, value in config.items(s))
+        if unknown is not None:
+            parser = argparse.ArgumentParser()
+            for name, value in config.items():
+                parser.add_argument('--'+name.replace('_', '-'), type=type(value), default=value)
+            config.update(vars(parser.parse_args(unknown)))
+        return cls(**config).update(kwargs)
diff --git a/tania_scripts/supar/utils/data.py b/tania_scripts/supar/utils/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b32d60f78b39cc2e6c425fe69c54f0919e61d
--- /dev/null
+++ b/tania_scripts/supar/utils/data.py
@@ -0,0 +1,344 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import itertools
+import os
+import queue
+import shutil
+import tempfile
+import threading
+from contextlib import contextmanager
+from typing import Dict, Iterable, List, Union
+
+import pathos.multiprocessing as mp
+import torch
+import torch.distributed as dist
+from supar.utils.common import INF
+from supar.utils.fn import binarize, debinarize, kmeans
+from supar.utils.logging import get_logger, progress_bar
+from supar.utils.parallel import is_dist, is_master
+from supar.utils.transform import Batch, Transform
+from torch.distributions.utils import lazy_property
+
+logger = get_logger(__name__)
+
+
+class Dataset(torch.utils.data.Dataset):
+    r"""
+    Dataset that is compatible with :class:`torch.utils.data.Dataset`, serving as a wrapper for manipulating all data fields
+    with the operating behaviours defined in :class:`~supar.utils.transform.Transform`.
+    The data fields of all the instantiated sentences can be accessed as an attribute of the dataset.
+
+    Args:
+        transform (Transform):
+            An instance of :class:`~supar.utils.transform.Transform` or its derivations.
+            The instance holds a series of loading and processing behaviours with regard to the specific data format.
+        data (Union[str, Iterable]):
+            A filename or a list of instances that will be passed into :meth:`transform.load`.
+        cache (bool):
+            If ``True``, tries to use the previously cached binarized data for fast loading.
+            In this way, sentences are loaded on-the-fly according to the meta data.
+            If ``False``, all sentences will be directly loaded into the memory.
+            Default: ``False``.
+        binarize (bool):
+            If ``True``, binarizes the dataset once building it. Only works if ``cache=True``. Default: ``False``.
+        bin (str):
+            Path for saving binarized files, required if ``cache=True``. Default: ``None``.
+        max_len (int):
+            Sentences exceeding the length will be discarded. Default: ``None``.
+        kwargs (Dict):
+            Together with `data`, kwargs will be passed into :meth:`transform.load` to control the loading behaviour.
+
+    Attributes:
+        transform (Transform):
+            An instance of :class:`~supar.utils.transform.Transform`.
+        sentences (List[Sentence]):
+            A list of sentences loaded from the data.
+            Each sentence includes fields obeying the data format defined in ``transform``.
+            If ``cache=True``, each is a pointer to the sentence stored in the cache file.
+    """
+
+    def __init__(
+        self,
+        transform: Transform,
+        data: Union[str, Iterable],
+        cache: bool = False,
+        binarize: bool = False,
+        bin: str = None,
+        max_len: int = None,
+        **kwargs
+    ) -> Dataset:
+        super(Dataset, self).__init__()
+
+        self.transform = transform
+        self.data = data
+        self.cache = cache
+        self.binarize = binarize
+        self.bin = bin
+        self.max_len = max_len or INF
+        self.kwargs = kwargs
+
+        if cache:
+            if not isinstance(data, str) or not os.path.exists(data):
+                raise FileNotFoundError("Only files are allowed for binarization, but not found")
+            if self.bin is None:
+                self.fbin = data + '.pt'
+            else:
+                os.makedirs(self.bin, exist_ok=True)
+                self.fbin = os.path.join(self.bin, os.path.split(data)[1]) + '.pt'
+            if not self.binarize and os.path.exists(self.fbin):
+                try:
+                    self.sentences = debinarize(self.fbin, meta=True)['sentences']
+                except Exception:
+                    raise RuntimeError(f"Error found while debinarizing {self.fbin}, which may have been corrupted. "
+                                       "Try re-binarizing it first!")
+        else:
+            self.sentences = list(transform.load(data, **kwargs))
+
+
+    def __repr__(self):
+        s = f"{self.__class__.__name__}("
+        s += f"n_sentences={len(self.sentences)}"
+        if hasattr(self, 'loader'):
+            s += f", n_batches={len(self.loader)}"
+        if hasattr(self, 'buckets'):
+            s += f", n_buckets={len(self.buckets)}"
+        if self.cache:
+            s += f", cache={self.cache}"
+        if self.binarize:
+            s += f", binarize={self.binarize}"
+        if self.max_len < INF:
+            s += f", max_len={self.max_len}"
+        s += ")"
+        return s
+
+    def __len__(self):
+        return len(self.sentences)
+
+    def __getitem__(self, index):
+        return debinarize(self.fbin, self.sentences[index]) if self.cache else self.sentences[index]
+
+    def __getattr__(self, name):
+        if name not in {f.name for f in self.transform.flattened_fields}:
+            raise AttributeError
+        if self.cache:
+            if os.path.exists(self.fbin) and not self.binarize:
+                sentences = self
+            else:
+                sentences = self.transform.load(self.data, **self.kwargs)
+            return (getattr(sentence, name) for sentence in sentences)
+        return [getattr(sentence, name) for sentence in self.sentences]
+
+    def __getstate__(self):
+        return self.__dict__
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+
+    @lazy_property
+    def sizes(self):
+        if not self.cache:
+            return [s.size for s in self.sentences]
+        return debinarize(self.fbin, 'sizes')
+
+    def build(
+        self,
+        batch_size: int,
+        n_buckets: int = 1,
+        shuffle: bool = False,
+        distributed: bool = False,
+        even: bool = True,
+        n_workers: int = 0,
+        pin_memory: bool = True,
+        chunk_size: int = 1000,
+    ) -> Dataset:
+        # numericalize all fields
+        if not self.cache:
+            self.sentences = [i for i in self.transform(self.sentences) if len(i) < self.max_len]
+        else:
+            # if not forced to do binarization and the binarized file already exists, directly load the meta file
+            if os.path.exists(self.fbin) and not self.binarize:
+                self.sentences = debinarize(self.fbin, meta=True)['sentences']
+            else:
+                @contextmanager
+                def cache(sentences):
+                    ftemp = tempfile.mkdtemp()
+                    fs = os.path.join(ftemp, 'sentences')
+                    fb = os.path.join(ftemp, os.path.basename(self.fbin))
+                    global global_transform
+                    global_transform = self.transform
+                    sentences = binarize({'sentences': progress_bar(sentences)}, fs)[1]['sentences']
+                    try:
+                        yield ((sentences[s:s+chunk_size], fs, f"{fb}.{i}", self.max_len)
+                               for i, s in enumerate(range(0, len(sentences), chunk_size)))
+                    finally:
+                        del global_transform
+                        shutil.rmtree(ftemp)
+
+                def numericalize(sentences, fs, fb, max_len):
+                    sentences = global_transform((debinarize(fs, sentence) for sentence in sentences))
+                    sentences = [i for i in sentences if len(i) < max_len]
+                    return binarize({'sentences': sentences, 'sizes': [sentence.size for sentence in sentences]}, fb)[0]
+
+                logger.info(f"Seeking to cache the data to {self.fbin} first")
+                # numericalize the fields of each sentence
+                if is_master():
+                    with cache(self.transform.load(self.data, **self.kwargs)) as chunks, mp.Pool(32) as pool:
+                        results = [pool.apply_async(numericalize, chunk) for chunk in chunks]
+                        self.sentences = binarize((r.get() for r in results), self.fbin, merge=True)[1]['sentences']
+                if is_dist():
+                    dist.barrier()
+                if not is_master():
+                    self.sentences = debinarize(self.fbin, meta=True)['sentences']
+        # NOTE: the final bucket count is roughly equal to n_buckets
+        self.buckets = dict(zip(*kmeans(self.sizes, n_buckets)))
+        self.loader = DataLoader(transform=self.transform,
+                                 dataset=self,
+                                 batch_sampler=Sampler(self.buckets, batch_size, shuffle, distributed, even),
+                                 num_workers=n_workers,
+                                 collate_fn=collate_fn,
+                                 pin_memory=pin_memory)
+        return self
+
+
+class Sampler(torch.utils.data.Sampler):
+    r"""
+    Sampler that supports for bucketization and token-level batchification.
+
+    Args:
+        buckets (Dict):
+            A dict that maps each centroid to indices of clustered sentences.
+            The centroid corresponds to the average length of all sentences in the bucket.
+        batch_size (int):
+            Token-level batch size. The resulting batch contains roughly the same number of tokens as ``batch_size``.
+        shuffle (bool):
+            If ``True``, the sampler will shuffle both buckets and samples in each bucket. Default: ``False``.
+        distributed (bool):
+            If ``True``, the sampler will be used in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`
+            that restricts data loading to a subset of the dataset.
+            Default: ``False``.
+        even (bool):
+            If ``True``, the sampler will add extra indices to make the data evenly divisible across the replicas.
+            Default: ``True``.
+    """
+
+    def __init__(
+        self,
+        buckets: Dict[float, List],
+        batch_size: int,
+        shuffle: bool = False,
+        distributed: bool = False,
+        even: bool = True
+    ) -> Sampler:
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.distributed = distributed
+        self.even = even
+        #print("237 utils data buckets items", buckets.items()) 
+        self.sizes, self.buckets = zip(*[(size, bucket) for size, bucket in buckets.items()])
+        # number of batches in each bucket, clipped by range [1, len(bucket)]
+        self.n_batches = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1))
+                          for size, bucket in zip(self.sizes, self.buckets)]
+        self.rank, self.n_replicas, self.n_samples = 0, 1, self.n_total_samples
+
+        if distributed:
+            self.rank = dist.get_rank()
+            self.n_replicas = dist.get_world_size()
+            self.n_samples = self.n_total_samples // self.n_replicas
+            if self.n_total_samples % self.n_replicas != 0:
+                self.n_samples += 1 if even else int(self.rank < self.n_total_samples % self.n_replicas)
+        self.epoch = 1
+
+    def __iter__(self):
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        self.epoch += 1
+
+        total, batches = 0, []
+        # if `shuffle=True`, shuffle both the buckets and samples in each bucket
+        # for distributed training, make sure each process generates the same random sequence at each epoch
+        range_fn = torch.arange if not self.shuffle else lambda x: torch.randperm(x, generator=g)
+        for i in itertools.cycle(range(len(self.buckets))):
+            bucket = self.buckets[i]
+            split_sizes = [(len(bucket) - j - 1) // self.n_batches[i] + 1 for j in range(self.n_batches[i])]
+            # DON'T use `torch.chunk` which may return wrong number of batches
+            for batch in range_fn(len(bucket)).split(split_sizes):
+                #print('supar utils 270', batch)
+                if total % self.n_replicas == self.rank:
+                    batches.append([bucket[j] for j in batch.tolist()])
+                if len(batches) == self.n_samples:
+                    return iter(batches[i] for i in range_fn(self.n_samples).tolist())
+                total += 1
+
+    def __len__(self):
+        return self.n_samples
+
+    @property
+    def n_total_samples(self):
+        return sum(self.n_batches)
+
+    def set_epoch(self, epoch: int) -> None:
+        self.epoch = epoch
+
+
+class DataLoader(torch.utils.data.DataLoader):
+
+    r"""
+    A wrapper for native :class:`torch.utils.data.DataLoader` enhanced with a data prefetcher.
+    See http://stackoverflow.com/questions/7323664/python-generator-pre-fetch and
+    https://github.com/NVIDIA/apex/issues/304.
+    """
+
+    def __init__(self, transform, **kwargs):
+        super().__init__(**kwargs)
+
+        self.transform = transform
+
+    def __iter__(self):
+        return PrefetchGenerator(self.transform, super().__iter__())
+
+
+class PrefetchGenerator(threading.Thread):
+
+    def __init__(self, transform, loader, prefetch=1):
+        threading.Thread.__init__(self)
+
+        self.transform = transform
+
+        self.queue = queue.Queue(prefetch)
+        self.loader = loader
+        self.daemon = True
+        if torch.cuda.is_available():
+            self.stream = torch.cuda.Stream()
+
+        self.start()
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if hasattr(self, 'stream'):
+            torch.cuda.current_stream().wait_stream(self.stream)
+        batch = self.queue.get()
+        if batch is None:
+            raise StopIteration
+        return batch
+
+    def run(self):
+        # `torch.cuda.current_device` is thread local
+        # see https://github.com/pytorch/pytorch/issues/56588
+        if is_dist() and torch.cuda.is_available():
+            torch.cuda.set_device(dist.get_rank())
+        if hasattr(self, 'stream'):
+            with torch.cuda.stream(self.stream):
+                for batch in self.loader:
+                    self.queue.put(batch.compose(self.transform))
+        else:
+            for batch in self.loader:
+                self.queue.put(batch.compose(self.transform))
+        self.queue.put(None)
+
+
+def collate_fn(x):
+    return Batch(x)
diff --git a/tania_scripts/supar/utils/embed.py b/tania_scripts/supar/utils/embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..c132b4da9627425576d339bf57cfdb15ed94aa89
--- /dev/null
+++ b/tania_scripts/supar/utils/embed.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from typing import Iterable, Optional, Union
+
+import torch
+from supar.utils.common import CACHE
+from supar.utils.fn import download
+from supar.utils.logging import progress_bar
+from torch.distributions.utils import lazy_property
+
+
+class Embedding(object):
+    r"""
+    Defines a container object for holding pretrained embeddings.
+    This object is callable and behaves like :class:`torch.nn.Embedding`.
+    For huge files, this object supports lazy loading, seeking to retrieve vectors from the disk on the fly if necessary.
+
+    Currently available embeddings:
+        - `GloVe`_
+        - `Fasttext`_
+        - `Giga`_
+        - `Tencent`_
+
+    Args:
+        path (str):
+            Path to the embedding file or short name registered in ``supar.utils.embed.PRETRAINED``.
+        unk (Optional[str]):
+            The string token used to represent OOV tokens. Default: ``None``.
+        skip_first (bool)
+            If ``True``, skips the first line of the embedding file. Default: ``False``.
+        cache (bool):
+            If ``True``, instead of loading entire embeddings into memory, seeks to load vectors from the disk once called.
+            Default: ``True``.
+        sep (str):
+            Separator used by embedding file. Default: ``' '``.
+
+    Examples:
+        >>> import torch.nn as nn
+        >>> from supar.utils.embed import Embedding
+        >>> glove = Embedding.load('glove-6b-100')
+        >>> glove
+        GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True)
+        >>> fasttext = Embedding.load('fasttext-en')
+        >>> fasttext
+        FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True)
+        >>> giga = Embedding.load('giga-100')
+        >>> giga
+        GigaEmbedding(n_tokens=372846, dim=100, cache=True)
+        >>> indices = torch.tensor([glove.vocab[i.lower()] for i in ['She', 'enjoys', 'playing', 'tennis', '.']])
+        >>> indices
+        tensor([  67, 8371,  697, 2140,    2])
+        >>> glove(indices).shape
+        torch.Size([5, 100])
+        >>> glove(indices).equal(nn.Embedding.from_pretrained(glove.vectors)(indices))
+        True
+
+    .. _GloVe:
+        https://nlp.stanford.edu/projects/glove/
+    .. _Fasttext:
+        https://fasttext.cc/docs/en/crawl-vectors.html
+    .. _Giga:
+        https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip
+    .. _Tencent:
+        https://ai.tencent.com/ailab/nlp/zh/download.html
+    """
+
+    CACHE = os.path.join(CACHE, 'data/embeds')
+
+    def __init__(
+        self,
+        path: str,
+        unk: Optional[str] = None,
+        skip_first: bool = False,
+        cache: bool = True,
+        sep: str = ' ',
+        **kwargs
+    ) -> Embedding:
+        super().__init__()
+
+        self.path = path
+        self.unk = unk
+        self.skip_first = skip_first
+        self.cache = cache
+        self.sep = sep
+        self.kwargs = kwargs
+
+        self.vocab = {token: i for i, token in enumerate(self.tokens)}
+
+    def __len__(self):
+        return len(self.vocab)
+
+    def __repr__(self):
+        s = f"{self.__class__.__name__}("
+        s += f"n_tokens={len(self)}, dim={self.dim}"
+        if self.unk is not None:
+            s += f", unk={self.unk}"
+        if self.skip_first:
+            s += f", skip_first={self.skip_first}"
+        if self.cache:
+            s += f", cache={self.cache}"
+        s += ")"
+        return s
+
+    def __contains__(self, token):
+        return token in self.vocab
+
+    def __getitem__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor:
+        indices = key
+        if not isinstance(indices, torch.Tensor):
+            indices = torch.tensor(key)
+        if self.cache:
+            elems, indices = indices.unique(return_inverse=True)
+            with open(self.path) as f:
+                vectors = []
+                for index in elems.tolist():
+                    f.seek(self.positions[index])
+                    vectors.append(list(map(float, f.readline().strip().split(self.sep)[1:])))
+                vectors = torch.tensor(vectors)
+        else:
+            vectors = self.vectors
+        return torch.embedding(vectors, indices)
+
+    def __call__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor:
+        return self[key]
+
+    @lazy_property
+    def dim(self):
+        return len(self[0])
+
+    @lazy_property
+    def unk_index(self):
+        if self.unk is not None:
+            return self.vocab[self.unk]
+        raise AttributeError
+
+    @lazy_property
+    def tokens(self):
+        with open(self.path) as f:
+            if self.skip_first:
+                f.readline()
+            return [line.strip().split(self.sep)[0] for line in progress_bar(f)]
+
+    @lazy_property
+    def vectors(self):
+        with open(self.path) as f:
+            if self.skip_first:
+                f.readline()
+            return torch.tensor([list(map(float, line.strip().split(self.sep)[1:])) for line in progress_bar(f)])
+
+    @lazy_property
+    def positions(self):
+        with open(self.path) as f:
+            if self.skip_first:
+                f.readline()
+            positions = [f.tell()]
+            while True:
+                line = f.readline()
+                if line:
+                    positions.append(f.tell())
+                else:
+                    break
+            return positions
+
+    @classmethod
+    def load(cls, path: str, unk: Optional[str] = None, **kwargs) -> Embedding:
+        if path in PRETRAINED:
+            cfg = dict(**PRETRAINED[path])
+            embed = cfg.pop('_target_')
+            return embed(**cfg, **kwargs)
+        return cls(path, unk, **kwargs)
+
+
+class GloVeEmbedding(Embedding):
+
+    r"""
+    `GloVe`_: Global Vectors for Word Representation.
+    Training is performed on aggregated global word-word co-occurrence statistics from a corpus,
+    and the resulting representations showcase interesting linear substructures of the word vector space.
+
+    Args:
+        src (str):
+            Size of the source data for training. Default: ``6B``.
+        dim (int):
+            Which dimension of the embeddings to use. Default: 100.
+        reload (bool):
+                If ``True``, forces a fresh download. Default: ``False``.
+
+    Examples:
+        >>> from supar.utils.embed import Embedding
+        >>> Embedding.load('glove-6b-100')
+        GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True)
+
+    .. _GloVe:
+        https://nlp.stanford.edu/projects/glove/
+    """
+
+    def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwargs) -> GloVeEmbedding:
+        if src == '6B' or src == 'twitter.27B':
+            url = f'https://nlp.stanford.edu/data/glove.{src}.zip'
+        else:
+            url = f'https://nlp.stanford.edu/data/glove.{src}.{dim}d.zip'
+        path = os.path.join(os.path.join(self.CACHE, 'glove'), f'glove.{src}.{dim}d.txt')
+        if not os.path.exists(path) or reload:
+            download(url, os.path.join(self.CACHE, 'glove'), clean=True)
+
+        super().__init__(path=path, unk='unk', *args, **kwargs, )
+
+
+class FasttextEmbedding(Embedding):
+
+    r"""
+    `Fasttext`_ word embeddings for 157 languages, trained using CBOW, in dimension 300,
+    with character n-grams of length 5, a window of size 5 and 10 negatives.
+
+    Args:
+        lang (str):
+            Language code. Default: ``en``.
+        reload (bool):
+                If ``True``, forces a fresh download. Default: ``False``.
+
+    Examples:
+        >>> from supar.utils.embed import Embedding
+        >>> Embedding.load('fasttext-en')
+        FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True)
+
+    .. _Fasttext:
+        https://fasttext.cc/docs/en/crawl-vectors.html
+    """
+
+    def __init__(self, lang: str = 'en', reload=False, *args, **kwargs) -> FasttextEmbedding:
+        url = f'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{lang}.300.vec.gz'
+        path = os.path.join(self.CACHE, 'fasttext', f'cc.{lang}.300.vec')
+        if not os.path.exists(path) or reload:
+            download(url, os.path.join(self.CACHE, 'fasttext'), clean=True)
+
+        super().__init__(path=path, skip_first=True, *args, **kwargs)
+
+
+class GigaEmbedding(Embedding):
+
+    r"""
+    `Giga`_ word embeddings, trained on Chinese Gigaword Third Edition for Chinese using word2vec,
+    used by :cite:`zhang-etal-2020-efficient` and :cite:`zhang-etal-2020-fast`.
+
+    Args:
+        reload (bool):
+            If ``True``, forces a fresh download. Default: ``False``.
+
+    Examples:
+        >>> from supar.utils.embed import Embedding
+        >>> Embedding.load('giga-100')
+        GigaEmbedding(n_tokens=372846, dim=100, cache=True)
+
+    .. _Giga:
+        https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip
+    """
+
+    def __init__(self, reload=False, *args, **kwargs) -> GigaEmbedding:
+        url = 'https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip'
+        path = os.path.join(self.CACHE, 'giga', 'giga.100.txt')
+        if not os.path.exists(path) or reload:
+            download(url, os.path.join(self.CACHE, 'giga'), clean=True)
+
+        super().__init__(path=path, *args, **kwargs)
+
+
+class TencentEmbedding(Embedding):
+
+    r"""
+    `Tencent`_ word embeddings.
+    The embeddings are trained on large-scale text collected from news, webpages, and novels with Directional Skip-Gram.
+    100-dimension and 200-dimension embeddings for over 12 million Chinese words are provided.
+
+    Args:
+        dim (int):
+            Which dimension of the embeddings to use. Currently 100 and 200 are available. Default: 100.
+        large (bool):
+            If ``True``, uses large version with larger vocab size (12,287,933); 2,000,000 otherwise. Default: ``False``.
+        reload (bool):
+            If ``True``, forces a fresh download. Default: ``False``.
+
+    Examples:
+        >>> from supar.utils.embed import Embedding
+        >>> Embedding.load('tencent-100')
+        TencentEmbedding(n_tokens=2000000, dim=100, skip_first=True, cache=True)
+        >>> Embedding.load('tencent-100-large')
+        TencentEmbedding(n_tokens=12287933, dim=100, skip_first=True, cache=True)
+
+    .. _Tencent:
+        https://ai.tencent.com/ailab/nlp/zh/download.html
+    """
+
+    def __init__(self, dim: int = 100, large: bool = False, reload=False, *args, **kwargs) -> TencentEmbedding:
+        url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}.tar.gz'  # noqa
+        name = f'tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}'
+        path = os.path.join(os.path.join(self.CACHE, 'tencent'), name, f'{name}.txt')
+        if not os.path.exists(path) or reload:
+            download(url, os.path.join(self.CACHE, 'tencent'), clean=True)
+
+        super().__init__(path=path, skip_first=True, *args, **kwargs)
+
+
+PRETRAINED = {
+    'glove-6b-50': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 50},
+    'glove-6b-100': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 100},
+    'glove-6b-200': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 200},
+    'glove-6b-300': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 300},
+    'glove-42b-300': {'_target_': GloVeEmbedding, 'src': '42B', 'dim': 300},
+    'glove-840b-300': {'_target_': GloVeEmbedding, 'src': '84B', 'dim': 300},
+    'glove-twitter-27b-25': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 25},
+    'glove-twitter-27b-50': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 50},
+    'glove-twitter-27b-100': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 100},
+    'glove-twitter-27b-200': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 200},
+    'fasttext-bg': {'_target_': FasttextEmbedding, 'lang': 'bg'},
+    'fasttext-ca': {'_target_': FasttextEmbedding, 'lang': 'ca'},
+    'fasttext-cs': {'_target_': FasttextEmbedding, 'lang': 'cs'},
+    'fasttext-de': {'_target_': FasttextEmbedding, 'lang': 'de'},
+    'fasttext-en': {'_target_': FasttextEmbedding, 'lang': 'en'},
+    'fasttext-es': {'_target_': FasttextEmbedding, 'lang': 'es'},
+    'fasttext-fr': {'_target_': FasttextEmbedding, 'lang': 'fr'},
+    'fasttext-it': {'_target_': FasttextEmbedding, 'lang': 'it'},
+    'fasttext-nl': {'_target_': FasttextEmbedding, 'lang': 'nl'},
+    'fasttext-no': {'_target_': FasttextEmbedding, 'lang': 'no'},
+    'fasttext-ro': {'_target_': FasttextEmbedding, 'lang': 'ro'},
+    'fasttext-ru': {'_target_': FasttextEmbedding, 'lang': 'ru'},
+    'giga-100': {'_target_': GigaEmbedding},
+    'tencent-100': {'_target_': TencentEmbedding, 'dim': 100},
+    'tencent-100-large': {'_target_': TencentEmbedding, 'dim': 100, 'large': True},
+    'tencent-200': {'_target_': TencentEmbedding, 'dim': 200},
+    'tencent-200-large': {'_target_': TencentEmbedding, 'dim': 200, 'large': True},
+}
diff --git a/tania_scripts/supar/utils/field.py b/tania_scripts/supar/utils/field.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd14bc78f7c45302b61c5b17ff50c90cbde22abd
--- /dev/null
+++ b/tania_scripts/supar/utils/field.py
@@ -0,0 +1,415 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from collections import Counter
+from typing import Callable, Iterable, List, Optional, Union
+
+import torch
+from supar.utils.data import Dataset
+from supar.utils.embed import Embedding
+from supar.utils.fn import pad
+from supar.utils.logging import progress_bar
+from supar.utils.parallel import wait
+from supar.utils.vocab import Vocab
+
+
+class RawField(object):
+    r"""
+    Defines a general datatype.
+
+    A :class:`RawField` object does not assume any property of the datatype and
+    it holds parameters relating to how a datatype should be processed.
+
+    Args:
+        name (str):
+            The name of the field.
+        fn (function):
+            The function used for preprocessing the examples. Default: ``None``.
+    """
+
+    def __init__(self, name: str, fn: Optional[Callable] = None) -> RawField:
+        self.name = name
+        self.fn = fn
+
+    def __repr__(self):
+        return f"({self.name}): {self.__class__.__name__}()"
+
+    def preprocess(self, sequence: Iterable) -> Iterable:
+        return self.fn(sequence) if self.fn is not None else sequence
+
+    def transform(self, sequences: Iterable[List]) -> Iterable[List]:
+        return (self.preprocess(seq) for seq in sequences)
+
+    def compose(self, sequences: Iterable[List]) -> Iterable[List]:
+        return sequences
+
+
+class Field(RawField):
+    r"""
+    Defines a datatype together with instructions for converting to :class:`~torch.Tensor`.
+    :class:`Field` models common text processing datatypes that can be represented by tensors.
+    It holds a :class:`~supar.utils.vocab.Vocab` object that defines the set of possible values
+    for elements of the field and their corresponding numerical representations.
+    The :class:`Field` object also holds other parameters relating to how a datatype
+    should be numericalized, such as a tokenization method.
+
+    Args:
+        name (str):
+            The name of the field.
+        pad_token (str):
+            The string token used as padding. Default: ``None``.
+        unk_token (str):
+            The string token used to represent OOV words. Default: ``None``.
+        bos_token (str):
+            A token that will be prepended to every example using this field, or ``None`` for no `bos_token`.
+            Default: ``None``.
+        eos_token (str):
+            A token that will be appended to every example using this field, or ``None`` for no `eos_token`.
+        lower (bool):
+            Whether to lowercase the text in this field. Default: ``False``.
+        use_vocab (bool):
+            Whether to use a :class:`~supar.utils.vocab.Vocab` object.
+            If ``False``, the data in this field should already be numerical.
+            Default: ``True``.
+        tokenize (function):
+            The function used to tokenize strings using this field into sequential examples. Default: ``None``.
+        fn (function):
+            The function used for preprocessing the examples. Default: ``None``.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        pad: Optional[str] = None,
+        unk: Optional[str] = None,
+        bos: Optional[str] = None,
+        eos: Optional[str] = None,
+        lower: bool = False,
+        use_vocab: bool = True,
+        tokenize: Optional[Callable] = None,
+        fn: Optional[Callable] = None,
+        delay: Optional[int] = 0
+    ) -> Field:
+        self.name = name
+        self.pad = pad
+        self.unk = unk
+        self.bos = bos
+        self.eos = eos
+        self.lower = lower
+        self.use_vocab = use_vocab
+        self.tokenize = tokenize
+        self.fn = fn
+        self.delay = delay
+
+        self.specials = [token for token in [pad, unk, bos, eos] if token is not None]
+
+    def __repr__(self):
+        s, params = f"({self.name}): {self.__class__.__name__}(", []
+        if hasattr(self, 'vocab'):
+            params.append(f"vocab_size={len(self.vocab)}")
+        if self.pad is not None:
+            params.append(f"pad={self.pad}")
+        if self.unk is not None:
+            params.append(f"unk={self.unk}")
+        if self.bos is not None:
+            params.append(f"bos={self.bos}")
+        if self.eos is not None:
+            params.append(f"eos={self.eos}")
+        if self.lower:
+            params.append(f"lower={self.lower}")
+        if not self.use_vocab:
+            params.append(f"use_vocab={self.use_vocab}")
+        if self.delay > 0:
+            params.append(f'delay={self.delay}')
+        return s + ', '.join(params) + ')'
+
+    @property
+    def pad_index(self):
+        if self.pad is None:
+            return 0
+        if hasattr(self, 'vocab'):
+            return self.vocab[self.pad]
+        return self.specials.index(self.pad)
+
+    @property
+    def unk_index(self):
+        if self.unk is None:
+            return 0
+        if hasattr(self, 'vocab'):
+            return self.vocab[self.unk]
+        return self.specials.index(self.unk)
+
+    @property
+    def bos_index(self):
+        if hasattr(self, 'vocab'):
+            return self.vocab[self.bos]
+        return self.specials.index(self.bos)
+
+    @property
+    def eos_index(self):
+        if hasattr(self, 'vocab'):
+            return self.vocab[self.eos]
+        return self.specials.index(self.eos)
+
+    @property
+    def device(self):
+        return 'cuda' if torch.cuda.is_available() else 'cpu'
+
+    def preprocess(self, data: Union[str, Iterable]) -> Iterable:
+        r"""
+        Loads a single example and tokenize it if necessary.
+        The sequence will be first passed to ``fn`` if available.
+        If ``tokenize`` is not None, the input will be tokenized.
+        Then the input will be lowercased optionally.
+
+        Args:
+            data (Union[str, Iterable]):
+                The data to be preprocessed.
+
+        Returns:
+            A list of preprocessed sequence.
+        """
+
+        if self.fn is not None:
+            data = self.fn(data)
+        if self.tokenize is not None:
+            data = self.tokenize(data)
+        if self.lower:
+            data = [str.lower(token) for token in data]
+        return data
+
+    def build(
+        self,
+        dataset: Dataset,
+        min_freq: int = 1,
+        embed: Optional[Embedding] = None,
+        norm: Callable = None
+    ) -> Field:
+        r"""
+        Constructs a :class:`~supar.utils.vocab.Vocab` object for this field from the dataset.
+        If the vocabulary has already existed, this function will have no effect.
+
+        Args:
+            dataset (Dataset):
+                A :class:`~supar.utils.data.Dataset` object.
+                One of the attributes should be named after the name of this field.
+            min_freq (int):
+                The minimum frequency needed to include a token in the vocabulary. Default: 1.
+            embed (Embedding):
+                An Embedding object, words in which will be extended to the vocabulary. Default: ``None``.
+            norm (Callable):
+                Callable function used for normalizing embedding weights. Default: ``None``.
+        """
+
+        if hasattr(self, 'vocab'):
+            return
+
+        @wait
+        def build_vocab(dataset):
+            return Vocab(counter=Counter(token
+                                         for seq in progress_bar(getattr(dataset, self.name))
+                                         for token in self.preprocess(seq)),
+                         min_freq=min_freq,
+                         specials=self.specials,
+                         unk_index=self.unk_index)
+        self.vocab = build_vocab(dataset)
+
+        if not embed:
+            self.embed = None
+        else:
+            tokens = self.preprocess(embed.tokens)
+            # replace the `unk` token in the pretrained with a self-defined one if existed
+            if embed.unk:
+                tokens[embed.unk_index] = self.unk
+
+            self.vocab.update(tokens)
+            self.embed = torch.zeros(len(self.vocab), embed.dim)
+            self.embed[self.vocab[tokens]] = embed.vectors
+            if norm is not None:
+                self.embed = norm(self.embed)
+        return self
+
+    def transform(self, sequences: Iterable[List[str]]) -> Iterable[torch.Tensor]:
+        r"""
+        Turns a list of sequences that use this field into tensors.
+
+        Each sequence is first preprocessed and then numericalized if needed.
+
+        Args:
+            sequences (Iterable[List[str]]):
+                A list of sequences.
+
+        Returns:
+            A list of tensors transformed from the input sequences.
+        """
+
+        for seq in sequences:
+            seq = self.preprocess(seq)
+            if self.use_vocab:
+                try:
+                    seq = [self.vocab[token] for token in seq]
+                except:
+                    raise AssertionError
+            if self.bos:
+                seq = [self.bos_index] + seq
+            if self.delay > 0:
+                seq = seq + [self.pad_index for _ in range(self.delay)]
+            if self.eos:
+                seq = seq + [self.eos_index]
+            yield torch.tensor(seq, dtype=torch.long)
+
+    def compose(self, batch: Iterable[torch.Tensor]) -> torch.Tensor:
+        r"""
+        Composes a batch of sequences into a padded tensor.
+
+        Args:
+            batch (Iterable[~torch.Tensor]):
+                A list of tensors.
+
+        Returns:
+            A padded tensor converted to proper device.
+        """
+
+        return pad(batch, self.pad_index).to(self.device, non_blocking=True)
+
+
+class SubwordField(Field):
+    r"""
+    A field that conducts tokenization and numericalization over each token rather the sequence.
+
+    This is customized for models requiring character/subword-level inputs, e.g., CharLSTM and BERT.
+
+    Args:
+        fix_len (int):
+            A fixed length that all subword pieces will be padded to.
+            This is used for truncating the subword pieces exceeding the length.
+            To save the memory, the final length will be the smaller value
+            between the max length of subword pieces in a batch and `fix_len`.
+
+    Examples:
+        >>> from supar.utils.tokenizer import TransformerTokenizer
+        >>> tokenizer = TransformerTokenizer('bert-base-cased')
+        >>> field = SubwordField('bert',
+                                 pad=tokenizer.pad,
+                                 unk=tokenizer.unk,
+                                 bos=tokenizer.bos,
+                                 eos=tokenizer.eos,
+                                 fix_len=20,
+                                 tokenize=tokenizer)
+        >>> field.vocab = tokenizer.vocab  # no need to re-build the vocab
+        >>> next(field.transform([['This', 'field', 'performs', 'token-level', 'tokenization']]))
+        tensor([[  101,     0,     0],
+                [ 1188,     0,     0],
+                [ 1768,     0,     0],
+                [10383,     0,     0],
+                [22559,   118,  1634],
+                [22559,  2734,     0],
+                [  102,     0,     0]])
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.fix_len = kwargs.pop('fix_len') if 'fix_len' in kwargs else 0
+        super().__init__(*args, **kwargs)
+
+    def build(
+        self,
+        dataset: Dataset,
+        min_freq: int = 1,
+        embed: Optional[Embedding] = None,
+        norm: Callable = None
+    ) -> SubwordField:
+        if hasattr(self, 'vocab'):
+            return
+
+        @wait
+        def build_vocab(dataset):
+            return Vocab(counter=Counter(piece
+                                         for seq in progress_bar(getattr(dataset, self.name))
+                                         for token in seq
+                                         for piece in self.preprocess(token)),
+                         min_freq=min_freq,
+                         specials=self.specials,
+                         unk_index=self.unk_index)
+        self.vocab = build_vocab(dataset)
+
+        if not embed:
+            self.embed = None
+        else:
+            tokens = self.preprocess(embed.tokens)
+            # if the `unk` token has existed in the pretrained,
+            # then replace it with a self-defined one
+            if embed.unk:
+                tokens[embed.unk_index] = self.unk
+
+            self.vocab.update(tokens)
+            self.embed = torch.zeros(len(self.vocab), embed.dim)
+            self.embed[self.vocab[tokens]] = embed.vectors
+            if norm is not None:
+                self.embed = norm(self.embed)
+        return self
+
+    def transform(self, sequences: Iterable[List[str]]) -> Iterable[torch.Tensor]:
+        for seq in sequences:
+            seq = [self.preprocess(token) for token in seq]
+            if self.use_vocab:
+                seq = [[self.vocab[i] if i in self.vocab else self.unk_index for i in token] if token else [self.unk_index]
+                       for token in seq]
+            if self.bos:
+                seq = [[self.bos_index]] + seq
+            if self.delay > 0:
+                seq = seq + [[self.pad_index] for _ in range(self.delay)]
+            if self.eos:
+                seq = seq + [[self.eos_index]]
+            if self.fix_len > 0:
+                seq = [ids[:self.fix_len] for ids in seq]
+            yield pad([torch.tensor(ids, dtype=torch.long) for ids in seq], self.pad_index)
+
+
+class ChartField(Field):
+    r"""
+    Field dealing with chart inputs.
+
+    Examples:
+        >>> chart = [[    None,    'NP',    None,    None,    'S*',     'S'],
+                     [    None,    None,   'VP*',    None,    'VP',    None],
+                     [    None,    None,    None,   'VP*', 'S::VP',    None],
+                     [    None,    None,    None,    None,    'NP',    None],
+                     [    None,    None,    None,    None,    None,    'S*'],
+                     [    None,    None,    None,    None,    None,    None]]
+        >>> next(field.transform([chart]))
+        tensor([[ -1,  37,  -1,  -1, 107,  79],
+                [ -1,  -1, 120,  -1, 112,  -1],
+                [ -1,  -1,  -1, 120,  86,  -1],
+                [ -1,  -1,  -1,  -1,  37,  -1],
+                [ -1,  -1,  -1,  -1,  -1, 107],
+                [ -1,  -1,  -1,  -1,  -1,  -1]])
+    """
+
+    def build(
+        self,
+        dataset: Dataset,
+        min_freq: int = 1
+    ) -> ChartField:
+        @wait
+        def build_vocab(dataset):
+            return Vocab(counter=Counter(i
+                                         for chart in progress_bar(getattr(dataset, self.name))
+                                         for row in self.preprocess(chart)
+                                         for i in row if i is not None),
+                         min_freq=min_freq,
+                         specials=self.specials,
+                         unk_index=self.unk_index)
+        self.vocab = build_vocab(dataset)
+        return self
+
+    def transform(self, charts: Iterable[List[List]]) -> Iterable[torch.Tensor]:
+        for chart in charts:
+            chart = self.preprocess(chart)
+            if self.use_vocab:
+                chart = [[self.vocab[i] if i is not None else -1 for i in row] for row in chart]
+            if self.bos:
+                chart = [[self.bos_index]*len(chart[0])] + chart
+            if self.eos:
+                chart = chart + [[self.eos_index]*len(chart[0])]
+            yield torch.tensor(chart, dtype=torch.long)
diff --git a/tania_scripts/supar/utils/fn.py b/tania_scripts/supar/utils/fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..95d826a1dea476907b520a9eff8ea6ceb6a50f95
--- /dev/null
+++ b/tania_scripts/supar/utils/fn.py
@@ -0,0 +1,383 @@
+# -*- coding: utf-8 -*-
+
+import gzip
+import mmap
+import os
+import pickle
+import shutil
+import struct
+import sys
+import tarfile
+import unicodedata
+import urllib
+import zipfile
+from collections import defaultdict
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from supar.utils.common import CACHE
+from supar.utils.parallel import wait
+
+
+def ispunct(token: str) -> bool:
+    return all(unicodedata.category(char).startswith('P') for char in token)
+
+
+def isfullwidth(token: str) -> bool:
+    return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] for char in token)
+
+
+def islatin(token: str) -> bool:
+    return all('LATIN' in unicodedata.name(char) for char in token)
+
+
+def isdigit(token: str) -> bool:
+    return all('DIGIT' in unicodedata.name(char) for char in token)
+
+
+def tohalfwidth(token: str) -> str:
+    return unicodedata.normalize('NFKC', token)
+
+
+def kmeans(x: List[int], k: int, max_it: int = 32) -> Tuple[List[float], List[List[int]]]:
+    r"""
+    KMeans algorithm for clustering the sentences by length.
+
+    Args:
+        x (List[int]):
+            The list of sentence lengths.
+        k (int):
+            The number of clusters, which is an approximate value.
+            The final number of clusters can be less or equal to `k`.
+        max_it (int):
+            Maximum number of iterations.
+            If centroids does not converge after several iterations, the algorithm will be early stopped.
+
+    Returns:
+        List[float], List[List[int]]:
+            The first list contains average lengths of sentences in each cluster.
+            The second is the list of clusters holding indices of data points.
+
+    Examples:
+        >>> x = torch.randint(10, 20, (10,)).tolist()
+        >>> x
+        [15, 10, 17, 11, 18, 13, 17, 19, 18, 14]
+        >>> centroids, clusters = kmeans(x, 3)
+        >>> centroids
+        [10.5, 14.0, 17.799999237060547]
+        >>> clusters
+        [[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]]
+    """
+
+    x = torch.tensor(x, dtype=torch.float)
+    # collect unique datapoints
+    datapoints, indices, freqs = x.unique(return_inverse=True, return_counts=True)
+    # the number of clusters must not be greater than the number of datapoints
+    k = min(len(datapoints), k)
+    # initialize k centroids randomly
+    centroids = datapoints[torch.randperm(len(datapoints))[:k]]
+    # assign each datapoint to the cluster with the closest centroid
+    dists, y = torch.abs_(datapoints.unsqueeze(-1) - centroids).min(-1)
+
+    for _ in range(max_it):
+        # if an empty cluster is encountered,
+        # choose the farthest datapoint from the biggest cluster and move that the empty one
+        mask = torch.arange(k).unsqueeze(-1).eq(y)
+        none = torch.where(~mask.any(-1))[0].tolist()
+        for i in none:
+            # the biggest cluster
+            biggest = torch.where(mask[mask.sum(-1).argmax()])[0]
+            # the datapoint farthest from the centroid of the biggest cluster
+            farthest = dists[biggest].argmax()
+            # update the assigned cluster of the farthest datapoint
+            y[biggest[farthest]] = i
+            # re-calculate the mask
+            mask = torch.arange(k).unsqueeze(-1).eq(y)
+        # update the centroids
+        centroids, old = (datapoints * freqs * mask).sum(-1) / (freqs * mask).sum(-1), centroids
+        # re-assign all datapoints to clusters
+        dists, y = torch.abs_(datapoints.unsqueeze(-1) - centroids).min(-1)
+        # stop iteration early if the centroids converge
+        if centroids.equal(old):
+            break
+    # assign all datapoints to the new-generated clusters
+    # the empty ones are discarded
+    assigned = y.unique().tolist()
+    # get the centroids of the assigned clusters
+    centroids = centroids[assigned].tolist()
+    # map all values of datapoints to buckets
+    clusters = [torch.where(indices.unsqueeze(-1).eq(torch.where(y.eq(i))[0]).any(-1))[0].tolist() for i in assigned]
+
+    return centroids, clusters
+
+
+def stripe(x: torch.Tensor, n: int, w: int, offset: Tuple = (0, 0), horizontal: bool = True) -> torch.Tensor:
+    r"""
+    Returns a parallelogram stripe of the tensor.
+
+    Args:
+        x (~torch.Tensor): the input tensor with 2 or more dims.
+        n (int): the length of the stripe.
+        w (int): the width of the stripe.
+        offset (tuple): the offset of the first two dims.
+        horizontal (bool): `True` if returns a horizontal stripe; `False` otherwise.
+
+    Returns:
+        A parallelogram stripe of the tensor.
+
+    Examples:
+        >>> x = torch.arange(25).view(5, 5)
+        >>> x
+        tensor([[ 0,  1,  2,  3,  4],
+                [ 5,  6,  7,  8,  9],
+                [10, 11, 12, 13, 14],
+                [15, 16, 17, 18, 19],
+                [20, 21, 22, 23, 24]])
+        >>> stripe(x, 2, 3)
+        tensor([[0, 1, 2],
+                [6, 7, 8]])
+        >>> stripe(x, 2, 3, (1, 1))
+        tensor([[ 6,  7,  8],
+                [12, 13, 14]])
+        >>> stripe(x, 2, 3, (1, 1), 0)
+        tensor([[ 6, 11, 16],
+                [12, 17, 22]])
+    """
+
+    x = x.contiguous()
+    seq_len, stride = x.size(1), list(x.stride())
+    numel = stride[1]
+    return x.as_strided(size=(n, w, *x.shape[2:]),
+                        stride=[(seq_len + 1) * numel, (1 if horizontal else seq_len) * numel] + stride[2:],
+                        storage_offset=(offset[0]*seq_len+offset[1])*numel)
+
+
+def diagonal_stripe(x: torch.Tensor, offset: int = 1) -> torch.Tensor:
+    r"""
+    Returns a diagonal parallelogram stripe of the tensor.
+
+    Args:
+        x (~torch.Tensor): the input tensor with 3 or more dims.
+        offset (int): which diagonal to consider. Default: 1.
+
+    Returns:
+        A diagonal parallelogram stripe of the tensor.
+
+    Examples:
+        >>> x = torch.arange(125).view(5, 5, 5)
+        >>> diagonal_stripe(x)
+        tensor([[ 5],
+                [36],
+                [67],
+                [98]])
+        >>> diagonal_stripe(x, 2)
+        tensor([[10, 11],
+                [41, 42],
+                [72, 73]])
+        >>> diagonal_stripe(x, -2)
+        tensor([[ 50,  51],
+                [ 81,  82],
+                [112, 113]])
+    """
+
+    x = x.contiguous()
+    seq_len, stride = x.size(1), list(x.stride())
+    n, w, numel = seq_len - abs(offset), abs(offset), stride[2]
+    return x.as_strided(size=(n, w, *x.shape[3:]),
+                        stride=[((seq_len + 1) * x.size(2) + 1) * numel] + stride[2:],
+                        storage_offset=offset*stride[1] if offset > 0 else abs(offset)*stride[0])
+
+
+def expanded_stripe(x: torch.Tensor, n: int, w: int, offset: Tuple = (0, 0)) -> torch.Tensor:
+    r"""
+    Returns an expanded parallelogram stripe of the tensor.
+
+    Args:
+        x (~torch.Tensor): the input tensor with 2 or more dims.
+        n (int): the length of the stripe.
+        w (int): the width of the stripe.
+        offset (tuple): the offset of the first two dims.
+
+    Returns:
+        An expanded parallelogram stripe of the tensor.
+
+    Examples:
+        >>> x = torch.arange(25).view(5, 5)
+        >>> x
+        tensor([[ 0,  1,  2,  3,  4],
+                [ 5,  6,  7,  8,  9],
+                [10, 11, 12, 13, 14],
+                [15, 16, 17, 18, 19],
+                [20, 21, 22, 23, 24]])
+        >>> expanded_stripe(x, 2, 3)
+        tensor([[[ 0,  1,  2,  3,  4],
+                 [ 5,  6,  7,  8,  9],
+                 [10, 11, 12, 13, 14]],
+
+                [[ 5,  6,  7,  8,  9],
+                 [10, 11, 12, 13, 14],
+                 [15, 16, 17, 18, 19]]])
+        >>> expanded_stripe(x, 2, 3, (1, 1))
+        tensor([[[ 5,  6,  7,  8,  9],
+                 [10, 11, 12, 13, 14],
+                 [15, 16, 17, 18, 19]],
+
+                [[10, 11, 12, 13, 14],
+                 [15, 16, 17, 18, 19],
+                 [20, 21, 22, 23, 24]]])
+
+    """
+    x = x.contiguous()
+    stride = list(x.stride())
+    return x.as_strided(size=(n, w, *list(x.shape[1:])),
+                        stride=stride[:1] + [stride[0]] + stride[1:],
+                        storage_offset=(offset[1])*stride[0])
+
+
+def pad(
+    tensors: List[torch.Tensor],
+    padding_value: int = 0,
+    total_length: int = None,
+    padding_side: str = 'right'
+) -> torch.Tensor:
+    size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors)
+                             for i in range(len(tensors[0].size()))]
+    if total_length is not None:
+        assert total_length >= size[1]
+        size[1] = total_length
+    out_tensor = tensors[0].data.new(*size).fill_(padding_value)
+    for i, tensor in enumerate(tensors):
+        out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor
+    return out_tensor
+
+
+@wait
+def download(url: str, path: Optional[str] = None, reload: bool = False, clean: bool = False) -> str:
+    filename = os.path.basename(urllib.parse.urlparse(url).path)
+    if path is None:
+        path = CACHE
+    os.makedirs(path, exist_ok=True)
+    path = os.path.join(path, filename)
+    if reload and os.path.exists(path):
+        os.remove(path)
+    if not os.path.exists(path):
+        sys.stderr.write(f"Downloading {url} to {path}\n")
+        try:
+            torch.hub.download_url_to_file(url, path, progress=True)
+        except (ValueError, urllib.error.URLError):
+            raise RuntimeError(f"File {url} unavailable. Please try other sources.")
+    return extract(path, reload, clean)
+
+
+def extract(path: str, reload: bool = False, clean: bool = False) -> str:
+    extracted = path
+    if zipfile.is_zipfile(path):
+        with zipfile.ZipFile(path) as f:
+            extracted = os.path.join(os.path.dirname(path), f.infolist()[0].filename)
+            if reload or not os.path.exists(extracted):
+                f.extractall(os.path.dirname(path))
+    elif tarfile.is_tarfile(path):
+        with tarfile.open(path) as f:
+            extracted = os.path.join(os.path.dirname(path), f.getnames()[0])
+            if reload or not os.path.exists(extracted):
+                f.extractall(os.path.dirname(path))
+    elif path.endswith('.gz'):
+        extracted = path[:-3]
+        with gzip.open(path) as fgz:
+            with open(extracted, 'wb') as f:
+                shutil.copyfileobj(fgz, f)
+    if clean:
+        os.remove(path)
+    return extracted
+
+
+def binarize(
+    data: Union[List[str], Dict[str, Iterable]],
+    fbin: str = None,
+    merge: bool = False
+) -> Tuple[str, torch.Tensor]:
+    start, meta = 0, defaultdict(list)
+    # the binarized file is organized as:
+    # `data`: pickled objects
+    # `meta`: a dict containing the pointers of each kind of data
+    # `index`: fixed size integers representing the storage positions of the meta data
+    with open(fbin, 'wb') as f:
+        # in this case, data should be a list of binarized files
+        if merge:
+            for file in data:
+                if not os.path.exists(file):
+                    raise RuntimeError("Some files are missing. Please check the paths")
+                mi = debinarize(file, meta=True)
+                for key, val in mi.items():
+                    val[:, 0] += start
+                    meta[key].append(val)
+                with open(file, 'rb') as fi:
+                    length = int(sum(val[:, 1].sum() for val in mi.values()))
+                    f.write(fi.read(length))
+                start = start + length
+            meta = {key: torch.cat(val) for key, val in meta.items()}
+        else:
+            for key, val in data.items():
+                for i in val:
+                    bytes = pickle.dumps(i)
+                    f.write(bytes)
+                    meta[key].append((start, len(bytes)))
+                    start = start + len(bytes)
+            meta = {key: torch.tensor(val) for key, val in meta.items()}
+        pickled = pickle.dumps(meta)
+        # append the meta data to the end of the bin file
+        f.write(pickled)
+        # record the positions of the meta data
+        f.write(struct.pack('LL', start, len(pickled)))
+    return fbin, meta
+
+
+def debinarize(
+    fbin: str,
+    pos_or_key: Optional[Union[Tuple[int, int], str]] = (0, 0),
+    meta: bool = False
+) -> Union[Any, Iterable[Any]]:
+    with open(fbin, 'rb') as f, mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
+        if meta or isinstance(pos_or_key, str):
+            length = len(struct.pack('LL', 0, 0))
+            mm.seek(-length, os.SEEK_END)
+            offset, length = struct.unpack('LL', mm.read(length))
+            mm.seek(offset)
+            if meta:
+                return pickle.loads(mm.read(length))
+            # fetch by key
+            objs, meta = [], pickle.loads(mm.read(length))[pos_or_key]
+            for offset, length in meta.tolist():
+                mm.seek(offset)
+                objs.append(pickle.loads(mm.read(length)))
+            return objs
+        # fetch by positions
+        offset, length = pos_or_key
+        mm.seek(offset)
+        return pickle.loads(mm.read(length))
+
+
+def resolve_config(args: Union[Dict, DictConfig]) -> DictConfig:
+    OmegaConf.register_new_resolver("eval", eval)
+    return DictConfig(OmegaConf.to_container(args, resolve=True))
+
+
+def collect_args(args: Union[Dict, DictConfig]) -> DictConfig:
+    for key in ('self', 'cls', '__class__'):
+        args.pop(key, None)
+    args.update(args.pop('kwargs', dict()))
+    return DictConfig(args)
+
+
+def get_rng_state() -> Dict[str, torch.Tensor]:
+    state = {'rng_state': torch.get_rng_state()}
+    if torch.cuda.is_available():
+        state['cuda_rng_state'] = torch.cuda.get_rng_state()
+    return state
+
+
+def set_rng_state(state: Dict) -> None:
+    torch.set_rng_state(state['rng_state'])
+    if torch.cuda.is_available():
+        torch.cuda.set_rng_state(state['cuda_rng_state'])
diff --git a/tania_scripts/supar/utils/logging.py b/tania_scripts/supar/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..73e4a77a0ca5f660083ae74ff0c8377efeb5a280
--- /dev/null
+++ b/tania_scripts/supar/utils/logging.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+
+import logging
+import os
+from logging import FileHandler, Formatter, Handler, Logger, StreamHandler
+from typing import Iterable, Optional
+
+from supar.utils.parallel import is_master
+from tqdm import tqdm
+
+
+def get_logger(name: Optional[str] = None) -> Logger:
+    logger = logging.getLogger(name)
+    # init the root logger
+    if name is None:
+        logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',
+                            datefmt='%Y-%m-%d %H:%M:%S',
+                            handlers=[TqdmHandler()])
+    return logger
+
+
+class TqdmHandler(StreamHandler):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def emit(self, record):
+        try:
+            msg = self.format(record)
+            tqdm.write(msg)
+            self.flush()
+        except (KeyboardInterrupt, SystemExit):
+            raise
+        except Exception:
+            self.handleError(record)
+
+
+def init_logger(
+    logger: Logger,
+    path: Optional[str] = None,
+    mode: str = 'w',
+    handlers: Optional[Iterable[Handler]] = None,
+    verbose: bool = True
+) -> Logger:
+    if not handlers:
+        if path:
+            os.makedirs(os.path.dirname(path) or './', exist_ok=True)
+            logger.addHandler(FileHandler(path, mode))
+    for handler in logger.handlers:
+        handler.setFormatter(ColoredFormatter(colored=not isinstance(handler, FileHandler)))
+    logger.setLevel(logging.INFO if is_master() and verbose else logging.WARNING)
+    return logger
+
+
+def progress_bar(
+    iterator: Iterable,
+    ncols: Optional[int] = None,
+    bar_format: str = '{l_bar}{bar:20}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}',
+    leave: bool = False,
+    **kwargs
+) -> tqdm:
+    return tqdm(iterator,
+                ncols=ncols,
+                bar_format=bar_format,
+                ascii=True,
+                disable=(not (logger.level == logging.INFO and is_master())),
+                leave=leave,
+                **kwargs)
+
+
+class ColoredFormatter(Formatter):
+
+    BLACK = '\033[30m'
+    RED = '\033[31m'
+    GREEN = '\033[32m'
+    GREY = '\033[37m'
+    RESET = '\033[0m'
+
+    COLORS = {
+        logging.ERROR: RED,
+        logging.WARNING: RED,
+        logging.INFO: GREEN,
+        logging.DEBUG: BLACK,
+        logging.NOTSET: BLACK
+    }
+
+    def __init__(self, colored=True, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.colored = colored
+
+    def format(self, record):
+        fmt = '[%(asctime)s %(levelname)s] %(message)s'
+        if self.colored:
+            fmt = f'{self.COLORS[record.levelno]}[%(asctime)s %(levelname)s]{self.RESET} %(message)s'
+        datefmt = '%Y-%m-%d %H:%M:%S'
+        return Formatter(fmt=fmt, datefmt=datefmt).format(record)
+
+
+logger = get_logger()
diff --git a/tania_scripts/supar/utils/metric.py b/tania_scripts/supar/utils/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..f64940c1a909006d264c86976242ea6ff931affe
--- /dev/null
+++ b/tania_scripts/supar/utils/metric.py
@@ -0,0 +1,346 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from collections import Counter
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+
+class Metric(object):
+
+    def __init__(self, reverse: Optional[bool] = None, eps: float = 1e-12) -> Metric:
+        super().__init__()
+
+        self.n = 0.0
+        self.count = 0.0
+        self.total_loss = 0.0
+        self.reverse = reverse
+        self.eps = eps
+
+    def __repr__(self):
+        return f"loss: {self.loss:.4f} - " + ' '.join([f"{key}: {val:6.2%}" for key, val in self.values.items()])
+
+    def __lt__(self, other: Metric) -> bool:
+        if not hasattr(self, 'score'):
+            return True
+        if not hasattr(other, 'score'):
+            return False
+        return (self.score < other.score) if not self.reverse else (self.score > other.score)
+
+    def __le__(self, other: Metric) -> bool:
+        if not hasattr(self, 'score'):
+            return True
+        if not hasattr(other, 'score'):
+            return False
+        return (self.score <= other.score) if not self.reverse else (self.score >= other.score)
+
+    def __gt__(self, other: Metric) -> bool:
+        if not hasattr(self, 'score'):
+            return False
+        if not hasattr(other, 'score'):
+            return True
+        return (self.score > other.score) if not self.reverse else (self.score < other.score)
+
+    def __ge__(self, other: Metric) -> bool:
+        if not hasattr(self, 'score'):
+            return False
+        if not hasattr(other, 'score'):
+            return True
+        return (self.score >= other.score) if not self.reverse else (self.score <= other.score)
+
+    def __add__(self, other: Metric) -> Metric:
+        return other
+
+    @property
+    def score(self):
+        raise AttributeError
+
+    @property
+    def loss(self):
+        return self.total_loss / (self.count + self.eps)
+
+    @property
+    def values(self):
+        raise AttributeError
+
+
+class AttachmentMetric(Metric):
+
+    def __init__(
+        self,
+        loss: Optional[float] = None,
+        preds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        golds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        mask: Optional[torch.BoolTensor] = None,
+        reverse: bool = False,
+        eps: float = 1e-12
+    ) -> AttachmentMetric:
+        super().__init__(reverse=reverse, eps=eps)
+
+        self.n_ucm = 0.0
+        self.n_lcm = 0.0
+        self.total = 0.0
+        self.correct_arcs = 0.0
+        self.correct_rels = 0.0
+
+        if loss is not None:
+            self(loss, preds, golds, mask)
+
+    def __call__(
+        self,
+        loss: float,
+        preds: Tuple[torch.Tensor, torch.Tensor],
+        golds: Tuple[torch.Tensor, torch.Tensor],
+        mask: torch.BoolTensor
+    ) -> AttachmentMetric:
+        lens = mask.sum(1)
+        arc_preds, rel_preds, arc_golds, rel_golds = *preds, *golds
+        arc_mask = arc_preds.eq(arc_golds) & mask
+        rel_mask = rel_preds.eq(rel_golds) & arc_mask
+        arc_mask_seq, rel_mask_seq = arc_mask[mask], rel_mask[mask]
+
+        self.n += len(mask)
+        self.count += 1
+        self.total_loss += float(loss)
+        self.n_ucm += arc_mask.sum(1).eq(lens).sum().item()
+        self.n_lcm += rel_mask.sum(1).eq(lens).sum().item()
+
+        self.total += len(arc_mask_seq)
+        self.correct_arcs += arc_mask_seq.sum().item()
+        self.correct_rels += rel_mask_seq.sum().item()
+        return self
+
+    def __add__(self, other: AttachmentMetric) -> AttachmentMetric:
+        metric = AttachmentMetric(eps=self.eps)
+        metric.n = self.n + other.n
+        metric.count = self.count + other.count
+        metric.total_loss = self.total_loss + other.total_loss
+        metric.n_ucm = self.n_ucm + other.n_ucm
+        metric.n_lcm = self.n_lcm + other.n_lcm
+        metric.total = self.total + other.total
+        metric.correct_arcs = self.correct_arcs + other.correct_arcs
+        metric.correct_rels = self.correct_rels + other.correct_rels
+        metric.reverse = self.reverse or other.reverse
+        return metric
+
+    @property
+    def score(self):
+        return self.las
+
+    @property
+    def ucm(self):
+        return self.n_ucm / (self.n + self.eps)
+
+    @property
+    def lcm(self):
+        return self.n_lcm / (self.n + self.eps)
+
+    @property
+    def uas(self):
+        return self.correct_arcs / (self.total + self.eps)
+
+    @property
+    def las(self):
+        return self.correct_rels / (self.total + self.eps)
+
+    @property
+    def values(self) -> Dict:
+        return {'UCM': self.ucm,
+                'LCM': self.lcm,
+                'UAS': self.uas,
+                'LAS': self.las}
+
+
+class SpanMetric(Metric):
+
+    def __init__(
+        self,
+        loss: Optional[float] = None,
+        preds: Optional[List[List[Tuple]]] = None,
+        golds: Optional[List[List[Tuple]]] = None,
+        reverse: bool = False,
+        eps: float = 1e-12
+    ) -> SpanMetric:
+        super().__init__(reverse=reverse, eps=eps)
+
+        self.n_ucm = 0.0
+        self.n_lcm = 0.0
+        self.utp = 0.0
+        self.ltp = 0.0
+        self.pred = 0.0
+        self.gold = 0.0
+
+        if loss is not None:
+            self(loss, preds, golds)
+
+    def __call__(
+        self,
+        loss: float,
+        preds: List[List[Tuple]],
+        golds: List[List[Tuple]]
+    ) -> SpanMetric:
+        self.n += len(preds)
+        self.count += 1
+        self.total_loss += float(loss)
+        for pred, gold in zip(preds, golds):
+            upred, ugold = Counter([tuple(span[:-1]) for span in pred]), Counter([tuple(span[:-1]) for span in gold])
+            lpred, lgold = Counter([tuple(span) for span in pred]), Counter([tuple(span) for span in gold])
+            utp, ltp = list((upred & ugold).elements()), list((lpred & lgold).elements())
+            self.n_ucm += len(utp) == len(pred) == len(gold)
+            self.n_lcm += len(ltp) == len(pred) == len(gold)
+            self.utp += len(utp)
+            self.ltp += len(ltp)
+            self.pred += len(pred)
+            self.gold += len(gold)
+        return self
+
+    def __add__(self, other: SpanMetric) -> SpanMetric:
+        metric = SpanMetric(eps=self.eps)
+        metric.n = self.n + other.n
+        metric.count = self.count + other.count
+        metric.total_loss = self.total_loss + other.total_loss
+        metric.n_ucm = self.n_ucm + other.n_ucm
+        metric.n_lcm = self.n_lcm + other.n_lcm
+        metric.utp = self.utp + other.utp
+        metric.ltp = self.ltp + other.ltp
+        metric.pred = self.pred + other.pred
+        metric.gold = self.gold + other.gold
+        metric.reverse = self.reverse or other.reverse
+        return metric
+
+    @property
+    def score(self):
+        return self.lf
+
+    @property
+    def ucm(self):
+        return self.n_ucm / (self.n + self.eps)
+
+    @property
+    def lcm(self):
+        return self.n_lcm / (self.n + self.eps)
+
+    @property
+    def up(self):
+        return self.utp / (self.pred + self.eps)
+
+    @property
+    def ur(self):
+        return self.utp / (self.gold + self.eps)
+
+    @property
+    def uf(self):
+        return 2 * self.utp / (self.pred + self.gold + self.eps)
+
+    @property
+    def lp(self):
+        return self.ltp / (self.pred + self.eps)
+
+    @property
+    def lr(self):
+        return self.ltp / (self.gold + self.eps)
+
+    @property
+    def lf(self):
+        return 2 * self.ltp / (self.pred + self.gold + self.eps)
+
+    @property
+    def values(self) -> Dict:
+        return {'UCM': self.ucm,
+                'LCM': self.lcm,
+                'UP': self.up,
+                'UR': self.ur,
+                'UF': self.uf,
+                'LP': self.lp,
+                'LR': self.lr,
+                'LF': self.lf}
+
+
+class ChartMetric(Metric):
+
+    def __init__(
+        self,
+        loss: Optional[float] = None,
+        preds: Optional[torch.Tensor] = None,
+        golds: Optional[torch.Tensor] = None,
+        reverse: bool = False,
+        eps: float = 1e-12
+    ) -> ChartMetric:
+        super().__init__(reverse=reverse, eps=eps)
+
+        self.tp = 0.0
+        self.utp = 0.0
+        self.pred = 0.0
+        self.gold = 0.0
+
+        if loss is not None:
+            self(loss, preds, golds)
+
+    def __call__(
+        self,
+        loss: float,
+        preds: torch.Tensor,
+        golds: torch.Tensor
+    ) -> ChartMetric:
+        self.n += len(preds)
+        self.count += 1
+        self.total_loss += float(loss)
+        pred_mask = preds.ge(0)
+        gold_mask = golds.ge(0)
+        span_mask = pred_mask & gold_mask
+        self.pred += pred_mask.sum().item()
+        self.gold += gold_mask.sum().item()
+        self.tp += (preds.eq(golds) & span_mask).sum().item()
+        self.utp += span_mask.sum().item()
+        return self
+
+    def __add__(self, other: ChartMetric) -> ChartMetric:
+        metric = ChartMetric(eps=self.eps)
+        metric.n = self.n + other.n
+        metric.count = self.count + other.count
+        metric.total_loss = self.total_loss + other.total_loss
+        metric.tp = self.tp + other.tp
+        metric.utp = self.utp + other.utp
+        metric.pred = self.pred + other.pred
+        metric.gold = self.gold + other.gold
+        metric.reverse = self.reverse or other.reverse
+        return metric
+
+    @property
+    def score(self):
+        return self.f
+
+    @property
+    def up(self):
+        return self.utp / (self.pred + self.eps)
+
+    @property
+    def ur(self):
+        return self.utp / (self.gold + self.eps)
+
+    @property
+    def uf(self):
+        return 2 * self.utp / (self.pred + self.gold + self.eps)
+
+    @property
+    def p(self):
+        return self.tp / (self.pred + self.eps)
+
+    @property
+    def r(self):
+        return self.tp / (self.gold + self.eps)
+
+    @property
+    def f(self):
+        return 2 * self.tp / (self.pred + self.gold + self.eps)
+
+    @property
+    def values(self) -> Dict:
+        return {'UP': self.up,
+                'UR': self.ur,
+                'UF': self.uf,
+                'P': self.p,
+                'R': self.r,
+                'F': self.f}
diff --git a/tania_scripts/supar/utils/optim.py b/tania_scripts/supar/utils/optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..e67730b4eacbf68008f6b955cb0fa7bea2f33bc1
--- /dev/null
+++ b/tania_scripts/supar/utils/optim.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class InverseSquareRootLR(_LRScheduler):
+
+    def __init__(
+        self,
+        optimizer: Optimizer,
+        warmup_steps: int,
+        last_epoch: int = -1
+    ) -> InverseSquareRootLR:
+        self.warmup_steps = warmup_steps
+        self.factor = warmup_steps ** 0.5
+        super(InverseSquareRootLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        epoch = max(self.last_epoch, 1)
+        scale = min(epoch ** -0.5, epoch * self.warmup_steps ** -1.5) * self.factor
+        return [scale * lr for lr in self.base_lrs]
+
+
+class PolynomialLR(_LRScheduler):
+    r"""
+    Set the learning rate for each parameter group using a polynomial defined as: `lr = base_lr * (1 - t / T) ^ (power)`,
+    where `t` is the current epoch and `T` is the maximum number of epochs.
+    """
+
+    def __init__(
+        self,
+        optimizer: Optimizer,
+        warmup_steps: int = 0,
+        steps: int = 100000,
+        power: float = 1.,
+        last_epoch: int = -1
+    ) -> PolynomialLR:
+        self.warmup_steps = warmup_steps
+        self.steps = steps
+        self.power = power
+        super(PolynomialLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        epoch = max(self.last_epoch, 1)
+        if epoch <= self.warmup_steps:
+            return [epoch / self.warmup_steps * lr for lr in self.base_lrs]
+        t, T = (epoch - self.warmup_steps), (self.steps - self.warmup_steps)
+        return [lr * (1 - t / T) ** self.power for lr in self.base_lrs]
+
+
+def LinearLR(optimizer: Optimizer, warmup_steps: int = 0, steps: int = 100000, last_epoch: int = -1) -> PolynomialLR:
+    return PolynomialLR(optimizer, warmup_steps, steps, 1, last_epoch)
diff --git a/tania_scripts/supar/utils/parallel.py b/tania_scripts/supar/utils/parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f1e0b0cee6b98d07bbec7e2193fcd5bcb97343
--- /dev/null
+++ b/tania_scripts/supar/utils/parallel.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import functools
+import os
+import re
+from typing import Any, Iterable
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+
+class DistributedDataParallel(nn.parallel.DistributedDataParallel):
+
+    def __init__(self, module, **kwargs):
+        super().__init__(module, **kwargs)
+
+    def __getattr__(self, name):
+        wrapped = super().__getattr__('module')
+        if hasattr(wrapped, name):
+            return getattr(wrapped, name)
+        return super().__getattr__(name)
+
+
+def wait(fn) -> Any:
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        value = None
+        if is_master():
+            value = fn(*args, **kwargs)
+        if is_dist():
+            dist.barrier()
+            value = gather(value)[0]
+        return value
+    return wrapper
+
+
+def gather(obj: Any) -> Iterable[Any]:
+    objs = [None] * dist.get_world_size()
+    dist.all_gather_object(objs, obj)
+    return objs
+
+
+def reduce(obj: Any, reduction: str = 'sum') -> Any:
+    objs = gather(obj)
+    if reduction == 'sum':
+        return functools.reduce(lambda x, y: x + y, objs)
+    elif reduction == 'mean':
+        return functools.reduce(lambda x, y: x + y, objs) / len(objs)
+    elif reduction == 'min':
+        return min(objs)
+    elif reduction == 'max':
+        return max(objs)
+    else:
+        raise NotImplementedError(f"Unsupported reduction {reduction}")
+
+
+def is_dist():
+    return dist.is_available() and dist.is_initialized()
+
+
+def is_master():
+    return not is_dist() or dist.get_rank() == 0
+
+
+def get_free_port():
+    import socket
+    s = socket.socket()
+    s.bind(('', 0))
+    port = str(s.getsockname()[1])
+    s.close()
+    return port
+
+
+def get_device_count():
+    if 'CUDA_VISIBLE_DEVICES' in os.environ:
+        return len(re.findall(r'\d+', os.environ['CUDA_VISIBLE_DEVICES']))
+    return torch.cuda.device_count()
diff --git a/tania_scripts/supar/utils/tokenizer.py b/tania_scripts/supar/utils/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..35920b63a672c8b2cb4283355c779313bb5fbebb
--- /dev/null
+++ b/tania_scripts/supar/utils/tokenizer.py
@@ -0,0 +1,220 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+import re
+import tempfile
+from collections import Counter, defaultdict
+from typing import Any, Dict, List, Optional, Union
+
+import torch.distributed as dist
+from supar.utils.parallel import is_dist, is_master
+from supar.utils.vocab import Vocab
+from torch.distributions.utils import lazy_property
+
+
+class Tokenizer:
+
+    def __init__(self, lang: str = 'en') -> Tokenizer:
+        import stanza
+        try:
+            self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', verbose=False, tokenize_no_ssplit=True)
+        except Exception:
+            stanza.download(lang=lang, resources_url='stanford')
+            self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', verbose=False, tokenize_no_ssplit=True)
+
+    def __call__(self, text: str) -> List[str]:
+        return [i.text for i in self.pipeline(text).sentences[0].tokens]
+
+
+class TransformerTokenizer:
+
+    def __init__(self, name) -> TransformerTokenizer:
+        from transformers import AutoTokenizer
+        self.name = name
+        try:
+            self.tokenizer = AutoTokenizer.from_pretrained(name, local_files_only=True)
+        except Exception:
+            self.tokenizer = AutoTokenizer.from_pretrained(name, local_files_only=False)
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({self.name})"
+
+    def __len__(self) -> int:
+        return self.vocab_size
+
+    def __call__(self, text: str) -> List[str]:
+        from tokenizers.pre_tokenizers import ByteLevel
+        if isinstance(self.tokenizer.backend_tokenizer.pre_tokenizer, ByteLevel):
+            text = ' ' + text
+        return self.tokenizer.tokenize(text)
+
+    def __getattr__(self, name: str) -> Any:
+        return getattr(self.tokenizer, name)
+
+    def __getstate__(self) -> Dict:
+        return self.__dict__
+
+    def __setstate__(self, state: Dict):
+        self.__dict__.update(state)
+
+    @lazy_property
+    def vocab(self):
+        return defaultdict(lambda: self.tokenizer.vocab[self.unk], self.tokenizer.get_vocab())
+
+    @lazy_property
+    def tokens(self):
+        return sorted(self.vocab, key=lambda x: self.vocab[x])
+
+    @property
+    def vocab_size(self):
+        return len(self.vocab)
+
+    @property
+    def pad(self):
+        return self.tokenizer.pad_token
+
+    @property
+    def unk(self):
+        return self.tokenizer.unk_token
+
+    @property
+    def bos(self):
+        return self.tokenizer.bos_token or self.tokenizer.cls_token
+
+    @property
+    def eos(self):
+        return self.tokenizer.eos_token or self.tokenizer.sep_token
+
+    def decode(self, text: List) -> str:
+        return self.tokenizer.decode(text, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+
+
+class BPETokenizer:
+
+    def __init__(
+        self,
+        path: str = None,
+        files: Optional[List[str]] = None,
+        vocab_size: Optional[int] = 32000,
+        min_freq: Optional[int] = 2,
+        dropout: float = None,
+        backend: str = 'huggingface',
+        pad: Optional[str] = None,
+        unk: Optional[str] = None,
+        bos: Optional[str] = None,
+        eos: Optional[str] = None,
+    ) -> BPETokenizer:
+
+        self.path = path
+        self.files = files
+        self.min_freq = min_freq
+        self.dropout = dropout or .0
+        self.backend = backend
+        self.pad = pad
+        self.unk = unk
+        self.bos = bos
+        self.eos = eos
+        self.special_tokens = [i for i in [pad, unk, bos, eos] if i is not None]
+
+        if backend == 'huggingface':
+            from tokenizers import Tokenizer
+            from tokenizers.decoders import BPEDecoder
+            from tokenizers.models import BPE
+            from tokenizers.pre_tokenizers import WhitespaceSplit
+            from tokenizers.trainers import BpeTrainer
+            path = os.path.join(path, 'tokenizer.json')
+            if is_master() and not os.path.exists(path):
+                # start to train a tokenizer from scratch
+                self.tokenizer = Tokenizer(BPE(dropout=dropout, unk_token=unk))
+                self.tokenizer.pre_tokenizer = WhitespaceSplit()
+                self.tokenizer.decoder = BPEDecoder()
+                self.tokenizer.train(files=files,
+                                     trainer=BpeTrainer(vocab_size=vocab_size,
+                                                        min_frequency=min_freq,
+                                                        special_tokens=self.special_tokens,
+                                                        end_of_word_suffix='</w>'))
+                self.tokenizer.save(path)
+            if is_dist():
+                dist.barrier()
+            self.tokenizer = Tokenizer.from_file(path)
+            self.vocab = self.tokenizer.get_vocab()
+
+        elif backend == 'subword-nmt':
+            import argparse
+            from argparse import Namespace
+
+            from subword_nmt.apply_bpe import BPE, read_vocabulary
+            from subword_nmt.learn_joint_bpe_and_vocab import learn_joint_bpe_and_vocab
+            fmerge = os.path.join(path, 'merge.txt')
+            fvocab = os.path.join(path, 'vocab.txt')
+            separator = '@@'
+            if is_master() and (not os.path.exists(fmerge) or not os.path.exists(fvocab)):
+                with tempfile.TemporaryDirectory() as ftemp:
+                    fall = os.path.join(ftemp, 'fall')
+                    with open(fall, 'w') as f:
+                        for file in files:
+                            with open(file) as fi:
+                                f.write(fi.read())
+                    learn_joint_bpe_and_vocab(Namespace(input=[argparse.FileType()(fall)],
+                                                        output=argparse.FileType('w')(fmerge),
+                                                        symbols=vocab_size,
+                                                        separator=separator,
+                                                        vocab=[argparse.FileType('w')(fvocab)],
+                                                        min_frequency=min_freq,
+                                                        total_symbols=False,
+                                                        verbose=False,
+                                                        num_workers=32))
+            if is_dist():
+                dist.barrier()
+            self.tokenizer = BPE(codes=open(fmerge), separator=separator, vocab=read_vocabulary(open(fvocab), None))
+            self.vocab = Vocab(counter=Counter(self.tokenizer.vocab),
+                               specials=self.special_tokens,
+                               unk_index=self.special_tokens.index(unk))
+        else:
+            raise ValueError(f'Unsupported backend: {backend} not in (huggingface, subword-nmt)')
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__ + f'({self.vocab_size}, min_freq={self.min_freq}'
+        if self.dropout > 0:
+            s += f", dropout={self.dropout}"
+        s += f", backend={self.backend}"
+        if self.pad is not None:
+            s += f", pad={self.pad}"
+        if self.unk is not None:
+            s += f", unk={self.unk}"
+        if self.bos is not None:
+            s += f", bos={self.bos}"
+        if self.eos is not None:
+            s += f", eos={self.eos}"
+        s += ')'
+        return s
+
+    def __len__(self) -> int:
+        return self.vocab_size
+
+    def __call__(self, text: Union[str, List]) -> List[str]:
+        is_pretokenized = isinstance(text, list)
+        if self.backend == 'huggingface':
+            return self.tokenizer.encode(text, is_pretokenized=is_pretokenized).tokens
+        else:
+            if not is_pretokenized:
+                text = text.split()
+            return self.tokenizer.segment_tokens(text, dropout=self.dropout)
+
+    @lazy_property
+    def tokens(self):
+        return sorted(self.vocab, key=lambda x: self.vocab[x])
+
+    @property
+    def vocab_size(self):
+        return len(self.vocab)
+
+    def decode(self, text: List) -> str:
+        if self.backend == 'huggingface':
+            return self.tokenizer.decode(text)
+        else:
+            text = self.vocab[text]
+            text = ' '.join([i for i in text if i not in self.special_tokens])
+            return re.sub(f'({self.tokenizer.separator} )|({self.tokenizer.separator} ?$)', '', text)
diff --git a/tania_scripts/supar/utils/transform.py b/tania_scripts/supar/utils/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8d070c7ae27942474a8004b7688d7033582ba46
--- /dev/null
+++ b/tania_scripts/supar/utils/transform.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from typing import Any, Iterable, Optional, Tuple
+
+import torch
+from torch.distributions.utils import lazy_property
+
+from supar.utils.fn import debinarize
+from supar.utils.logging import get_logger, progress_bar
+
+logger = get_logger(__name__)
+
+
+class Transform(object):
+    r"""
+    A :class:`Transform` object corresponds to a specific data format, which holds several instances of data fields
+    that provide instructions for preprocessing and numericalization, etc.
+
+    Attributes:
+        training (bool):
+            Sets the object in training mode.
+            If ``False``, some data fields not required for predictions won't be returned.
+            Default: ``True``.
+    """
+
+    fields = []
+
+    def __init__(self):
+        self.training = True
+
+    def __len__(self):
+        return len(self.fields)
+
+    def __repr__(self):
+        s = '\n' + '\n'.join([f" {f}" for f in self.flattened_fields]) + '\n'
+        return f"{self.__class__.__name__}({s})"
+
+    def __call__(self, sentences: Iterable[Sentence]) -> Iterable[Sentence]:
+        return [sentence.numericalize(self.flattened_fields) for sentence in progress_bar(sentences)]
+
+    def __getitem__(self, index):
+        return getattr(self, self.fields[index])
+
+    @property
+    def flattened_fields(self):
+        flattened = []
+        for field in self:
+            if field not in self.src and field not in self.tgt:
+                continue
+            if not self.training and field in self.tgt:
+                continue
+            if not isinstance(field, Iterable):
+                field = [field]
+            for f in field:
+                if f is not None:
+                    flattened.append(f)
+        return flattened
+
+    def train(self, training=True):
+        self.training = training
+
+    def eval(self):
+        self.train(False)
+
+    def append(self, field):
+        self.fields.append(field.name)
+        setattr(self, field.name, field)
+
+    @property
+    def src(self):
+        raise AttributeError
+
+    @property
+    def tgt(self):
+        raise AttributeError
+
+
+class Batch(object):
+
+    def __init__(self, sentences: Iterable[Sentence]) -> Batch:
+        self.sentences = sentences
+
+        self.names, self.fields = [], {}
+
+    def __repr__(self):
+        return f'{self.__class__.__name__}({", ".join([f"{name}" for name in self.names])})'
+
+    def __len__(self):
+        return len(self.sentences)
+
+    def __getitem__(self, index):
+        return self.fields[self.names[index]]
+
+    def __getattr__(self, name):
+        return [s.fields[name] for s in self.sentences]
+
+    def __setattr__(self, name: str, value: Iterable[Any]):
+        if name not in ('sentences', 'fields', 'names'):
+            for s, v in zip(self.sentences, value):
+                setattr(s, name, v)
+        else:
+            self.__dict__[name] = value
+
+    def __getstate__(self):
+        return self.__dict__
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+
+    @property
+    def device(self):
+        return 'cuda' if torch.cuda.is_available() else 'cpu'
+
+    @lazy_property
+    def lens(self):
+        return torch.tensor([len(i) for i in self.sentences]).to(self.device, non_blocking=True)
+
+    @lazy_property
+    def mask(self):
+        return self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(self.lens.max())))
+
+    def compose(self, transform: Transform) -> Batch:
+        for f in transform.flattened_fields:
+            self.names.append(f.name)
+            self.fields[f.name] = f.compose([s.fields[f.name] for s in self.sentences])
+        return self
+
+    def shrink(self, batch_size: Optional[int] = None) -> Batch:
+        if batch_size is None:
+            batch_size = len(self) // 2
+        if batch_size <= 0:
+            raise RuntimeError(f"The batch has only {len(self)} sentences and can't be shrinked!")
+        return Batch([self.sentences[i] for i in torch.randperm(len(self))[:batch_size].tolist()])
+
+    def pin_memory(self):
+        for s in self.sentences:
+            for i in s.fields.values():
+                if isinstance(i, torch.Tensor):
+                    i.pin_memory()
+        return self
+
+
+class Sentence(object):
+
+    def __init__(self, transform, index: Optional[int] = None) -> Sentence:
+        self.index = index
+        # mapping from each nested field to their proper position
+        self.maps = dict()
+        # original values and numericalized values of each position
+        self.values, self.fields = [], {}
+        for i, field in enumerate(transform):
+
+            if not isinstance(field, Iterable):
+                field = [field]
+            for f in field:
+                if f is not None:
+                    self.maps[f.name] = i
+                    self.fields[f.name] = None
+
+    def __contains__(self, name):
+        return name in self.fields
+
+    def __getattr__(self, name):
+        if name in self.fields:
+            return self.values[self.maps[name]]
+        raise AttributeError(f"`{name}` not found")
+
+    def __setattr__(self, name, value):
+        if 'fields' in self.__dict__ and name in self:
+            index = self.maps[name]
+            if index >= len(self.values):
+                self.__dict__[name] = value
+            else:
+                self.values[index] = value
+        else:
+            self.__dict__[name] = value
+
+    def __getstate__(self):
+        state = vars(self)
+        if 'fields' in state:
+            state['fields'] = {
+                name: ((value.dtype, value.tolist())
+                       if isinstance(value, torch.Tensor)
+                       else value)
+                for name, value in state['fields'].items()
+            }
+        return state
+
+    def __setstate__(self, state):
+        if 'fields' in state:
+            state['fields'] = {
+                name: (torch.tensor(value[1], dtype=value[0])
+                       if isinstance(value, tuple) and isinstance(value[0], torch.dtype)
+                       else value)
+                for name, value in state['fields'].items()
+            }
+            self.__dict__.update(state)
+
+    def __len__(self):
+        try:
+            return len(next(iter(self.fields.values())))
+        except Exception:
+            raise AttributeError("Cannot get size of a sentence with no fields")
+
+    @lazy_property
+    def size(self):
+        return len(self)
+
+    def numericalize(self, fields):
+        for f in fields:
+             self.fields[f.name] = next(f.transform([getattr(self, f.name)]))
+        self.pad_index = fields[0].pad_index
+        return self
+
+    @classmethod
+    def from_cache(cls, fbin: str, pos: Tuple[int, int]) -> Sentence:
+        return debinarize(fbin, pos)
diff --git a/tania_scripts/supar/utils/vocab.py b/tania_scripts/supar/utils/vocab.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee9cae0a04a30bf92d1799525f6bfd73b36a4046
--- /dev/null
+++ b/tania_scripts/supar/utils/vocab.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from collections import Counter, defaultdict
+from typing import Iterable, Tuple, Union
+
+
+class Vocab(object):
+    r"""
+    Defines a vocabulary object that will be used to numericalize a field.
+
+    Args:
+        counter (~collections.Counter):
+            :class:`~collections.Counter` object holding the frequencies of each value found in the data.
+        min_freq (int):
+            The minimum frequency needed to include a token in the vocabulary. Default: 1.
+        specials (Tuple[str]):
+            The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: ``[]``.
+        unk_index (int):
+            The index of unk token. Default: 0.
+
+    Attributes:
+        itos:
+            A list of token strings indexed by their numerical identifiers.
+        stoi:
+            A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers.
+    """
+
+    def __init__(self, counter: Counter, min_freq: int = 1, specials: Tuple = tuple(), unk_index: int = 0) -> Vocab:
+        self.itos = list(specials)
+        self.stoi = defaultdict(lambda: unk_index)
+        self.stoi.update({token: i for i, token in enumerate(self.itos)})
+        self.update([token for token, freq in counter.items() if freq >= min_freq])
+        self.unk_index = unk_index
+        self.n_init = len(self)
+
+    def __len__(self):
+        return len(self.itos)
+
+    def __getitem__(self, key: Union[int, str, Iterable]) -> Union[str, int, Iterable]:
+        if isinstance(key, str):
+            return self.stoi[key]
+        elif not isinstance(key, Iterable):
+            return self.itos[key]
+        elif isinstance(key[0], str):
+            return [self.stoi[i] for i in key]
+        else:
+            return [self.itos[i] for i in key]
+
+    def __contains__(self, token):
+        return token in self.stoi
+
+    def __getstate__(self):
+        # avoid picking defaultdict
+        attrs = dict(self.__dict__)
+        # cast to regular dict
+        attrs['stoi'] = dict(self.stoi)
+        return attrs
+
+    def __setstate__(self, state):
+        stoi = defaultdict(lambda: self.unk_index)
+        stoi.update(state['stoi'])
+        state['stoi'] = stoi
+        self.__dict__.update(state)
+
+    def items(self):
+        return self.stoi.items()
+
+    def update(self, vocab: Union[Iterable[str], Vocab, Counter]) -> Vocab:
+        if isinstance(vocab, Vocab):
+            vocab = vocab.itos
+        # NOTE: PAY CAREFUL ATTENTION TO DICT ORDER UNDER DISTRIBUTED TRAINING!
+        vocab = sorted(set(vocab).difference(self.stoi))
+        self.itos.extend(vocab)
+        self.stoi.update({token: i for i, token in enumerate(vocab, len(self.stoi))})
+        return self
diff --git a/tania_scripts/supar/vector_quantize.py b/tania_scripts/supar/vector_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9a6ac374ae3c04ed5ad738aca34c909f31eaa6
--- /dev/null
+++ b/tania_scripts/supar/vector_quantize.py
@@ -0,0 +1,130 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from clusopt_core.cluster import Streamkm
+
+
+def ema_inplace(moving_avg, new, decay):
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories, eps=1e-5):
+    return (x + eps) / (x.sum() + n_categories * eps)
+
+
+class VectorQuantize(nn.Module):
+    # Based on: https://github.com/lucidrains/vector-quantize-pytorch
+    def __init__(
+        self,
+        dim,
+        n_embed,
+        decay=0.8,
+        commitment=1.0,
+        eps=1e-5,
+        wait_steps=0,
+        observe_steps=1245,
+        coreset_size_multiplier=10,
+    ):
+        super().__init__()
+
+        self.dim = dim
+        self.n_embed = n_embed
+        self.decay = decay
+        self.eps = eps
+        self.commitment = commitment
+
+        embed = torch.randn(dim, n_embed)
+        self.register_buffer("embed", nn.Parameter(embed))
+        self.register_buffer("cluster_size", torch.zeros(n_embed))
+        self.register_buffer("embed_avg", nn.Parameter(embed.clone()))
+
+        self.wait_steps_remaining = wait_steps
+        self.observe_steps_remaining = observe_steps
+        self.clustering_model = Streamkm(
+            coresetsize=n_embed * coreset_size_multiplier,
+            length=1500000,
+            seed=42,
+        )
+        self.data_chunks = []
+
+    def stream_cluster(self, input, expected_num_tokens=None):
+        input = input.reshape(-1, self.dim)
+        input_np = input.detach().cpu().numpy()
+        assert len(input.shape) == 2
+        self.data_chunks.append(input_np)
+        if (
+            expected_num_tokens is not None
+            and sum([chunk.shape[0] for chunk in self.data_chunks])
+            < expected_num_tokens
+        ):
+            return  # This is not the last sub-batch.
+        if self.wait_steps_remaining > 0:
+            self.wait_steps_remaining -= 1
+            self.data_chunks.clear()
+            return
+
+        self.observe_steps_remaining -= 1
+        input_np = np.concatenate(self.data_chunks, axis=0)
+        self.data_chunks.clear()
+        self.clustering_model.partial_fit(input_np)
+        if self.observe_steps_remaining == 0:
+            print("\nInitializing vq clusters (this may take a while)...")
+            clusters, _ = self.clustering_model.get_final_clusters(
+                self.n_embed, seed=42
+            )
+            new_embed = torch.tensor(
+                clusters.T, dtype=self.embed.dtype, device=self.embed.device
+            )
+            self.embed.copy_(new_embed)
+            # Don't set initial cluster sizes to zero! If a cluster is rare,
+            # embed_avg will be undergoing exponential decay until it's seen for
+            # the first time. If cluster_size is zero, this will lead to *embed*
+            # also undergoing exponential decay towards the origin before the
+            # cluster is ever encountered. Initializing to 1.0 will instead will
+            # instead leave embed in place for many iterations, up until
+            # cluster_size finally decays to near-zero.
+            self.cluster_size.fill_(1.0)
+            self.embed_avg.copy_(new_embed)
+
+    def forward(self, input, expected_num_tokens=None):
+        if self.observe_steps_remaining > 0:
+            if self.training:
+                self.stream_cluster(input, expected_num_tokens)
+            return (
+                input,
+                torch.zeros(input.shape[0],
+                            dtype=torch.long, device=input.device),
+                torch.tensor(0.0, dtype=input.dtype, device=input.device),
+                None
+            )
+
+        dtype = input.dtype
+        flatten = input.reshape(-1, self.dim)
+        dist = (
+            flatten.pow(2).sum(1, keepdim=True)
+            - 2 * flatten @ self.embed
+            + self.embed.pow(2).sum(0, keepdim=True)
+        )
+        _, embed_ind = (-dist).max(1)
+        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
+        embed_ind = embed_ind.view(*input.shape[:-1])
+        quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
+
+        if self.training:
+            ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+            embed_sum = flatten.transpose(0, 1) @ embed_onehot
+            ema_inplace(self.embed_avg, embed_sum, self.decay)
+            cluster_size = (
+                laplace_smoothing(self.cluster_size, self.n_embed, self.eps)
+                * self.cluster_size.sum()
+            )
+            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
+            self.embed.data.copy_(embed_normalized)
+
+        loss = F.mse_loss(quantize.detach(), input) * self.commitment
+        quantize = input + (quantize - input).detach()
+        quantize = torch.reshape(quantize, input.size())
+        return quantize, embed_ind, loss, dist
diff --git a/tania_scripts/tania-some-other-metrics.ipynb b/tania_scripts/tania-some-other-metrics.ipynb
index e23180be4b2b3ebb0fa7b86798291f716322a6a5..ceb6052a49e387b6e1af94e850c3bffb5c13fcd0 100644
--- a/tania_scripts/tania-some-other-metrics.ipynb
+++ b/tania_scripts/tania-some-other-metrics.ipynb
@@ -88,7 +88,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "id": "2a882cc9-8f9d-4457-becb-d2e26ab3f14f",
    "metadata": {},
    "outputs": [
@@ -107,7 +107,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "id": "8897dcc3-4218-4ee5-9984-17b9a6d8dce2",
    "metadata": {},
    "outputs": [],
@@ -163,7 +163,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 6,
    "id": "1363f307-fa4b-43ba-93d5-2d1c11ceb9e4",
    "metadata": {},
    "outputs": [
@@ -184,7 +184,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 7,
    "id": "1362e192-514a-4a77-a8cb-5c012026e2bb",
    "metadata": {},
    "outputs": [],
@@ -238,7 +238,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 8,
    "id": "544ff6aa-4104-4580-a01f-97429ffcc228",
    "metadata": {},
    "outputs": [
@@ -328,7 +328,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 9,
    "id": "b9052dc2-ce45-4af4-a0a0-46c60a13da12",
    "metadata": {},
    "outputs": [],
@@ -388,7 +388,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 10,
    "id": "1e9dd0fb-db6a-47d1-8bfb-1015845f6d3e",
    "metadata": {},
    "outputs": [
@@ -398,7 +398,7 @@
        "{'Flesch-Douma': 88.68, 'LIX': 11.55, 'Kandel-Moles': 5.86}"
       ]
      },
-     "execution_count": 11,
+     "execution_count": 10,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -421,7 +421,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 11,
    "id": "24bc84a5-b2df-4194-838a-8f24302599bd",
    "metadata": {},
    "outputs": [],
@@ -467,7 +467,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 12,
    "id": "0cdb972f-31b6-4e7e-82a8-371eda344f2c",
    "metadata": {},
    "outputs": [
@@ -477,7 +477,7 @@
        "{'Average Word Length': 3.79, 'Average Sentence Length': 7.0}"
       ]
      },
-     "execution_count": 13,
+     "execution_count": 12,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -521,7 +521,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 13,
    "id": "56af520c-d56b-404a-aebf-ad7c2a9ca503",
    "metadata": {},
    "outputs": [],
@@ -567,7 +567,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 14,
    "id": "f7c8b125-4651-4b21-bcc4-93ef78a4239b",
    "metadata": {},
    "outputs": [
@@ -603,7 +603,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 15,
    "id": "daa17c33-adca-4695-90eb-741579382939",
    "metadata": {},
    "outputs": [],
@@ -622,7 +622,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 16,
    "id": "80d8fa08-6b7d-4ab7-85cd-987823639277",
    "metadata": {},
    "outputs": [
@@ -665,7 +665,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 18,
    "id": "3f9c7dc7-6820-4013-a85c-2af4f846d4f5",
    "metadata": {},
    "outputs": [
@@ -693,7 +693,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 19,
    "id": "65e1a630-c46e-4b18-9831-b97864de53ee",
    "metadata": {},
    "outputs": [],
@@ -713,7 +713,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 20,
    "id": "1612e911-12a8-47c9-b811-b2d6885c3647",
    "metadata": {},
    "outputs": [
@@ -742,7 +742,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 21,
    "id": "925a3a75-aaaa-4851-b77b-b42cb1e21e11",
    "metadata": {},
    "outputs": [],
@@ -757,7 +757,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 22,
    "id": "6fa60897-ad26-43b4-b8de-861290ca6bd3",
    "metadata": {},
    "outputs": [
@@ -783,25 +783,10 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 23,
    "id": "f3678462-e572-4ce5-8d3d-a5389b2356c8",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Defaulting to user installation because normal site-packages is not writeable\n",
-      "Collecting scipy\n",
-      "  Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
-      "Requirement already satisfied: numpy<2.5,>=1.23.5 in /public/conda/Miniconda/envs/pytorch-2.6/lib/python3.11/site-packages (from scipy) (2.2.4)\n",
-      "Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)\n",
-      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.7/37.7 MB\u001b[0m \u001b[31m75.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
-      "\u001b[?25hInstalling collected packages: scipy\n",
-      "Successfully installed scipy-1.15.3\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "#!pip3 install seaborn\n",
     "#!pip3 install scipy"
@@ -809,7 +794,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 24,
    "id": "b621b2a8-488f-44db-b085-fe156f453943",
    "metadata": {},
    "outputs": [
@@ -830,7 +815,7 @@
     "import matplotlib.pyplot as plt\n",
     "from scipy.stats import spearmanr\n",
     "\n",
-    "# Sample data (replace with your real values)\n",
+    "# Sample data (to be replaces with real values)\n",
     "data = {\n",
     "    \"perplexity\": [32.5, 45.2, 28.1, 39.0, 50.3],\n",
     "    \"avg_word_length\": [4.1, 4.3, 4.0, 4.2, 4.5],\n",
@@ -853,10 +838,123 @@
     "plt.show()"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "45ee04fc-acab-4bba-ba06-e4cf4bca9fe5",
+   "metadata": {},
+   "source": [
+    "## Tree depth"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "id": "79f99787-c220-4f1d-93a9-59230363ec3f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def parse_sentence_block(text):\n",
+    "    lines = text.strip().split('\\n')\n",
+    "    result = []\n",
+    "    tokenlist = []\n",
+    "    for line in lines:\n",
+    "        # Split the line by tab and strip whitespace\n",
+    "        parts = tuple(line.strip().split('\\t'))\n",
+    "        # Only include lines that have exactly 4 parts\n",
+    "        if len(parts) == 4:\n",
+    "            parentidx =  int(parts[3])\n",
+    "            if '@@' in parts[2]:\n",
+    "                nonterm1 = parts[2].split('@@')[0]\n",
+    "                nonterm2 = parts[2].split('@@')[1]\n",
+    "            else:\n",
+    "                nonterm1 = parts[2]\n",
+    "                nonterm2 = '<nul>'\n",
+    "            postag = parts[1]\n",
+    "            token = parts[0]\n",
+    "            result.append((parentidx, nonterm1, nonterm2, postag))\n",
+    "            tokenlist.append(token)\n",
+    "    return result, tokenlist\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "id": "f567efb0-8b0b-4782-9345-052cf1785776",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "example_sentence = \"\"\"\n",
+    "<s>\t<s>\t<s>\t1\n",
+    "--\tponct\t<nul>@@<nul>\t1\n",
+    "Eh\tnpp\t<nul>@@<nul>\t1\n",
+    "bien?\tadv\tAP@@<nul>\t1\n",
+    "fit\tv\tVN@@<nul>\t2\n",
+    "-il\tcls-suj\tVN@@VPinf-OBJ\t3\n",
+    ".\tponct\t<nul>@@<nul>\t4\n",
+    "</s>\t</s>\t</s>\t4\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "id": "8d4ecba9-89b8-4000-a061-aa16aa68a404",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from transform import *\n",
+    "\n",
+    "def visualize_const_prediction(example_sent):\n",
+    "    parsed, tokenlist = parse_sentence_block(example_sent)\n",
+    "    tree = AttachJuxtaposeTree.totree(tokenlist, 'SENT')\n",
+    "    AttachJuxtaposeTree.action2tree(tree, parsed).pretty_print()\n",
+    "    nltk_tree = AttachJuxtaposeTree.action2tree(tree, parsed)\n",
+    "    #print(\"NLTK TREE\", nltk_tree)\n",
+    "    depth = nltk_tree.height() - 1  # NLTK includes the leaf level as height 1, so subtract 1 for tree depth \n",
+    "    print(\"Tree depth:\", depth)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "id": "bfd3abf3-b83a-4817-85ad-654daf72be88",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "               SENT                                 \n",
+      "                |                                    \n",
+      "               <s>                                  \n",
+      "                |                                    \n",
+      "               <s>                                  \n",
+      "  ______________|__________                          \n",
+      " |    |    |               AP                       \n",
+      " |    |    |     __________|________                 \n",
+      " |    |    |    |               VPinf-OBJ           \n",
+      " |    |    |    |     ______________|_______         \n",
+      " |    |    |    |    |                      VN      \n",
+      " |    |    |    |    |      ________________|____    \n",
+      " |    |    |    |    VN    |                |   </s>\n",
+      " |    |    |    |    |     |                |    |   \n",
+      " |  ponct npp  adv   v  cls-suj           ponct </s>\n",
+      " |    |    |    |    |     |                |    |   \n",
+      "<s>   --   Eh bien? fit   -il               .   </s>\n",
+      "\n",
+      "Tree depth: 8\n"
+     ]
+    }
+   ],
+   "source": [
+    "visualize_const_prediction(example_sentence)"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "3a6e3b53-7104-45ef-a4b5-e831bdd6ca6f",
+   "id": "bc51ab44-6885-45cc-bad2-6a43a7791fdb",
    "metadata": {},
    "outputs": [],
    "source": []
diff --git a/tania_scripts/transform.py b/tania_scripts/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6
--- /dev/null
+++ b/tania_scripts/transform.py
@@ -0,0 +1,459 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
+
+import nltk
+import torch
+
+from supar.models.const.crf.transform import Tree
+from supar.utils.common import NUL
+from supar.utils.logging import get_logger
+from supar.utils.tokenizer import Tokenizer
+from supar.utils.transform import Sentence
+
+if TYPE_CHECKING:
+    from supar.utils import Field
+
+logger = get_logger(__name__)
+
+
+class AttachJuxtaposeTree(Tree):
+    r"""
+    :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class,
+    supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`.
+
+    Attributes:
+        WORD:
+            Words in the sentence.
+        POS:
+            Part-of-speech tags, or underscores if not available.
+        TREE:
+            The raw constituency tree in :class:`nltk.tree.Tree` format.
+        NODE:
+            The target node on each rightmost chain.
+        PARENT:
+            The label of the parent node of each terminal.
+        NEW:
+            The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children.
+            ``NUL`` represents the `Attach` action.
+    """
+
+    fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW']
+
+    def __init__(
+        self,
+        WORD: Optional[Union[Field, Iterable[Field]]] = None,
+        POS: Optional[Union[Field, Iterable[Field]]] = None,
+        TREE: Optional[Union[Field, Iterable[Field]]] = None,
+        NODE: Optional[Union[Field, Iterable[Field]]] = None,
+        PARENT: Optional[Union[Field, Iterable[Field]]] = None,
+        NEW: Optional[Union[Field, Iterable[Field]]] = None
+    ) -> Tree:
+        super().__init__()
+
+        self.WORD = WORD
+        self.POS = POS
+        self.TREE = TREE
+        self.NODE = NODE
+        self.PARENT = PARENT
+        self.NEW = NEW
+
+    @property
+    def src(self):
+        return self.WORD, self.POS, self.TREE
+
+    @property
+    def tgt(self):
+        return self.NODE, self.PARENT, self.NEW
+
+    @classmethod
+    def tree2action(cls, tree: nltk.Tree):
+        r"""
+        Converts a constituency tree into AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                A constituency tree in :class:`nltk.tree.Tree` format.
+
+        Returns:
+            A sequence of AttachJuxtapose actions.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = nltk.Tree.fromstring('''
+                                            (TOP
+                                              (S
+                                                (NP (_ Arthur))
+                                                (VP
+                                                  (_ is)
+                                                  (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons)))))
+                                                (_ .)))
+                                            ''')
+            >>> tree.pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+            >>> AttachJuxtaposeTree.tree2action(tree)
+            [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+             (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+             (0, '<nul>', '<nul>')]
+        """
+
+        def isroot(node):
+            return node == tree[0]
+
+        def isterminal(node):
+            return len(node) == 1 and not isinstance(node[0], nltk.Tree)
+
+        def last_leaf(node):
+            pos = ()
+            while True:
+                pos += (len(node) - 1,)
+                node = node[-1]
+                if isterminal(node):
+                    return node, pos
+
+        def parent(position):
+            return tree[position[:-1]]
+
+        def grand(position):
+            return tree[position[:-2]]
+
+        def detach(tree):
+            last, last_pos = last_leaf(tree)
+            siblings = parent(last_pos)[:-1]
+
+            if len(siblings) > 0:
+                last_subtree = last
+                last_subtree_siblings = siblings
+                parent_label = NUL
+            else:
+                last_subtree, last_pos = parent(last_pos), last_pos[:-1]
+                last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1]
+                parent_label = last_subtree.label()
+
+            target_pos, new_label, last_tree = 0, NUL, tree
+            if isroot(last_subtree):
+                last_tree = None
+            
+            elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]):
+                new_label = parent(last_pos).label()
+                new_label = new_label
+                target = last_subtree_siblings[0]
+                last_grand = grand(last_pos)
+                if last_grand is None:
+                    last_tree = targetistermina
+                else:
+                    last_grand[-1] = target
+                target_pos = len(last_pos) - 2
+            else:
+                target = parent(last_pos)
+                target.pop()
+                target_pos = len(last_pos) - 2
+            action = target_pos, parent_label, new_label
+            return action, last_tree
+        if tree is None:
+            return []
+        action, last_tree = detach(tree)
+        return cls.tree2action(last_tree) + [action]
+
+    @classmethod
+    def action2tree(
+        cls,
+        tree: nltk.Tree,
+        actions: List[Tuple[int, str, str]],
+        join: str = '::',
+    ) -> nltk.Tree:
+        r"""
+        Recovers a constituency tree from a sequence of AttachJuxtapose actions.
+
+        Args:
+            tree (nltk.tree.Tree):
+                An empty tree that provides a base for building a result tree.
+            actions (List[Tuple[int, str, str]]):
+                A sequence of AttachJuxtapose actions.
+            join (str):
+                A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains.
+                Default: ``'::'``.
+
+        Returns:
+            A result constituency tree.
+
+        Examples:
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.action2tree(tree,
+                                                [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'),
+                                                 (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'),
+                                                 (0, '<nul>', '<nul>')]).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+        """
+
+        def target(node, depth):
+            node_pos = ()
+            for _ in range(depth):
+                node_pos += (len(node) - 1,)
+                node = node[-1]
+            return node, node_pos
+
+        def parent(tree, position):
+            return tree[position[:-1]]
+
+        def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree:
+            target_pos, parent_label, new_label, post = action
+            #print(target_pos, parent_label, new_label)
+            new_leaf = nltk.Tree(post, [terminal[0]])
+
+            # create the subtree to be inserted
+            new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf])
+            # find the target position at which to insert the new subtree
+            target_node = tree
+            if target_node is not None:
+                target_node, target_pos = target(target_node, target_pos)
+
+            # Attach
+            if new_label == NUL:
+                # attach the first token
+                if target_node is None:
+                    return new_subtree
+                target_node.append(new_subtree)
+            # Juxtapose
+            else:
+                new_subtree = nltk.Tree(new_label, [target_node, new_subtree])
+                if len(target_pos) > 0:
+                    parent_node = parent(tree, target_pos)
+                    parent_node[-1] = new_subtree
+                else:
+                    tree = new_subtree
+            return tree
+
+        tree, root, terminals = None, tree.label(), tree.pos()
+        for terminal, action in zip(terminals, actions):
+            tree = execute(tree, terminal, action)
+        # recover unary chains
+        nodes = [tree]
+        while nodes:
+            node = nodes.pop()
+            if isinstance(node, nltk.Tree):
+                nodes.extend(node)
+                if join in node.label():
+                    labels = node.label().split(join)
+                    node.set_label(labels[0])
+                    subtree = nltk.Tree(labels[-1], node)
+                    for label in reversed(labels[1:-1]):
+                        subtree = nltk.Tree(label, [subtree])
+                    node[:] = [subtree]
+        return nltk.Tree(root, [tree])
+
+    @classmethod
+    def action2span(
+        cls,
+        action: torch.Tensor,
+        spans: torch.Tensor = None,
+        nul_index: int = -1,
+        mask: torch.BoolTensor = None
+    ) -> torch.Tensor:
+        r"""
+        Converts a batch of the tensorized action at a given step into spans.
+
+        Args:
+            action (~torch.Tensor): ``[3, batch_size]``.
+                A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels.
+            spans (~torch.Tensor):
+                Spans generated at previous steps, ``None`` at the first step. Default: ``None``.
+            nul_index (int):
+                The index for the obj:`NUL` token, representing the Attach action. Default: -1.
+            mask (~torch.BoolTensor): ``[batch_size]``.
+                The mask for covering the unpadded tokens.
+
+        Returns:
+            A tensor representing a batch of spans for the given step.
+
+        Examples:
+            >>> from collections import Counter
+            >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab
+            >>> from supar.utils.common import NUL
+            >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL),
+                                             (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL),
+                                             (0, NUL, NUL)])
+            >>> vocab = Vocab(Counter(sorted(set([*parents, *news]))))
+            >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1)
+            >>> spans = None
+            >>> for action in actions.unbind(-1):
+            ...     spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL])
+            ...
+            >>> spans
+            tensor([[[-1,  1, -1, -1, -1, -1, -1,  3],
+                     [-1, -1, -1, -1, -1, -1,  4, -1],
+                     [-1, -1, -1,  1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1,  2, -1],
+                     [-1, -1, -1, -1, -1, -1,  1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1],
+                     [-1, -1, -1, -1, -1, -1, -1, -1]]])
+            >>> sequence = torch.where(spans.ge(0))
+            >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]]))
+            >>> sequence
+            [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')]
+            >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP')
+            >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print()
+                            TOP
+                             |
+                             S
+               ______________|_______________________
+              |              VP                      |
+              |      ________|___                    |
+              |     |            NP                  |
+              |     |    ________|___                |
+              |     |   |            PP              |
+              |     |   |     _______|___            |
+              NP    |   NP   |           NP          |
+              |     |   |    |        ___|_____      |
+              _     _   _    _       _         _     _
+              |     |   |    |       |         |     |
+            Arthur  is King  of     the     Britons  .
+
+        """
+
+        # [batch_size]
+        target, parent, new = action
+        if spans is None:
+            spans = action.new_full((action.shape[1], 2, 2), -1)
+            spans[:, 0, 1] = parent
+            return spans
+        if mask is None:
+            mask = torch.ones_like(target, dtype=bool)
+        juxtapose_mask = new.ne(nul_index) & mask
+        # ancestor nodes are those on the rightmost chain and higher than the target node
+        # [batch_size, seq_len]
+        rightmost_mask = spans[..., -1].ge(0)
+        ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1
+        # should not include the target node for the Juxtapose action
+        ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1))
+        target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1]
+        # the right boundaries of ancestor nodes should be aligned with the new generated terminals
+        spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1)
+        spans[..., -2].masked_fill_(ancestor_mask, -1)
+        spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask]
+        spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask]
+        # [batch_size, seq_len+1, seq_len+1]
+        spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1)
+        return spans
+
+    def load(
+        self,
+        data: Union[str, Iterable],
+        lang: Optional[str] = None,
+        **kwargs
+    ) -> List[AttachJuxtaposeTreeSentence]:
+        r"""
+        Args:
+            data (Union[str, Iterable]):
+                A filename or a list of instances.
+            lang (str):
+                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
+                ``None`` if tokenization is not required.
+                Default: ``None``.
+
+        Returns:
+            A list of :class:`AttachJuxtaposeTreeSentence` instances.
+        """
+
+        if lang is not None:
+            tokenizer = Tokenizer(lang)
+        if isinstance(data, str) and os.path.exists(data):
+            if data.endswith('.txt'):
+                data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1)
+            else:
+                data = open(data)
+        else:
+            if lang is not None:
+                data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
+            else:
+                data = [data] if isinstance(data[0], str) else data
+
+        index = 0
+        for s in data:
+
+            try:
+                tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            except ValueError:
+                logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!")
+                continue
+            except IndexError:
+                tree = nltk.Tree.fromstring('(S ' + s + ')')
+                sentence = AttachJuxtaposeTreeSentence(self, tree, index)
+            else:
+                yield sentence
+                index += 1
+        self.root = tree.label()
+
+
+class AttachJuxtaposeTreeSentence(Sentence):
+    r"""
+    Args:
+        transform (AttachJuxtaposeTree):
+            A :class:`AttachJuxtaposeTree` object.
+        tree (nltk.tree.Tree):
+            A :class:`nltk.tree.Tree` object.
+        index (Optional[int]):
+            Index of the sentence in the corpus. Default: ``None``.
+    """
+
+    def __init__(
+        self,
+        transform: AttachJuxtaposeTree,
+        tree: nltk.Tree,
+        index: Optional[int] = None
+    ) -> AttachJuxtaposeTreeSentence:
+        super().__init__(transform, index)
+
+        words, tags = zip(*tree.pos())
+        nodes, parents, news = None, None, None
+        if transform.training:
+            oracle_tree = tree.copy(True)
+            # the root node must have a unary chain
+            if len(oracle_tree) > 1:
+                oracle_tree[:] = [nltk.Tree('*', oracle_tree)]
+            oracle_tree.collapse_unary(joinChar='::')
+            if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree):
+                oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]])
+            nodes, parents, news = zip(*transform.tree2action(oracle_tree))
+        tags = [x.split("##")[0] for x in tags]
+        self.values = [words, tags, tree, nodes, parents, news]
+
+    def __repr__(self):
+        return self.values[-4].pformat(1000000)
+
+    def pretty_print(self):
+        self.values[-4].pretty_print()